| 1 | //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to do vector unrolling and vector distribution. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 14 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 15 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 16 | #include "mlir/Interfaces/VectorInterfaces.h" |
| 17 | #include "llvm/ADT/MapVector.h" |
| 18 | #include "llvm/ADT/STLExtras.h" |
| 19 | #include "llvm/Support/Debug.h" |
| 20 | #include "llvm/Support/InterleavedRange.h" |
| 21 | #include <optional> |
| 22 | |
| 23 | #define DEBUG_TYPE "vector-unroll" |
| 24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 25 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::vector; |
| 29 | |
| 30 | /// Compute the indices of the slice `index` for a transfer op. |
| 31 | static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, |
| 32 | ArrayRef<Value> indices, |
| 33 | AffineMap permutationMap, |
| 34 | Location loc, |
| 35 | OpBuilder &builder) { |
| 36 | MLIRContext *ctx = builder.getContext(); |
| 37 | auto isBroadcast = [](AffineExpr expr) { |
| 38 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) |
| 39 | return constExpr.getValue() == 0; |
| 40 | return false; |
| 41 | }; |
| 42 | // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. |
| 43 | SmallVector<Value> slicedIndices(indices); |
| 44 | for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) { |
| 45 | if (isBroadcast(dim.value())) |
| 46 | continue; |
| 47 | unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition(); |
| 48 | auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) + |
| 49 | getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx); |
| 50 | auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr); |
| 51 | slicedIndices[pos] = |
| 52 | builder.create<affine::AffineApplyOp>(loc, map, indices[pos]); |
| 53 | } |
| 54 | return slicedIndices; |
| 55 | } |
| 56 | |
| 57 | // Clones `op` into a new operations that takes `operands` and returns |
| 58 | // `resultTypes`. |
| 59 | static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, |
| 60 | Operation *op, |
| 61 | ArrayRef<Value> operands, |
| 62 | ArrayRef<Type> resultTypes) { |
| 63 | return builder.create(loc, op->getName().getIdentifier(), operands, |
| 64 | resultTypes, op->getAttrs()); |
| 65 | } |
| 66 | |
| 67 | /// Return the target shape for unrolling for the given `op`. Return |
| 68 | /// std::nullopt if the op shouldn't be or cannot be unrolled. |
| 69 | static std::optional<SmallVector<int64_t>> |
| 70 | getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { |
| 71 | LDBG("" ); |
| 72 | LDBG("Get unroll shape for op " << op->getName().getStringRef()); |
| 73 | if (options.filterConstraint && failed(options.filterConstraint(op))) { |
| 74 | LDBG("--no filter constraint -> BAIL" ); |
| 75 | return std::nullopt; |
| 76 | } |
| 77 | assert(options.nativeShape && |
| 78 | "vector unrolling expects the native shape or native" |
| 79 | "shape call back function to be set" ); |
| 80 | auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); |
| 81 | if (!unrollableVectorOp) { |
| 82 | LDBG("--not an unrollable op -> BAIL" ); |
| 83 | return std::nullopt; |
| 84 | } |
| 85 | auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); |
| 86 | if (!maybeUnrollShape) { |
| 87 | LDBG("--could not get shape of op " << *op << " -> BAIL" ); |
| 88 | return std::nullopt; |
| 89 | } |
| 90 | LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape)); |
| 91 | |
| 92 | std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); |
| 93 | if (!targetShape) { |
| 94 | LDBG("--no unrolling target shape defined " << *op << "-> SKIP" ); |
| 95 | return std::nullopt; |
| 96 | } |
| 97 | LDBG("--target shape: " << llvm::interleaved(*targetShape)); |
| 98 | |
| 99 | auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); |
| 100 | if (!maybeShapeRatio) { |
| 101 | LDBG("--could not compute integral shape ratio -> BAIL" ); |
| 102 | return std::nullopt; |
| 103 | } |
| 104 | if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { |
| 105 | LDBG("--no unrolling needed -> SKIP" ); |
| 106 | return std::nullopt; |
| 107 | } |
| 108 | LDBG("--found an integral shape ratio to unroll to -> SUCCESS" ); |
| 109 | return targetShape; |
| 110 | } |
| 111 | |
| 112 | static SmallVector<int64_t> |
| 113 | getUnrollOrder(unsigned numLoops, Operation *op, |
| 114 | const vector::UnrollVectorOptions &options) { |
| 115 | SmallVector<int64_t> loopOrder = |
| 116 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops))); |
| 117 | if (options.traversalOrderCallback != nullptr) { |
| 118 | std::optional<SmallVector<int64_t>> order = |
| 119 | options.traversalOrderCallback(op); |
| 120 | if (order) { |
| 121 | loopOrder = std::move(*order); |
| 122 | } |
| 123 | } |
| 124 | return loopOrder; |
| 125 | } |
| 126 | |
| 127 | namespace { |
| 128 | |
| 129 | struct UnrollTransferReadPattern |
| 130 | : public OpRewritePattern<vector::TransferReadOp> { |
| 131 | UnrollTransferReadPattern(MLIRContext *context, |
| 132 | const vector::UnrollVectorOptions &options, |
| 133 | PatternBenefit benefit = 1) |
| 134 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
| 135 | options(options) {} |
| 136 | |
| 137 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
| 138 | PatternRewriter &rewriter) const override { |
| 139 | // TODO: support 0-d corner case. |
| 140 | if (readOp.getTransferRank() == 0) |
| 141 | return failure(); |
| 142 | if (readOp.getMask()) |
| 143 | return failure(); |
| 144 | auto targetShape = getTargetShape(options, readOp); |
| 145 | if (!targetShape) |
| 146 | return failure(); |
| 147 | auto sourceVectorType = readOp.getVectorType(); |
| 148 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 149 | Location loc = readOp.getLoc(); |
| 150 | ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); |
| 151 | |
| 152 | // Prepare the result vector; |
| 153 | Value result = rewriter.create<arith::ConstantOp>( |
| 154 | loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); |
| 155 | auto targetType = |
| 156 | VectorType::get(*targetShape, sourceVectorType.getElementType()); |
| 157 | SmallVector<Value> originalIndices(readOp.getIndices().begin(), |
| 158 | readOp.getIndices().end()); |
| 159 | SmallVector<int64_t> loopOrder = |
| 160 | getUnrollOrder(originalSize.size(), readOp, options); |
| 161 | for (SmallVector<int64_t> elementOffsets : |
| 162 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 163 | SmallVector<Value> indices = |
| 164 | sliceTransferIndices(elementOffsets, originalIndices, |
| 165 | readOp.getPermutationMap(), loc, rewriter); |
| 166 | auto slicedRead = rewriter.create<vector::TransferReadOp>( |
| 167 | loc, targetType, readOp.getBase(), indices, |
| 168 | readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), |
| 169 | readOp.getInBoundsAttr()); |
| 170 | |
| 171 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 172 | loc, slicedRead, result, elementOffsets, strides); |
| 173 | } |
| 174 | rewriter.replaceOp(readOp, result); |
| 175 | return success(); |
| 176 | } |
| 177 | |
| 178 | private: |
| 179 | vector::UnrollVectorOptions options; |
| 180 | }; |
| 181 | |
| 182 | struct UnrollTransferWritePattern |
| 183 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 184 | UnrollTransferWritePattern(MLIRContext *context, |
| 185 | const vector::UnrollVectorOptions &options, |
| 186 | PatternBenefit benefit = 1) |
| 187 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
| 188 | options(options) {} |
| 189 | |
| 190 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| 191 | PatternRewriter &rewriter) const override { |
| 192 | // TODO: support 0-d corner case. |
| 193 | if (writeOp.getTransferRank() == 0) |
| 194 | return failure(); |
| 195 | |
| 196 | if (writeOp.getMask()) |
| 197 | return failure(); |
| 198 | auto targetShape = getTargetShape(options, writeOp); |
| 199 | if (!targetShape) |
| 200 | return failure(); |
| 201 | auto sourceVectorType = writeOp.getVectorType(); |
| 202 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 203 | Location loc = writeOp.getLoc(); |
| 204 | ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); |
| 205 | SmallVector<Value> originalIndices(writeOp.getIndices().begin(), |
| 206 | writeOp.getIndices().end()); |
| 207 | SmallVector<int64_t> loopOrder = |
| 208 | getUnrollOrder(originalSize.size(), writeOp, options); |
| 209 | Value resultTensor; |
| 210 | for (SmallVector<int64_t> elementOffsets : |
| 211 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 212 | Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 213 | loc, writeOp.getVector(), elementOffsets, *targetShape, strides); |
| 214 | SmallVector<Value> indices = |
| 215 | sliceTransferIndices(elementOffsets, originalIndices, |
| 216 | writeOp.getPermutationMap(), loc, rewriter); |
| 217 | Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( |
| 218 | loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(), |
| 219 | indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); |
| 220 | // For the tensor case update the destination for the next transfer write. |
| 221 | if (!slicedWrite->getResults().empty()) |
| 222 | resultTensor = slicedWrite->getResult(0); |
| 223 | } |
| 224 | if (resultTensor) |
| 225 | rewriter.replaceOp(writeOp, resultTensor); |
| 226 | else |
| 227 | rewriter.eraseOp(op: writeOp); |
| 228 | return success(); |
| 229 | } |
| 230 | |
| 231 | private: |
| 232 | vector::UnrollVectorOptions options; |
| 233 | }; |
| 234 | |
| 235 | struct OffsetMapInfo { |
| 236 | static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } |
| 237 | |
| 238 | static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } |
| 239 | |
| 240 | static unsigned getHashValue(const SmallVector<int64_t> &v) { |
| 241 | return static_cast<unsigned>(llvm::hash_combine_range(R: v)); |
| 242 | } |
| 243 | |
| 244 | static bool isEqual(const SmallVector<int64_t> &lhs, |
| 245 | const SmallVector<int64_t> &rhs) { |
| 246 | return lhs == rhs; |
| 247 | } |
| 248 | }; |
| 249 | |
| 250 | struct UnrollContractionPattern |
| 251 | : public OpRewritePattern<vector::ContractionOp> { |
| 252 | UnrollContractionPattern(MLIRContext *context, |
| 253 | const vector::UnrollVectorOptions &options, |
| 254 | PatternBenefit benefit = 1) |
| 255 | : OpRewritePattern<vector::ContractionOp>(context, benefit), |
| 256 | options(options) {} |
| 257 | |
| 258 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
| 259 | PatternRewriter &rewriter) const override { |
| 260 | auto targetShape = getTargetShape(options, contractOp); |
| 261 | if (!targetShape) |
| 262 | return failure(); |
| 263 | auto dstVecType = cast<VectorType>(contractOp.getResultType()); |
| 264 | SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); |
| 265 | |
| 266 | Location loc = contractOp.getLoc(); |
| 267 | unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); |
| 268 | AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; |
| 269 | llvm::MapVector< |
| 270 | SmallVector<int64_t>, Value, |
| 271 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
| 272 | accCache; |
| 273 | |
| 274 | SmallVector<int64_t> loopOrder = getUnrollOrder( |
| 275 | contractOp.getIteratorTypes().size(), contractOp, options); |
| 276 | |
| 277 | for (SmallVector<int64_t> offsets : |
| 278 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 279 | SmallVector<Value> slicesOperands(contractOp.getNumOperands()); |
| 280 | |
| 281 | // Helper to compute the new shape of each operand and extract the slice. |
| 282 | auto extractOperand = [&](unsigned index, Value operand, |
| 283 | AffineMap permutationMap, |
| 284 | ArrayRef<int64_t> operandOffets) { |
| 285 | SmallVector<int64_t> operandShape = applyPermutationMap( |
| 286 | permutationMap, ArrayRef<int64_t>(*targetShape)); |
| 287 | SmallVector<int64_t> operandStrides(operandOffets.size(), 1); |
| 288 | slicesOperands[index] = |
| 289 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 290 | loc, operand, operandOffets, operandShape, operandStrides); |
| 291 | }; |
| 292 | |
| 293 | // Extract the new lhs operand. |
| 294 | AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; |
| 295 | SmallVector<int64_t> lhsOffets = |
| 296 | applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| 297 | extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); |
| 298 | |
| 299 | // Extract the new rhs operand. |
| 300 | AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; |
| 301 | SmallVector<int64_t> rhsOffets = |
| 302 | applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
| 303 | extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); |
| 304 | |
| 305 | AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; |
| 306 | SmallVector<int64_t> accOffets = |
| 307 | applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
| 308 | // If a version of the accumulator has already been computed, use it |
| 309 | // otherwise extract the first version from the original operand. |
| 310 | auto *accIt = accCache.find(accOffets); |
| 311 | if (accIt != accCache.end()) |
| 312 | slicesOperands[2] = accIt->second; |
| 313 | else |
| 314 | extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); |
| 315 | |
| 316 | SmallVector<int64_t> dstShape = |
| 317 | applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); |
| 318 | auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); |
| 319 | Operation *newOp = cloneOpWithOperandsAndTypes( |
| 320 | rewriter, loc, contractOp, slicesOperands, targetType); |
| 321 | |
| 322 | SmallVector<int64_t> dstOffets = |
| 323 | applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); |
| 324 | // Save the accumulated value untill all the loops are unrolled since |
| 325 | // reduction loop keep updating the accumulator. |
| 326 | accCache[dstOffets] = newOp->getResult(0); |
| 327 | } |
| 328 | // Assemble back the accumulator into a single vector. |
| 329 | Value result = rewriter.create<arith::ConstantOp>( |
| 330 | loc, dstVecType, rewriter.getZeroAttr(dstVecType)); |
| 331 | for (const auto &it : accCache) { |
| 332 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
| 333 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 334 | loc, it.second, result, it.first, dstStrides); |
| 335 | } |
| 336 | rewriter.replaceOp(contractOp, result); |
| 337 | return success(); |
| 338 | } |
| 339 | |
| 340 | private: |
| 341 | vector::UnrollVectorOptions options; |
| 342 | }; |
| 343 | |
| 344 | struct UnrollMultiReductionPattern |
| 345 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
| 346 | UnrollMultiReductionPattern(MLIRContext *context, |
| 347 | const vector::UnrollVectorOptions &options, |
| 348 | PatternBenefit benefit = 1) |
| 349 | : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
| 350 | options(options) {} |
| 351 | |
| 352 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, |
| 353 | PatternRewriter &rewriter) const override { |
| 354 | auto resultType = reductionOp->getResult(0).getType(); |
| 355 | if (resultType.isIntOrFloat()) { |
| 356 | return rewriter.notifyMatchFailure(reductionOp, |
| 357 | "Unrolling scalars is not supported" ); |
| 358 | } |
| 359 | std::optional<SmallVector<int64_t>> targetShape = |
| 360 | getTargetShape(options, reductionOp); |
| 361 | if (!targetShape) |
| 362 | return failure(); |
| 363 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
| 364 | llvm::MapVector< |
| 365 | SmallVector<int64_t>, Value, |
| 366 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
| 367 | accCache; |
| 368 | Location loc = reductionOp.getLoc(); |
| 369 | |
| 370 | // Stride of the ratios, this gives us the offsets of sliceCount in a basis |
| 371 | // of multiples of the targetShape. |
| 372 | for (SmallVector<int64_t> offsets : |
| 373 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 374 | SmallVector<Value> operands; |
| 375 | SmallVector<int64_t> operandStrides(offsets.size(), 1); |
| 376 | Value slicedOperand = |
| 377 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 378 | loc, reductionOp.getSource(), offsets, *targetShape, |
| 379 | operandStrides); |
| 380 | operands.push_back(slicedOperand); |
| 381 | SmallVector<int64_t> dstShape; |
| 382 | SmallVector<int64_t> destOffset; |
| 383 | for (size_t i : llvm::seq(size_t(0), targetShape->size())) { |
| 384 | if (!reductionOp.isReducedDim(i)) { |
| 385 | destOffset.push_back(offsets[i]); |
| 386 | dstShape.push_back((*targetShape)[i]); |
| 387 | } |
| 388 | } |
| 389 | Value acc; |
| 390 | SmallVector<int64_t> accStrides(destOffset.size(), 1); |
| 391 | // If a version of the accumulator has already been computed, use it |
| 392 | // otherwise extract the first version from the original operand. |
| 393 | auto *accIt = accCache.find(destOffset); |
| 394 | if (accIt != accCache.end()) |
| 395 | acc = accIt->second; |
| 396 | else |
| 397 | acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 398 | loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); |
| 399 | operands.push_back(acc); |
| 400 | auto targetType = VectorType::get( |
| 401 | dstShape, reductionOp.getSourceVectorType().getElementType()); |
| 402 | Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, |
| 403 | operands, targetType); |
| 404 | Value result = newOp->getResult(0); |
| 405 | accCache[destOffset] = result; |
| 406 | } |
| 407 | // Assemble back the accumulator into a single vector. |
| 408 | Value result = rewriter.create<arith::ConstantOp>( |
| 409 | loc, reductionOp.getDestType(), |
| 410 | rewriter.getZeroAttr(reductionOp.getDestType())); |
| 411 | for (const auto &it : accCache) { |
| 412 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
| 413 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 414 | loc, it.second, result, it.first, dstStrides); |
| 415 | } |
| 416 | rewriter.replaceOp(reductionOp, result); |
| 417 | return success(); |
| 418 | } |
| 419 | |
| 420 | private: |
| 421 | vector::UnrollVectorOptions options; |
| 422 | }; |
| 423 | |
| 424 | struct UnrollElementwisePattern : public RewritePattern { |
| 425 | UnrollElementwisePattern(MLIRContext *context, |
| 426 | const vector::UnrollVectorOptions &options, |
| 427 | PatternBenefit benefit = 1) |
| 428 | : RewritePattern(MatchAnyOpTypeTag(), benefit, context), |
| 429 | options(options) {} |
| 430 | |
| 431 | LogicalResult matchAndRewrite(Operation *op, |
| 432 | PatternRewriter &rewriter) const override { |
| 433 | if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
| 434 | return failure(); |
| 435 | auto targetShape = getTargetShape(options, op); |
| 436 | if (!targetShape) |
| 437 | return failure(); |
| 438 | auto dstVecType = cast<VectorType>(op->getResult(idx: 0).getType()); |
| 439 | SmallVector<int64_t> originalSize = |
| 440 | *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); |
| 441 | // Bail-out if rank(source) != rank(target). The main limitation here is the |
| 442 | // fact that `ExtractStridedSlice` requires the rank for the input and |
| 443 | // output to match. If needed, we can relax this later. |
| 444 | if (originalSize.size() != targetShape->size()) |
| 445 | return rewriter.notifyMatchFailure( |
| 446 | arg&: op, msg: "expected input vector rank to match target shape rank" ); |
| 447 | Location loc = op->getLoc(); |
| 448 | // Prepare the result vector. |
| 449 | Value result = rewriter.create<arith::ConstantOp>( |
| 450 | loc, dstVecType, rewriter.getZeroAttr(dstVecType)); |
| 451 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 452 | VectorType newVecType = |
| 453 | VectorType::get(*targetShape, dstVecType.getElementType()); |
| 454 | |
| 455 | // Create the unrolled computation. |
| 456 | for (SmallVector<int64_t> offsets : |
| 457 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 458 | SmallVector<Value> extractOperands; |
| 459 | for (OpOperand &operand : op->getOpOperands()) { |
| 460 | auto vecType = dyn_cast<VectorType>(operand.get().getType()); |
| 461 | if (!vecType) { |
| 462 | extractOperands.push_back(operand.get()); |
| 463 | continue; |
| 464 | } |
| 465 | extractOperands.push_back( |
| 466 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 467 | loc, operand.get(), offsets, *targetShape, strides)); |
| 468 | } |
| 469 | Operation *newOp = cloneOpWithOperandsAndTypes( |
| 470 | rewriter, loc, op, extractOperands, newVecType); |
| 471 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 472 | loc, newOp->getResult(0), result, offsets, strides); |
| 473 | } |
| 474 | rewriter.replaceOp(op, newValues: result); |
| 475 | return success(); |
| 476 | } |
| 477 | |
| 478 | private: |
| 479 | vector::UnrollVectorOptions options; |
| 480 | }; |
| 481 | |
| 482 | struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { |
| 483 | UnrollReductionPattern(MLIRContext *context, |
| 484 | const vector::UnrollVectorOptions &options, |
| 485 | PatternBenefit benefit = 1) |
| 486 | : OpRewritePattern<vector::ReductionOp>(context, benefit), |
| 487 | options(options) {} |
| 488 | |
| 489 | LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, |
| 490 | PatternRewriter &rewriter) const override { |
| 491 | std::optional<SmallVector<int64_t>> targetShape = |
| 492 | getTargetShape(options, reductionOp); |
| 493 | if (!targetShape) |
| 494 | return failure(); |
| 495 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
| 496 | |
| 497 | // Create unrolled vector reduction. |
| 498 | Location loc = reductionOp.getLoc(); |
| 499 | Value accumulator = nullptr; |
| 500 | for (SmallVector<int64_t> offsets : |
| 501 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 502 | SmallVector<int64_t> strides(offsets.size(), 1); |
| 503 | Value slicedOperand = |
| 504 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 505 | loc, reductionOp.getVector(), offsets, *targetShape, strides); |
| 506 | Operation *newOp = cloneOpWithOperandsAndTypes( |
| 507 | rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); |
| 508 | Value result = newOp->getResult(0); |
| 509 | |
| 510 | if (!accumulator) { |
| 511 | // This is the first reduction. |
| 512 | accumulator = result; |
| 513 | } else { |
| 514 | // On subsequent reduction, combine with the accumulator. |
| 515 | accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), |
| 516 | accumulator, result); |
| 517 | } |
| 518 | } |
| 519 | |
| 520 | rewriter.replaceOp(reductionOp, accumulator); |
| 521 | return success(); |
| 522 | } |
| 523 | |
| 524 | private: |
| 525 | const vector::UnrollVectorOptions options; |
| 526 | }; |
| 527 | |
| 528 | struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { |
| 529 | UnrollTransposePattern(MLIRContext *context, |
| 530 | const vector::UnrollVectorOptions &options, |
| 531 | PatternBenefit benefit = 1) |
| 532 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 533 | options(options) {} |
| 534 | |
| 535 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| 536 | PatternRewriter &rewriter) const override { |
| 537 | if (transposeOp.getResultVectorType().getRank() == 0) |
| 538 | return failure(); |
| 539 | auto targetShape = getTargetShape(options, transposeOp); |
| 540 | if (!targetShape) |
| 541 | return failure(); |
| 542 | auto originalVectorType = transposeOp.getResultVectorType(); |
| 543 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 544 | Location loc = transposeOp.getLoc(); |
| 545 | ArrayRef<int64_t> originalSize = originalVectorType.getShape(); |
| 546 | |
| 547 | // Prepare the result vector; |
| 548 | Value result = rewriter.create<arith::ConstantOp>( |
| 549 | loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); |
| 550 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
| 551 | |
| 552 | // Unroll the computation. |
| 553 | for (SmallVector<int64_t> elementOffsets : |
| 554 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 555 | SmallVector<int64_t> permutedOffsets(elementOffsets.size()); |
| 556 | SmallVector<int64_t> permutedShape(elementOffsets.size()); |
| 557 | // Compute the source offsets and shape. |
| 558 | for (auto indices : llvm::enumerate(permutation)) { |
| 559 | permutedOffsets[indices.value()] = elementOffsets[indices.index()]; |
| 560 | permutedShape[indices.value()] = (*targetShape)[indices.index()]; |
| 561 | } |
| 562 | Value slicedOperand = |
| 563 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 564 | loc, transposeOp.getVector(), permutedOffsets, permutedShape, |
| 565 | strides); |
| 566 | Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>( |
| 567 | loc, slicedOperand, permutation); |
| 568 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 569 | loc, transposedSlice, result, elementOffsets, strides); |
| 570 | } |
| 571 | rewriter.replaceOp(transposeOp, result); |
| 572 | return success(); |
| 573 | } |
| 574 | |
| 575 | private: |
| 576 | vector::UnrollVectorOptions options; |
| 577 | }; |
| 578 | |
| 579 | struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { |
| 580 | UnrollGatherPattern(MLIRContext *context, |
| 581 | const vector::UnrollVectorOptions &options, |
| 582 | PatternBenefit benefit = 1) |
| 583 | : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { |
| 584 | } |
| 585 | |
| 586 | LogicalResult matchAndRewrite(vector::GatherOp gatherOp, |
| 587 | PatternRewriter &rewriter) const override { |
| 588 | VectorType sourceVectorType = gatherOp.getVectorType(); |
| 589 | if (sourceVectorType.getRank() == 0) |
| 590 | return failure(); |
| 591 | auto targetShape = getTargetShape(options, gatherOp); |
| 592 | if (!targetShape) |
| 593 | return failure(); |
| 594 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 595 | Location loc = gatherOp.getLoc(); |
| 596 | ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); |
| 597 | |
| 598 | // Prepare the result vector; |
| 599 | Value result = rewriter.create<arith::ConstantOp>( |
| 600 | loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); |
| 601 | auto targetType = |
| 602 | VectorType::get(*targetShape, sourceVectorType.getElementType()); |
| 603 | |
| 604 | SmallVector<int64_t> loopOrder = |
| 605 | getUnrollOrder(originalSize.size(), gatherOp, options); |
| 606 | for (SmallVector<int64_t> elementOffsets : |
| 607 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 608 | // To get the unrolled gather, extract the same slice based on the |
| 609 | // decomposed shape from each of the index, mask, and pass-through |
| 610 | // vectors. |
| 611 | Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 612 | loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); |
| 613 | Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 614 | loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); |
| 615 | Value passThruSubVec = |
| 616 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 617 | loc, gatherOp.getPassThru(), elementOffsets, *targetShape, |
| 618 | strides); |
| 619 | auto slicedGather = rewriter.create<vector::GatherOp>( |
| 620 | loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), |
| 621 | indexSubVec, maskSubVec, passThruSubVec); |
| 622 | |
| 623 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 624 | loc, slicedGather, result, elementOffsets, strides); |
| 625 | } |
| 626 | rewriter.replaceOp(gatherOp, result); |
| 627 | return success(); |
| 628 | } |
| 629 | |
| 630 | private: |
| 631 | vector::UnrollVectorOptions options; |
| 632 | }; |
| 633 | |
| 634 | struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> { |
| 635 | UnrollBroadcastPattern(MLIRContext *context, |
| 636 | const vector::UnrollVectorOptions &options, |
| 637 | PatternBenefit benefit = 1) |
| 638 | : OpRewritePattern<vector::BroadcastOp>(context, benefit), |
| 639 | options(options) {} |
| 640 | |
| 641 | LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
| 642 | PatternRewriter &rewriter) const override { |
| 643 | auto targetShape = getTargetShape(options, broadcastOp); |
| 644 | if (!targetShape) |
| 645 | return failure(); |
| 646 | |
| 647 | Location loc = broadcastOp.getLoc(); |
| 648 | VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); |
| 649 | VectorType resType = broadcastOp.getResultVectorType(); |
| 650 | VectorType targetType = |
| 651 | resType.cloneWith(*targetShape, resType.getElementType()); |
| 652 | Value result = rewriter.create<arith::ConstantOp>( |
| 653 | loc, resType, rewriter.getZeroAttr(resType)); |
| 654 | |
| 655 | SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll(); |
| 656 | SmallVector<int64_t> strides(originalShape.size(), 1); |
| 657 | |
| 658 | for (SmallVector<int64_t> offsets : |
| 659 | StaticTileOffsetRange(originalShape, *targetShape)) { |
| 660 | Value newSrc; |
| 661 | if (!srcType) { |
| 662 | // Scalar to vector broadcast. |
| 663 | newSrc = broadcastOp.getSource(); |
| 664 | } else { |
| 665 | // Vector to vector broadcast. |
| 666 | int64_t rank = srcType.getRank(); |
| 667 | SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end()); |
| 668 | SmallVector<int64_t> srcShape(targetShape->end() - rank, |
| 669 | targetShape->end()); |
| 670 | SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end()); |
| 671 | // adjust the offset and shape for src if the corresponding dim is 1. |
| 672 | for (int64_t i = 0; i < rank; ++i) { |
| 673 | if (srcType.getDimSize(i) == 1) { |
| 674 | srcOffsets[i] = 0; |
| 675 | srcShape[i] = 1; |
| 676 | } |
| 677 | } |
| 678 | newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 679 | loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides); |
| 680 | } |
| 681 | |
| 682 | Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp, |
| 683 | newSrc, targetType); |
| 684 | |
| 685 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 686 | loc, newOp->getResult(0), result, offsets, strides); |
| 687 | } |
| 688 | |
| 689 | rewriter.replaceOp(broadcastOp, result); |
| 690 | return success(); |
| 691 | } |
| 692 | |
| 693 | private: |
| 694 | vector::UnrollVectorOptions options; |
| 695 | }; |
| 696 | |
| 697 | } // namespace |
| 698 | |
| 699 | void mlir::vector::populateVectorUnrollPatterns( |
| 700 | RewritePatternSet &patterns, const UnrollVectorOptions &options, |
| 701 | PatternBenefit benefit) { |
| 702 | patterns |
| 703 | .add<UnrollTransferReadPattern, UnrollTransferWritePattern, |
| 704 | UnrollContractionPattern, UnrollElementwisePattern, |
| 705 | UnrollReductionPattern, UnrollMultiReductionPattern, |
| 706 | UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>( |
| 707 | arg: patterns.getContext(), args: options, args&: benefit); |
| 708 | } |
| 709 | |