| 1 | //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements target-independent rewrites and utilities to lower the |
| 10 | // 'vector.shape_cast' operation. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 15 | #include "mlir/Dialect/UB//IR/UBOps.h" |
| 16 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 17 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 18 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
| 19 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 20 | #include "mlir/IR/BuiltinTypes.h" |
| 21 | #include "mlir/IR/Location.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/IR/TypeUtilities.h" |
| 24 | #include <numeric> |
| 25 | |
| 26 | #define DEBUG_TYPE "vector-shape-cast-lowering" |
| 27 | |
| 28 | using namespace mlir; |
| 29 | |
| 30 | /// Perform the inplace update |
| 31 | /// rhs <- lhs + rhs |
| 32 | /// |
| 33 | /// where `rhs` is a number expressed in mixed base `base` with most signficant |
| 34 | /// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is |
| 35 | /// {5,3,2} then `rhs` has value a*3*2 + b*2 + c. |
| 36 | /// |
| 37 | /// Some examples where `base` is {5,3,2}: |
| 38 | /// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1} |
| 39 | /// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0} |
| 40 | /// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1} |
| 41 | /// |
| 42 | /// Invalid: |
| 43 | /// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2} |
| 44 | /// |
| 45 | /// Overflows not handled correctly: |
| 46 | /// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1}) |
| 47 | static void inplaceAdd(int64_t lhs, ArrayRef<int64_t> base, |
| 48 | MutableArrayRef<int64_t> rhs) { |
| 49 | |
| 50 | // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]: |
| 51 | for (int dim : llvm::reverse(C: llvm::seq<int>(Begin: 0, End: rhs.size()))) { |
| 52 | int64_t dimBase = base[dim]; |
| 53 | assert(rhs[dim] < dimBase && "rhs not in base" ); |
| 54 | |
| 55 | int64_t incremented = rhs[dim] + lhs; |
| 56 | |
| 57 | // If the incremented value excedes the dimension base, we must spill to the |
| 58 | // next most significant dimension and repeat (we might need to spill to |
| 59 | // more significant dimensions multiple times). |
| 60 | lhs = incremented / dimBase; |
| 61 | rhs[dim] = incremented % dimBase; |
| 62 | if (lhs == 0) |
| 63 | break; |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | namespace { |
| 68 | |
| 69 | /// shape_cast is converted to a sequence of extract, extract_strided_slice, |
| 70 | /// insert_strided_slice, and insert operations. The running example will be: |
| 71 | /// |
| 72 | /// %0 = vector.shape_cast %arg0 : |
| 73 | /// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8> |
| 74 | /// |
| 75 | /// In this example the source and result shapes share a common suffix of 7x11. |
| 76 | /// This means we can always decompose the shape_cast into extract, insert, and |
| 77 | /// their strided equivalents, on vectors with shape suffix 7x11. |
| 78 | /// |
| 79 | /// The greatest common divisor (gcd) of the first dimension preceding the |
| 80 | /// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate |
| 81 | /// on vectors with shapes that are `multiples` of (what we define as) the |
| 82 | /// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`. |
| 83 | /// |
| 84 | /// vector<2x2x3x4x7x11xi8> to |
| 85 | /// vector<8x6x7x11xi8> |
| 86 | /// | |||| |
| 87 | /// | ++++------------> common suffix of 7x11 |
| 88 | /// +-----------------> gcd(4,6) is 2 | | |
| 89 | /// | | | |
| 90 | /// v v v |
| 91 | /// atomic shape <----- 2x7x11 |
| 92 | /// |
| 93 | /// |
| 94 | /// |
| 95 | /// The decomposition implemented in this pattern consists of a sequence of |
| 96 | /// repeated steps: |
| 97 | /// |
| 98 | /// (1) Extract vectors from the suffix of the source. |
| 99 | /// In our example this is 2x2x3x4x7x11 -> 4x7x11. |
| 100 | /// |
| 101 | /// (2) Do extract_strided_slice down to the atomic shape. |
| 102 | /// In our example this is 4x7x11 -> 2x7x11. |
| 103 | /// |
| 104 | /// (3) Do insert_strided_slice to the suffix of the result. |
| 105 | /// In our example this is 2x7x11 -> 6x7x11. |
| 106 | /// |
| 107 | /// (4) insert these vectors into the result vector. |
| 108 | /// In our example this is 6x7x11 -> 8x6x7x11. |
| 109 | /// |
| 110 | /// These steps occur with different periods. In this example |
| 111 | /// (1) occurs 12 times, |
| 112 | /// (2) and (3) occur 24 times, and |
| 113 | /// (4) occurs 8 times. |
| 114 | /// |
| 115 | /// Two special cases are handled independently in this pattern |
| 116 | /// (i) A shape_cast that just does leading 1 insertion/removal |
| 117 | /// (ii) A shape_cast where the gcd is 1. |
| 118 | /// |
| 119 | /// These 2 cases can have more compact IR generated by not using the generic |
| 120 | /// algorithm described above. |
| 121 | /// |
| 122 | class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { |
| 123 | |
| 124 | // Case (i) of description. |
| 125 | // Assumes source and result shapes are identical up to some leading ones. |
| 126 | static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast, |
| 127 | PatternRewriter &rewriter) { |
| 128 | |
| 129 | const Location loc = shapeCast.getLoc(); |
| 130 | const VectorType sourceType = shapeCast.getSourceVectorType(); |
| 131 | const VectorType resultType = shapeCast.getResultVectorType(); |
| 132 | |
| 133 | const int64_t sourceRank = sourceType.getRank(); |
| 134 | const int64_t resultRank = resultType.getRank(); |
| 135 | const int64_t delta = sourceRank - resultRank; |
| 136 | const int64_t sourceLeading = delta > 0 ? delta : 0; |
| 137 | const int64_t resultLeading = delta > 0 ? 0 : -delta; |
| 138 | |
| 139 | const Value source = shapeCast.getSource(); |
| 140 | const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType); |
| 141 | const Value = rewriter.create<vector::ExtractOp>( |
| 142 | loc, source, SmallVector<int64_t>(sourceLeading, 0)); |
| 143 | const Value result = rewriter.create<vector::InsertOp>( |
| 144 | loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0)); |
| 145 | |
| 146 | rewriter.replaceOp(shapeCast, result); |
| 147 | return success(); |
| 148 | } |
| 149 | |
| 150 | // Case (ii) of description. |
| 151 | // Assumes a shape_cast where the suffix shape of the source starting at |
| 152 | // `sourceDim` and the suffix shape of the result starting at `resultDim` are |
| 153 | // identical. |
| 154 | static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast, |
| 155 | int64_t sourceDim, |
| 156 | int64_t resultDim, |
| 157 | PatternRewriter &rewriter) { |
| 158 | |
| 159 | const Location loc = shapeCast.getLoc(); |
| 160 | |
| 161 | const Value source = shapeCast.getSource(); |
| 162 | const ArrayRef<int64_t> sourceShape = |
| 163 | shapeCast.getSourceVectorType().getShape(); |
| 164 | |
| 165 | const VectorType resultType = shapeCast.getResultVectorType(); |
| 166 | const ArrayRef<int64_t> resultShape = resultType.getShape(); |
| 167 | |
| 168 | const int64_t nSlices = |
| 169 | std::accumulate(first: sourceShape.begin(), last: sourceShape.begin() + sourceDim, init: 1, |
| 170 | binary_op: std::multiplies<int64_t>()); |
| 171 | |
| 172 | SmallVector<int64_t> (sourceDim, 0); |
| 173 | SmallVector<int64_t> insertIndex(resultDim, 0); |
| 174 | Value result = rewriter.create<ub::PoisonOp>(loc, resultType); |
| 175 | |
| 176 | for (int i = 0; i < nSlices; ++i) { |
| 177 | Value = |
| 178 | rewriter.create<vector::ExtractOp>(loc, source, extractIndex); |
| 179 | |
| 180 | result = rewriter.create<vector::InsertOp>(loc, extracted, result, |
| 181 | insertIndex); |
| 182 | |
| 183 | inplaceAdd(lhs: 1, base: sourceShape.take_front(N: sourceDim), rhs: extractIndex); |
| 184 | inplaceAdd(lhs: 1, base: resultShape.take_front(N: resultDim), rhs: insertIndex); |
| 185 | } |
| 186 | rewriter.replaceOp(shapeCast, result); |
| 187 | return success(); |
| 188 | } |
| 189 | |
| 190 | public: |
| 191 | using OpRewritePattern::OpRewritePattern; |
| 192 | |
| 193 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
| 194 | PatternRewriter &rewriter) const override { |
| 195 | Location loc = op.getLoc(); |
| 196 | VectorType sourceType = op.getSourceVectorType(); |
| 197 | VectorType resultType = op.getResultVectorType(); |
| 198 | |
| 199 | if (sourceType.isScalable() || resultType.isScalable()) |
| 200 | return rewriter.notifyMatchFailure( |
| 201 | op, |
| 202 | "shape_cast where vectors are scalable not handled by this pattern" ); |
| 203 | |
| 204 | const ArrayRef<int64_t> sourceShape = sourceType.getShape(); |
| 205 | const ArrayRef<int64_t> resultShape = resultType.getShape(); |
| 206 | const int64_t sourceRank = sourceType.getRank(); |
| 207 | const int64_t resultRank = resultType.getRank(); |
| 208 | const int64_t numElms = sourceType.getNumElements(); |
| 209 | const Value source = op.getSource(); |
| 210 | |
| 211 | // Set the first dimension (starting at the end) in the source and result |
| 212 | // respectively where the dimension sizes differ. Using the running example: |
| 213 | // |
| 214 | // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ] |
| 215 | // shapes: (2,2,3,4,7,11) -> (8,6,7,11) |
| 216 | // ^ ^ |
| 217 | // | | |
| 218 | // sourceSuffixStartDim is 3 | |
| 219 | // | |
| 220 | // resultSuffixStartDim is 1 |
| 221 | int64_t sourceSuffixStartDim = sourceRank - 1; |
| 222 | int64_t resultSuffixStartDim = resultRank - 1; |
| 223 | while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 && |
| 224 | (sourceType.getDimSize(sourceSuffixStartDim) == |
| 225 | resultType.getDimSize(resultSuffixStartDim))) { |
| 226 | --sourceSuffixStartDim; |
| 227 | --resultSuffixStartDim; |
| 228 | } |
| 229 | |
| 230 | // This is the case (i) where there are just some leading ones to contend |
| 231 | // with in the source or result. It can be handled with a single |
| 232 | // extract/insert pair. |
| 233 | if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) |
| 234 | return leadingOnesLowering(op, rewriter); |
| 235 | |
| 236 | const int64_t sourceSuffixStartDimSize = |
| 237 | sourceType.getDimSize(sourceSuffixStartDim); |
| 238 | const int64_t resultSuffixStartDimSize = |
| 239 | resultType.getDimSize(resultSuffixStartDim); |
| 240 | const int64_t greatestCommonDivisor = |
| 241 | std::gcd(m: sourceSuffixStartDimSize, n: resultSuffixStartDimSize); |
| 242 | const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim; |
| 243 | const size_t = |
| 244 | sourceSuffixStartDimSize / greatestCommonDivisor; |
| 245 | const size_t insertPeriod = |
| 246 | resultSuffixStartDimSize / greatestCommonDivisor; |
| 247 | |
| 248 | SmallVector<int64_t> atomicShape(sourceShape.begin() + sourceSuffixStartDim, |
| 249 | sourceShape.end()); |
| 250 | atomicShape[0] = greatestCommonDivisor; |
| 251 | |
| 252 | const int64_t numAtomicElms = std::accumulate( |
| 253 | first: atomicShape.begin(), last: atomicShape.end(), init: 1, binary_op: std::multiplies<int64_t>()); |
| 254 | const size_t nAtomicSlices = numElms / numAtomicElms; |
| 255 | |
| 256 | // This is the case (ii) where the strided dimension size is 1. More compact |
| 257 | // IR is generated in this case if we just extract and insert the elements |
| 258 | // directly. In other words, we don't use extract_strided_slice and |
| 259 | // insert_strided_slice. |
| 260 | if (greatestCommonDivisor == 1) |
| 261 | return noStridedSliceLowering(op, sourceSuffixStartDim + 1, |
| 262 | resultSuffixStartDim + 1, rewriter); |
| 263 | |
| 264 | // The insert_strided_slice result's type |
| 265 | const ArrayRef<int64_t> insertStridedShape = |
| 266 | resultShape.drop_front(N: resultSuffixStartDim); |
| 267 | const VectorType insertStridedType = |
| 268 | VectorType::get(insertStridedShape, resultType.getElementType()); |
| 269 | |
| 270 | SmallVector<int64_t> (sourceSuffixStartDim, 0); |
| 271 | SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0); |
| 272 | SmallVector<int64_t> (stridedSliceRank, 0); |
| 273 | SmallVector<int64_t> insertOffsets(stridedSliceRank, 0); |
| 274 | const SmallVector<int64_t> sizes(stridedSliceRank, 1); |
| 275 | |
| 276 | Value = {}; |
| 277 | Value = {}; |
| 278 | Value insertedSlice = {}; |
| 279 | Value result = rewriter.create<ub::PoisonOp>(loc, resultType); |
| 280 | const Value partResult = |
| 281 | rewriter.create<ub::PoisonOp>(loc, insertStridedType); |
| 282 | |
| 283 | for (size_t i = 0; i < nAtomicSlices; ++i) { |
| 284 | |
| 285 | const size_t = i % extractPeriod; |
| 286 | const size_t insertStridedPhase = i % insertPeriod; |
| 287 | |
| 288 | // vector.extract |
| 289 | if (extractStridedPhase == 0) { |
| 290 | extracted = |
| 291 | rewriter.create<vector::ExtractOp>(loc, source, extractIndex); |
| 292 | inplaceAdd(lhs: 1, base: sourceShape.take_front(N: sourceSuffixStartDim), |
| 293 | rhs: extractIndex); |
| 294 | } |
| 295 | |
| 296 | // vector.extract_strided_slice |
| 297 | extractOffsets[0] = extractStridedPhase * greatestCommonDivisor; |
| 298 | extractedStrided = rewriter.create<vector::ExtractStridedSliceOp>( |
| 299 | loc, extracted, extractOffsets, atomicShape, sizes); |
| 300 | |
| 301 | // vector.insert_strided_slice |
| 302 | if (insertStridedPhase == 0) { |
| 303 | insertedSlice = partResult; |
| 304 | } |
| 305 | insertOffsets[0] = insertStridedPhase * greatestCommonDivisor; |
| 306 | insertedSlice = rewriter.create<vector::InsertStridedSliceOp>( |
| 307 | loc, extractedStrided, insertedSlice, insertOffsets, sizes); |
| 308 | |
| 309 | // vector.insert |
| 310 | if (insertStridedPhase + 1 == insertPeriod) { |
| 311 | result = rewriter.create<vector::InsertOp>(loc, insertedSlice, result, |
| 312 | insertIndex); |
| 313 | inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim), |
| 314 | insertIndex); |
| 315 | } |
| 316 | } |
| 317 | rewriter.replaceOp(op, result); |
| 318 | return success(); |
| 319 | } |
| 320 | }; |
| 321 | |
| 322 | /// A shape_cast lowering for scalable vectors with a single trailing scalable |
| 323 | /// dimension. This is similar to the general shape_cast lowering but makes use |
| 324 | /// of vector.scalable.insert and vector.scalable.extract to move elements a |
| 325 | /// subvector at a time. |
| 326 | /// |
| 327 | /// E.g.: |
| 328 | /// ``` |
| 329 | /// // Flatten scalable vector |
| 330 | /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> |
| 331 | /// ``` |
| 332 | /// is rewritten to: |
| 333 | /// ``` |
| 334 | /// // Flatten scalable vector |
| 335 | /// %c = arith.constant dense<0> : vector<[8]xi32> |
| 336 | /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> |
| 337 | /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> |
| 338 | /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> |
| 339 | /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> |
| 340 | /// ``` |
| 341 | /// or: |
| 342 | /// ``` |
| 343 | /// // Un-flatten scalable vector |
| 344 | /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> |
| 345 | /// ``` |
| 346 | /// is rewritten to: |
| 347 | /// ``` |
| 348 | /// // Un-flatten scalable vector |
| 349 | /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> |
| 350 | /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> |
| 351 | /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> |
| 352 | /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> |
| 353 | /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> |
| 354 | /// ``` |
| 355 | class ScalableShapeCastOpRewritePattern |
| 356 | : public OpRewritePattern<vector::ShapeCastOp> { |
| 357 | public: |
| 358 | using OpRewritePattern::OpRewritePattern; |
| 359 | |
| 360 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
| 361 | PatternRewriter &rewriter) const override { |
| 362 | |
| 363 | Location loc = op.getLoc(); |
| 364 | auto sourceVectorType = op.getSourceVectorType(); |
| 365 | auto resultVectorType = op.getResultVectorType(); |
| 366 | auto srcRank = sourceVectorType.getRank(); |
| 367 | auto resRank = resultVectorType.getRank(); |
| 368 | |
| 369 | // This can only lower shape_casts where both the source and result types |
| 370 | // have a single trailing scalable dimension. This is because there are no |
| 371 | // legal representation of other scalable types in LLVM (and likely won't be |
| 372 | // soon). There are also (currently) no operations that can index or extract |
| 373 | // from >= 2-D scalable vectors or scalable vectors of fixed vectors. |
| 374 | if (!isTrailingDimScalable(type: sourceVectorType) || |
| 375 | !isTrailingDimScalable(type: resultVectorType)) { |
| 376 | return rewriter.notifyMatchFailure( |
| 377 | op, "trailing dims are not scalable, not handled by this pattern" ); |
| 378 | } |
| 379 | |
| 380 | // The sizes of the trailing dimension of the source and result vectors, the |
| 381 | // size of subvector to move, and the number of elements in the vectors. |
| 382 | // These are "min" sizes as they are the size when vscale == 1. |
| 383 | auto minSourceTrailingSize = sourceVectorType.getShape().back(); |
| 384 | auto minResultTrailingSize = resultVectorType.getShape().back(); |
| 385 | auto = |
| 386 | std::min(minSourceTrailingSize, minResultTrailingSize); |
| 387 | int64_t minNumElts = 1; |
| 388 | for (auto size : sourceVectorType.getShape()) |
| 389 | minNumElts *= size; |
| 390 | |
| 391 | // The subvector type to move from the source to the result. Note that this |
| 392 | // is a scalable vector. This rewrite will generate code in terms of the |
| 393 | // "min" size (vscale == 1 case), that scales to any vscale. |
| 394 | auto = VectorType::get( |
| 395 | {minExtractionSize}, sourceVectorType.getElementType(), {true}); |
| 396 | |
| 397 | Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType); |
| 398 | SmallVector<int64_t> srcIdx(srcRank, 0); |
| 399 | SmallVector<int64_t> resIdx(resRank, 0); |
| 400 | |
| 401 | // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) |
| 402 | // once D150000 lands. |
| 403 | Value currentResultScalableVector; |
| 404 | Value currentSourceScalableVector; |
| 405 | for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { |
| 406 | // 1. Extract a scalable subvector from the source vector. |
| 407 | if (!currentSourceScalableVector) { |
| 408 | if (srcRank != 1) { |
| 409 | currentSourceScalableVector = rewriter.create<vector::ExtractOp>( |
| 410 | loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); |
| 411 | } else { |
| 412 | currentSourceScalableVector = op.getSource(); |
| 413 | } |
| 414 | } |
| 415 | Value sourceSubVector = currentSourceScalableVector; |
| 416 | if (minExtractionSize < minSourceTrailingSize) { |
| 417 | sourceSubVector = rewriter.create<vector::ScalableExtractOp>( |
| 418 | loc, extractionVectorType, sourceSubVector, srcIdx.back()); |
| 419 | } |
| 420 | |
| 421 | // 2. Insert the scalable subvector into the result vector. |
| 422 | if (!currentResultScalableVector) { |
| 423 | if (minExtractionSize == minResultTrailingSize) { |
| 424 | currentResultScalableVector = sourceSubVector; |
| 425 | } else if (resRank != 1) { |
| 426 | currentResultScalableVector = rewriter.create<vector::ExtractOp>( |
| 427 | loc, result, llvm::ArrayRef(resIdx).drop_back()); |
| 428 | } else { |
| 429 | currentResultScalableVector = result; |
| 430 | } |
| 431 | } |
| 432 | if (minExtractionSize < minResultTrailingSize) { |
| 433 | currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>( |
| 434 | loc, sourceSubVector, currentResultScalableVector, resIdx.back()); |
| 435 | } |
| 436 | |
| 437 | // 3. Update the source and result scalable vectors if needed. |
| 438 | if (resIdx.back() + minExtractionSize >= minResultTrailingSize && |
| 439 | currentResultScalableVector != result) { |
| 440 | // Finished row of result. Insert complete scalable vector into result |
| 441 | // (n-D) vector. |
| 442 | result = rewriter.create<vector::InsertOp>( |
| 443 | loc, currentResultScalableVector, result, |
| 444 | llvm::ArrayRef(resIdx).drop_back()); |
| 445 | currentResultScalableVector = {}; |
| 446 | } |
| 447 | if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { |
| 448 | // Finished row of source. |
| 449 | currentSourceScalableVector = {}; |
| 450 | } |
| 451 | |
| 452 | // 4. Increment the insert/extract indices, stepping by minExtractionSize |
| 453 | // for the trailing dimensions. |
| 454 | inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx); |
| 455 | inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx); |
| 456 | } |
| 457 | |
| 458 | rewriter.replaceOp(op, result); |
| 459 | return success(); |
| 460 | } |
| 461 | |
| 462 | static bool isTrailingDimScalable(VectorType type) { |
| 463 | return type.getRank() >= 1 && type.getScalableDims().back() && |
| 464 | !llvm::is_contained(type.getScalableDims().drop_back(), true); |
| 465 | } |
| 466 | }; |
| 467 | |
| 468 | } // namespace |
| 469 | |
| 470 | void mlir::vector::populateVectorShapeCastLoweringPatterns( |
| 471 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 472 | patterns.add<ShapeCastOpRewritePattern, ScalableShapeCastOpRewritePattern>( |
| 473 | arg: patterns.getContext(), args&: benefit); |
| 474 | } |
| 475 | |