| 1 | //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements patterns to do vector unrolling and vector distribution. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 14 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 15 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 16 | #include "mlir/Interfaces/VectorInterfaces.h" |
| 17 | #include "llvm/ADT/MapVector.h" |
| 18 | #include "llvm/ADT/STLExtras.h" |
| 19 | #include "llvm/Support/Debug.h" |
| 20 | #include "llvm/Support/InterleavedRange.h" |
| 21 | #include <optional> |
| 22 | |
| 23 | #define DEBUG_TYPE "vector-unroll" |
| 24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 25 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::vector; |
| 29 | |
| 30 | /// Compute the indices of the slice `index` for a transfer op. |
| 31 | static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, |
| 32 | ArrayRef<Value> indices, |
| 33 | AffineMap permutationMap, |
| 34 | Location loc, |
| 35 | OpBuilder &builder) { |
| 36 | MLIRContext *ctx = builder.getContext(); |
| 37 | auto isBroadcast = [](AffineExpr expr) { |
| 38 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) |
| 39 | return constExpr.getValue() == 0; |
| 40 | return false; |
| 41 | }; |
| 42 | // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. |
| 43 | SmallVector<Value> slicedIndices(indices); |
| 44 | for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) { |
| 45 | if (isBroadcast(dim.value())) |
| 46 | continue; |
| 47 | unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition(); |
| 48 | auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) + |
| 49 | getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx); |
| 50 | auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr); |
| 51 | slicedIndices[pos] = |
| 52 | builder.create<affine::AffineApplyOp>(location: loc, args&: map, args: indices[pos]); |
| 53 | } |
| 54 | return slicedIndices; |
| 55 | } |
| 56 | |
| 57 | // Compute the new indices by adding `offsets` to `originalIndices`. |
| 58 | // If m < n (m = offsets.size(), n = originalIndices.size()), |
| 59 | // then only the trailing m values in `originalIndices` are updated. |
| 60 | static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter, |
| 61 | Location loc, |
| 62 | OperandRange originalIndices, |
| 63 | ArrayRef<int64_t> offsets) { |
| 64 | assert(offsets.size() <= originalIndices.size() && |
| 65 | "Offsets should not exceed the number of original indices" ); |
| 66 | SmallVector<Value> indices(originalIndices); |
| 67 | |
| 68 | auto start = indices.size() - offsets.size(); |
| 69 | for (auto [i, offset] : llvm::enumerate(First&: offsets)) { |
| 70 | if (offset != 0) { |
| 71 | indices[start + i] = rewriter.create<arith::AddIOp>( |
| 72 | location: loc, args: originalIndices[start + i], |
| 73 | args: rewriter.create<arith::ConstantIndexOp>(location: loc, args: offset)); |
| 74 | } |
| 75 | } |
| 76 | return indices; |
| 77 | } |
| 78 | |
| 79 | // Clones `op` into a new operations that takes `operands` and returns |
| 80 | // `resultTypes`. |
| 81 | static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, |
| 82 | Operation *op, |
| 83 | ArrayRef<Value> operands, |
| 84 | ArrayRef<Type> resultTypes) { |
| 85 | return builder.create(loc, opName: op->getName().getIdentifier(), operands, |
| 86 | types: resultTypes, attributes: op->getAttrs()); |
| 87 | } |
| 88 | |
| 89 | /// Return the target shape for unrolling for the given `op`. Return |
| 90 | /// std::nullopt if the op shouldn't be or cannot be unrolled. |
| 91 | static std::optional<SmallVector<int64_t>> |
| 92 | getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { |
| 93 | LDBG("" ); |
| 94 | LDBG("Get unroll shape for op " << op->getName().getStringRef()); |
| 95 | if (options.filterConstraint && failed(Result: options.filterConstraint(op))) { |
| 96 | LDBG("--no filter constraint -> BAIL" ); |
| 97 | return std::nullopt; |
| 98 | } |
| 99 | assert(options.nativeShape && |
| 100 | "vector unrolling expects the native shape or native" |
| 101 | "shape call back function to be set" ); |
| 102 | auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(Val: op); |
| 103 | if (!unrollableVectorOp) { |
| 104 | LDBG("--not an unrollable op -> BAIL" ); |
| 105 | return std::nullopt; |
| 106 | } |
| 107 | auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); |
| 108 | if (!maybeUnrollShape) { |
| 109 | LDBG("--could not get shape of op " << *op << " -> BAIL" ); |
| 110 | return std::nullopt; |
| 111 | } |
| 112 | LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape)); |
| 113 | |
| 114 | std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); |
| 115 | if (!targetShape) { |
| 116 | LDBG("--no unrolling target shape defined " << *op << "-> SKIP" ); |
| 117 | return std::nullopt; |
| 118 | } |
| 119 | LDBG("--target shape: " << llvm::interleaved(*targetShape)); |
| 120 | |
| 121 | auto maybeShapeRatio = computeShapeRatio(shape: *maybeUnrollShape, subShape: *targetShape); |
| 122 | if (!maybeShapeRatio) { |
| 123 | LDBG("--could not compute integral shape ratio -> BAIL" ); |
| 124 | return std::nullopt; |
| 125 | } |
| 126 | if (llvm::all_of(Range&: *maybeShapeRatio, P: [](int64_t v) { return v == 1; })) { |
| 127 | LDBG("--no unrolling needed -> SKIP" ); |
| 128 | return std::nullopt; |
| 129 | } |
| 130 | LDBG("--found an integral shape ratio to unroll to -> SUCCESS" ); |
| 131 | return targetShape; |
| 132 | } |
| 133 | |
| 134 | static SmallVector<int64_t> |
| 135 | getUnrollOrder(unsigned numLoops, Operation *op, |
| 136 | const vector::UnrollVectorOptions &options) { |
| 137 | SmallVector<int64_t> loopOrder = |
| 138 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops))); |
| 139 | if (options.traversalOrderCallback != nullptr) { |
| 140 | std::optional<SmallVector<int64_t>> order = |
| 141 | options.traversalOrderCallback(op); |
| 142 | if (order) { |
| 143 | loopOrder = std::move(*order); |
| 144 | } |
| 145 | } |
| 146 | return loopOrder; |
| 147 | } |
| 148 | |
| 149 | namespace { |
| 150 | |
| 151 | struct UnrollTransferReadPattern |
| 152 | : public OpRewritePattern<vector::TransferReadOp> { |
| 153 | UnrollTransferReadPattern(MLIRContext *context, |
| 154 | const vector::UnrollVectorOptions &options, |
| 155 | PatternBenefit benefit = 1) |
| 156 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
| 157 | options(options) {} |
| 158 | |
| 159 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
| 160 | PatternRewriter &rewriter) const override { |
| 161 | // TODO: support 0-d corner case. |
| 162 | if (readOp.getTransferRank() == 0) |
| 163 | return failure(); |
| 164 | if (readOp.getMask()) |
| 165 | return failure(); |
| 166 | auto targetShape = getTargetShape(options, op: readOp); |
| 167 | if (!targetShape) |
| 168 | return failure(); |
| 169 | auto sourceVectorType = readOp.getVectorType(); |
| 170 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 171 | Location loc = readOp.getLoc(); |
| 172 | ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); |
| 173 | |
| 174 | // Prepare the result vector; |
| 175 | Value result = rewriter.create<arith::ConstantOp>( |
| 176 | location: loc, args&: sourceVectorType, args: rewriter.getZeroAttr(type: sourceVectorType)); |
| 177 | auto targetType = |
| 178 | VectorType::get(shape: *targetShape, elementType: sourceVectorType.getElementType()); |
| 179 | SmallVector<Value> originalIndices(readOp.getIndices().begin(), |
| 180 | readOp.getIndices().end()); |
| 181 | SmallVector<int64_t> loopOrder = |
| 182 | getUnrollOrder(numLoops: originalSize.size(), op: readOp, options); |
| 183 | for (SmallVector<int64_t> elementOffsets : |
| 184 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 185 | SmallVector<Value> indices = |
| 186 | sliceTransferIndices(elementOffsets, indices: originalIndices, |
| 187 | permutationMap: readOp.getPermutationMap(), loc, builder&: rewriter); |
| 188 | auto slicedRead = rewriter.create<vector::TransferReadOp>( |
| 189 | location: loc, args&: targetType, args: readOp.getBase(), args&: indices, |
| 190 | args: readOp.getPermutationMapAttr(), args: readOp.getPadding(), args: readOp.getMask(), |
| 191 | args: readOp.getInBoundsAttr()); |
| 192 | |
| 193 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 194 | location: loc, args&: slicedRead, args&: result, args&: elementOffsets, args&: strides); |
| 195 | } |
| 196 | rewriter.replaceOp(op: readOp, newValues: result); |
| 197 | return success(); |
| 198 | } |
| 199 | |
| 200 | private: |
| 201 | vector::UnrollVectorOptions options; |
| 202 | }; |
| 203 | |
| 204 | struct UnrollTransferWritePattern |
| 205 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 206 | UnrollTransferWritePattern(MLIRContext *context, |
| 207 | const vector::UnrollVectorOptions &options, |
| 208 | PatternBenefit benefit = 1) |
| 209 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
| 210 | options(options) {} |
| 211 | |
| 212 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| 213 | PatternRewriter &rewriter) const override { |
| 214 | // TODO: support 0-d corner case. |
| 215 | if (writeOp.getTransferRank() == 0) |
| 216 | return failure(); |
| 217 | |
| 218 | if (writeOp.getMask()) |
| 219 | return failure(); |
| 220 | auto targetShape = getTargetShape(options, op: writeOp); |
| 221 | if (!targetShape) |
| 222 | return failure(); |
| 223 | auto sourceVectorType = writeOp.getVectorType(); |
| 224 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 225 | Location loc = writeOp.getLoc(); |
| 226 | ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); |
| 227 | SmallVector<Value> originalIndices(writeOp.getIndices().begin(), |
| 228 | writeOp.getIndices().end()); |
| 229 | SmallVector<int64_t> loopOrder = |
| 230 | getUnrollOrder(numLoops: originalSize.size(), op: writeOp, options); |
| 231 | Value resultTensor; |
| 232 | for (SmallVector<int64_t> elementOffsets : |
| 233 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 234 | Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 235 | location: loc, args: writeOp.getVector(), args&: elementOffsets, args&: *targetShape, args&: strides); |
| 236 | SmallVector<Value> indices = |
| 237 | sliceTransferIndices(elementOffsets, indices: originalIndices, |
| 238 | permutationMap: writeOp.getPermutationMap(), loc, builder&: rewriter); |
| 239 | Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( |
| 240 | location: loc, args&: slicedVector, args: resultTensor ? resultTensor : writeOp.getBase(), |
| 241 | args&: indices, args: writeOp.getPermutationMapAttr(), args: writeOp.getInBoundsAttr()); |
| 242 | // For the tensor case update the destination for the next transfer write. |
| 243 | if (!slicedWrite->getResults().empty()) |
| 244 | resultTensor = slicedWrite->getResult(idx: 0); |
| 245 | } |
| 246 | if (resultTensor) |
| 247 | rewriter.replaceOp(op: writeOp, newValues: resultTensor); |
| 248 | else |
| 249 | rewriter.eraseOp(op: writeOp); |
| 250 | return success(); |
| 251 | } |
| 252 | |
| 253 | private: |
| 254 | vector::UnrollVectorOptions options; |
| 255 | }; |
| 256 | |
| 257 | struct OffsetMapInfo { |
| 258 | static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } |
| 259 | |
| 260 | static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } |
| 261 | |
| 262 | static unsigned getHashValue(const SmallVector<int64_t> &v) { |
| 263 | return static_cast<unsigned>(llvm::hash_combine_range(R: v)); |
| 264 | } |
| 265 | |
| 266 | static bool isEqual(const SmallVector<int64_t> &lhs, |
| 267 | const SmallVector<int64_t> &rhs) { |
| 268 | return lhs == rhs; |
| 269 | } |
| 270 | }; |
| 271 | |
| 272 | struct UnrollContractionPattern |
| 273 | : public OpRewritePattern<vector::ContractionOp> { |
| 274 | UnrollContractionPattern(MLIRContext *context, |
| 275 | const vector::UnrollVectorOptions &options, |
| 276 | PatternBenefit benefit = 1) |
| 277 | : OpRewritePattern<vector::ContractionOp>(context, benefit), |
| 278 | options(options) {} |
| 279 | |
| 280 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
| 281 | PatternRewriter &rewriter) const override { |
| 282 | auto targetShape = getTargetShape(options, op: contractOp); |
| 283 | if (!targetShape) |
| 284 | return failure(); |
| 285 | auto dstVecType = cast<VectorType>(Val: contractOp.getResultType()); |
| 286 | SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); |
| 287 | |
| 288 | Location loc = contractOp.getLoc(); |
| 289 | unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); |
| 290 | AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; |
| 291 | llvm::MapVector< |
| 292 | SmallVector<int64_t>, Value, |
| 293 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
| 294 | accCache; |
| 295 | |
| 296 | SmallVector<int64_t> loopOrder = getUnrollOrder( |
| 297 | numLoops: contractOp.getIteratorTypes().size(), op: contractOp, options); |
| 298 | |
| 299 | for (SmallVector<int64_t> offsets : |
| 300 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 301 | SmallVector<Value> slicesOperands(contractOp.getNumOperands()); |
| 302 | |
| 303 | // Helper to compute the new shape of each operand and extract the slice. |
| 304 | auto extractOperand = [&](unsigned index, Value operand, |
| 305 | AffineMap permutationMap, |
| 306 | ArrayRef<int64_t> operandOffets) { |
| 307 | SmallVector<int64_t> operandShape = applyPermutationMap( |
| 308 | map: permutationMap, source: ArrayRef<int64_t>(*targetShape)); |
| 309 | SmallVector<int64_t> operandStrides(operandOffets.size(), 1); |
| 310 | slicesOperands[index] = |
| 311 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 312 | location: loc, args&: operand, args&: operandOffets, args&: operandShape, args&: operandStrides); |
| 313 | }; |
| 314 | |
| 315 | // Extract the new lhs operand. |
| 316 | AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; |
| 317 | SmallVector<int64_t> lhsOffets = |
| 318 | applyPermutationMap(map: lhsPermutationMap, source: ArrayRef<int64_t>(offsets)); |
| 319 | extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); |
| 320 | |
| 321 | // Extract the new rhs operand. |
| 322 | AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; |
| 323 | SmallVector<int64_t> rhsOffets = |
| 324 | applyPermutationMap(map: rhsPermutationMap, source: ArrayRef<int64_t>(offsets)); |
| 325 | extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); |
| 326 | |
| 327 | AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; |
| 328 | SmallVector<int64_t> accOffets = |
| 329 | applyPermutationMap(map: accPermutationMap, source: ArrayRef<int64_t>(offsets)); |
| 330 | // If a version of the accumulator has already been computed, use it |
| 331 | // otherwise extract the first version from the original operand. |
| 332 | auto *accIt = accCache.find(Key: accOffets); |
| 333 | if (accIt != accCache.end()) |
| 334 | slicesOperands[2] = accIt->second; |
| 335 | else |
| 336 | extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); |
| 337 | |
| 338 | SmallVector<int64_t> dstShape = |
| 339 | applyPermutationMap(map: dstAffineMap, source: ArrayRef<int64_t>(*targetShape)); |
| 340 | auto targetType = VectorType::get(shape: dstShape, elementType: dstVecType.getElementType()); |
| 341 | Operation *newOp = cloneOpWithOperandsAndTypes( |
| 342 | builder&: rewriter, loc, op: contractOp, operands: slicesOperands, resultTypes: targetType); |
| 343 | |
| 344 | SmallVector<int64_t> dstOffets = |
| 345 | applyPermutationMap(map: dstAffineMap, source: ArrayRef<int64_t>(offsets)); |
| 346 | // Save the accumulated value untill all the loops are unrolled since |
| 347 | // reduction loop keep updating the accumulator. |
| 348 | accCache[dstOffets] = newOp->getResult(idx: 0); |
| 349 | } |
| 350 | // Assemble back the accumulator into a single vector. |
| 351 | Value result = rewriter.create<arith::ConstantOp>( |
| 352 | location: loc, args&: dstVecType, args: rewriter.getZeroAttr(type: dstVecType)); |
| 353 | for (const auto &it : accCache) { |
| 354 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
| 355 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 356 | location: loc, args: it.second, args&: result, args: it.first, args&: dstStrides); |
| 357 | } |
| 358 | rewriter.replaceOp(op: contractOp, newValues: result); |
| 359 | return success(); |
| 360 | } |
| 361 | |
| 362 | private: |
| 363 | vector::UnrollVectorOptions options; |
| 364 | }; |
| 365 | |
| 366 | struct UnrollMultiReductionPattern |
| 367 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
| 368 | UnrollMultiReductionPattern(MLIRContext *context, |
| 369 | const vector::UnrollVectorOptions &options, |
| 370 | PatternBenefit benefit = 1) |
| 371 | : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
| 372 | options(options) {} |
| 373 | |
| 374 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, |
| 375 | PatternRewriter &rewriter) const override { |
| 376 | auto resultType = reductionOp->getResult(idx: 0).getType(); |
| 377 | if (resultType.isIntOrFloat()) { |
| 378 | return rewriter.notifyMatchFailure(arg&: reductionOp, |
| 379 | msg: "Unrolling scalars is not supported" ); |
| 380 | } |
| 381 | std::optional<SmallVector<int64_t>> targetShape = |
| 382 | getTargetShape(options, op: reductionOp); |
| 383 | if (!targetShape) |
| 384 | return failure(); |
| 385 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
| 386 | llvm::MapVector< |
| 387 | SmallVector<int64_t>, Value, |
| 388 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
| 389 | accCache; |
| 390 | Location loc = reductionOp.getLoc(); |
| 391 | |
| 392 | // Stride of the ratios, this gives us the offsets of sliceCount in a basis |
| 393 | // of multiples of the targetShape. |
| 394 | for (SmallVector<int64_t> offsets : |
| 395 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 396 | SmallVector<Value> operands; |
| 397 | SmallVector<int64_t> operandStrides(offsets.size(), 1); |
| 398 | Value slicedOperand = |
| 399 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 400 | location: loc, args: reductionOp.getSource(), args&: offsets, args&: *targetShape, |
| 401 | args&: operandStrides); |
| 402 | operands.push_back(Elt: slicedOperand); |
| 403 | SmallVector<int64_t> dstShape; |
| 404 | SmallVector<int64_t> destOffset; |
| 405 | for (size_t i : llvm::seq(Begin: size_t(0), End: targetShape->size())) { |
| 406 | if (!reductionOp.isReducedDim(d: i)) { |
| 407 | destOffset.push_back(Elt: offsets[i]); |
| 408 | dstShape.push_back(Elt: (*targetShape)[i]); |
| 409 | } |
| 410 | } |
| 411 | Value acc; |
| 412 | SmallVector<int64_t> accStrides(destOffset.size(), 1); |
| 413 | // If a version of the accumulator has already been computed, use it |
| 414 | // otherwise extract the first version from the original operand. |
| 415 | auto *accIt = accCache.find(Key: destOffset); |
| 416 | if (accIt != accCache.end()) |
| 417 | acc = accIt->second; |
| 418 | else |
| 419 | acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 420 | location: loc, args: reductionOp.getAcc(), args&: destOffset, args&: dstShape, args&: accStrides); |
| 421 | operands.push_back(Elt: acc); |
| 422 | auto targetType = VectorType::get( |
| 423 | shape: dstShape, elementType: reductionOp.getSourceVectorType().getElementType()); |
| 424 | Operation *newOp = cloneOpWithOperandsAndTypes(builder&: rewriter, loc, op: reductionOp, |
| 425 | operands, resultTypes: targetType); |
| 426 | Value result = newOp->getResult(idx: 0); |
| 427 | accCache[destOffset] = result; |
| 428 | } |
| 429 | // Assemble back the accumulator into a single vector. |
| 430 | Value result = rewriter.create<arith::ConstantOp>( |
| 431 | location: loc, args: reductionOp.getDestType(), |
| 432 | args: rewriter.getZeroAttr(type: reductionOp.getDestType())); |
| 433 | for (const auto &it : accCache) { |
| 434 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
| 435 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 436 | location: loc, args: it.second, args&: result, args: it.first, args&: dstStrides); |
| 437 | } |
| 438 | rewriter.replaceOp(op: reductionOp, newValues: result); |
| 439 | return success(); |
| 440 | } |
| 441 | |
| 442 | private: |
| 443 | vector::UnrollVectorOptions options; |
| 444 | }; |
| 445 | |
| 446 | struct UnrollElementwisePattern : public RewritePattern { |
| 447 | UnrollElementwisePattern(MLIRContext *context, |
| 448 | const vector::UnrollVectorOptions &options, |
| 449 | PatternBenefit benefit = 1) |
| 450 | : RewritePattern(MatchAnyOpTypeTag(), benefit, context), |
| 451 | options(options) {} |
| 452 | |
| 453 | LogicalResult matchAndRewrite(Operation *op, |
| 454 | PatternRewriter &rewriter) const override { |
| 455 | if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
| 456 | return failure(); |
| 457 | auto targetShape = getTargetShape(options, op); |
| 458 | if (!targetShape) |
| 459 | return failure(); |
| 460 | auto dstVecType = cast<VectorType>(Val: op->getResult(idx: 0).getType()); |
| 461 | SmallVector<int64_t> originalSize = |
| 462 | *cast<VectorUnrollOpInterface>(Val: op).getShapeForUnroll(); |
| 463 | // Bail-out if rank(source) != rank(target). The main limitation here is the |
| 464 | // fact that `ExtractStridedSlice` requires the rank for the input and |
| 465 | // output to match. If needed, we can relax this later. |
| 466 | if (originalSize.size() != targetShape->size()) |
| 467 | return rewriter.notifyMatchFailure( |
| 468 | arg&: op, msg: "expected input vector rank to match target shape rank" ); |
| 469 | Location loc = op->getLoc(); |
| 470 | // Prepare the result vector. |
| 471 | Value result = rewriter.create<arith::ConstantOp>( |
| 472 | location: loc, args&: dstVecType, args: rewriter.getZeroAttr(type: dstVecType)); |
| 473 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 474 | VectorType newVecType = |
| 475 | VectorType::get(shape: *targetShape, elementType: dstVecType.getElementType()); |
| 476 | |
| 477 | // Create the unrolled computation. |
| 478 | for (SmallVector<int64_t> offsets : |
| 479 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 480 | SmallVector<Value> extractOperands; |
| 481 | for (OpOperand &operand : op->getOpOperands()) { |
| 482 | auto vecType = dyn_cast<VectorType>(Val: operand.get().getType()); |
| 483 | if (!vecType) { |
| 484 | extractOperands.push_back(Elt: operand.get()); |
| 485 | continue; |
| 486 | } |
| 487 | extractOperands.push_back( |
| 488 | Elt: rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 489 | location: loc, args: operand.get(), args&: offsets, args&: *targetShape, args&: strides)); |
| 490 | } |
| 491 | Operation *newOp = cloneOpWithOperandsAndTypes( |
| 492 | builder&: rewriter, loc, op, operands: extractOperands, resultTypes: newVecType); |
| 493 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 494 | location: loc, args: newOp->getResult(idx: 0), args&: result, args&: offsets, args&: strides); |
| 495 | } |
| 496 | rewriter.replaceOp(op, newValues: result); |
| 497 | return success(); |
| 498 | } |
| 499 | |
| 500 | private: |
| 501 | vector::UnrollVectorOptions options; |
| 502 | }; |
| 503 | |
| 504 | struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { |
| 505 | UnrollReductionPattern(MLIRContext *context, |
| 506 | const vector::UnrollVectorOptions &options, |
| 507 | PatternBenefit benefit = 1) |
| 508 | : OpRewritePattern<vector::ReductionOp>(context, benefit), |
| 509 | options(options) {} |
| 510 | |
| 511 | LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, |
| 512 | PatternRewriter &rewriter) const override { |
| 513 | std::optional<SmallVector<int64_t>> targetShape = |
| 514 | getTargetShape(options, op: reductionOp); |
| 515 | if (!targetShape) |
| 516 | return failure(); |
| 517 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
| 518 | |
| 519 | // Create unrolled vector reduction. |
| 520 | Location loc = reductionOp.getLoc(); |
| 521 | Value accumulator = nullptr; |
| 522 | for (SmallVector<int64_t> offsets : |
| 523 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 524 | SmallVector<int64_t> strides(offsets.size(), 1); |
| 525 | Value slicedOperand = |
| 526 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 527 | location: loc, args: reductionOp.getVector(), args&: offsets, args&: *targetShape, args&: strides); |
| 528 | Operation *newOp = cloneOpWithOperandsAndTypes( |
| 529 | builder&: rewriter, loc, op: reductionOp, operands: slicedOperand, resultTypes: reductionOp.getType()); |
| 530 | Value result = newOp->getResult(idx: 0); |
| 531 | |
| 532 | if (!accumulator) { |
| 533 | // This is the first reduction. |
| 534 | accumulator = result; |
| 535 | } else { |
| 536 | // On subsequent reduction, combine with the accumulator. |
| 537 | accumulator = makeArithReduction(b&: rewriter, loc, kind: reductionOp.getKind(), |
| 538 | v1: accumulator, acc: result); |
| 539 | } |
| 540 | } |
| 541 | |
| 542 | rewriter.replaceOp(op: reductionOp, newValues: accumulator); |
| 543 | return success(); |
| 544 | } |
| 545 | |
| 546 | private: |
| 547 | const vector::UnrollVectorOptions options; |
| 548 | }; |
| 549 | |
| 550 | struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { |
| 551 | UnrollTransposePattern(MLIRContext *context, |
| 552 | const vector::UnrollVectorOptions &options, |
| 553 | PatternBenefit benefit = 1) |
| 554 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
| 555 | options(options) {} |
| 556 | |
| 557 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| 558 | PatternRewriter &rewriter) const override { |
| 559 | if (transposeOp.getResultVectorType().getRank() == 0) |
| 560 | return failure(); |
| 561 | auto targetShape = getTargetShape(options, op: transposeOp); |
| 562 | if (!targetShape) |
| 563 | return failure(); |
| 564 | auto originalVectorType = transposeOp.getResultVectorType(); |
| 565 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 566 | Location loc = transposeOp.getLoc(); |
| 567 | ArrayRef<int64_t> originalSize = originalVectorType.getShape(); |
| 568 | |
| 569 | // Prepare the result vector; |
| 570 | Value result = rewriter.create<arith::ConstantOp>( |
| 571 | location: loc, args&: originalVectorType, args: rewriter.getZeroAttr(type: originalVectorType)); |
| 572 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
| 573 | |
| 574 | // Unroll the computation. |
| 575 | for (SmallVector<int64_t> elementOffsets : |
| 576 | StaticTileOffsetRange(originalSize, *targetShape)) { |
| 577 | SmallVector<int64_t> permutedOffsets(elementOffsets.size()); |
| 578 | SmallVector<int64_t> permutedShape(elementOffsets.size()); |
| 579 | // Compute the source offsets and shape. |
| 580 | for (auto indices : llvm::enumerate(First&: permutation)) { |
| 581 | permutedOffsets[indices.value()] = elementOffsets[indices.index()]; |
| 582 | permutedShape[indices.value()] = (*targetShape)[indices.index()]; |
| 583 | } |
| 584 | Value slicedOperand = |
| 585 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 586 | location: loc, args: transposeOp.getVector(), args&: permutedOffsets, args&: permutedShape, |
| 587 | args&: strides); |
| 588 | Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>( |
| 589 | location: loc, args&: slicedOperand, args&: permutation); |
| 590 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 591 | location: loc, args&: transposedSlice, args&: result, args&: elementOffsets, args&: strides); |
| 592 | } |
| 593 | rewriter.replaceOp(op: transposeOp, newValues: result); |
| 594 | return success(); |
| 595 | } |
| 596 | |
| 597 | private: |
| 598 | vector::UnrollVectorOptions options; |
| 599 | }; |
| 600 | |
| 601 | struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { |
| 602 | UnrollGatherPattern(MLIRContext *context, |
| 603 | const vector::UnrollVectorOptions &options, |
| 604 | PatternBenefit benefit = 1) |
| 605 | : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { |
| 606 | } |
| 607 | |
| 608 | LogicalResult matchAndRewrite(vector::GatherOp gatherOp, |
| 609 | PatternRewriter &rewriter) const override { |
| 610 | VectorType sourceVectorType = gatherOp.getVectorType(); |
| 611 | if (sourceVectorType.getRank() == 0) |
| 612 | return failure(); |
| 613 | auto targetShape = getTargetShape(options, op: gatherOp); |
| 614 | if (!targetShape) |
| 615 | return failure(); |
| 616 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 617 | Location loc = gatherOp.getLoc(); |
| 618 | ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); |
| 619 | |
| 620 | // Prepare the result vector; |
| 621 | Value result = rewriter.create<arith::ConstantOp>( |
| 622 | location: loc, args&: sourceVectorType, args: rewriter.getZeroAttr(type: sourceVectorType)); |
| 623 | auto targetType = |
| 624 | VectorType::get(shape: *targetShape, elementType: sourceVectorType.getElementType()); |
| 625 | |
| 626 | SmallVector<int64_t> loopOrder = |
| 627 | getUnrollOrder(numLoops: originalSize.size(), op: gatherOp, options); |
| 628 | for (SmallVector<int64_t> elementOffsets : |
| 629 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
| 630 | // To get the unrolled gather, extract the same slice based on the |
| 631 | // decomposed shape from each of the index, mask, and pass-through |
| 632 | // vectors. |
| 633 | Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 634 | location: loc, args: gatherOp.getIndexVec(), args&: elementOffsets, args&: *targetShape, args&: strides); |
| 635 | Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 636 | location: loc, args: gatherOp.getMask(), args&: elementOffsets, args&: *targetShape, args&: strides); |
| 637 | Value passThruSubVec = |
| 638 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 639 | location: loc, args: gatherOp.getPassThru(), args&: elementOffsets, args&: *targetShape, |
| 640 | args&: strides); |
| 641 | auto slicedGather = rewriter.create<vector::GatherOp>( |
| 642 | location: loc, args&: targetType, args: gatherOp.getBase(), args: gatherOp.getIndices(), |
| 643 | args&: indexSubVec, args&: maskSubVec, args&: passThruSubVec); |
| 644 | |
| 645 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 646 | location: loc, args&: slicedGather, args&: result, args&: elementOffsets, args&: strides); |
| 647 | } |
| 648 | rewriter.replaceOp(op: gatherOp, newValues: result); |
| 649 | return success(); |
| 650 | } |
| 651 | |
| 652 | private: |
| 653 | vector::UnrollVectorOptions options; |
| 654 | }; |
| 655 | |
| 656 | struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> { |
| 657 | UnrollLoadPattern(MLIRContext *context, |
| 658 | const vector::UnrollVectorOptions &options, |
| 659 | PatternBenefit benefit = 1) |
| 660 | : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {} |
| 661 | |
| 662 | LogicalResult matchAndRewrite(vector::LoadOp loadOp, |
| 663 | PatternRewriter &rewriter) const override { |
| 664 | VectorType vecType = loadOp.getVectorType(); |
| 665 | |
| 666 | auto targetShape = getTargetShape(options, op: loadOp); |
| 667 | if (!targetShape) |
| 668 | return failure(); |
| 669 | |
| 670 | Location loc = loadOp.getLoc(); |
| 671 | ArrayRef<int64_t> originalShape = vecType.getShape(); |
| 672 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 673 | |
| 674 | Value result = rewriter.create<arith::ConstantOp>( |
| 675 | location: loc, args&: vecType, args: rewriter.getZeroAttr(type: vecType)); |
| 676 | |
| 677 | SmallVector<int64_t> loopOrder = |
| 678 | getUnrollOrder(numLoops: originalShape.size(), op: loadOp, options); |
| 679 | |
| 680 | auto targetVecType = |
| 681 | VectorType::get(shape: *targetShape, elementType: vecType.getElementType()); |
| 682 | |
| 683 | for (SmallVector<int64_t> offsets : |
| 684 | StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { |
| 685 | SmallVector<Value> indices = |
| 686 | sliceLoadStoreIndices(rewriter, loc, originalIndices: loadOp.getIndices(), offsets); |
| 687 | Value slicedLoad = rewriter.create<vector::LoadOp>( |
| 688 | location: loc, args&: targetVecType, args: loadOp.getBase(), args&: indices); |
| 689 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 690 | location: loc, args&: slicedLoad, args&: result, args&: offsets, args&: strides); |
| 691 | } |
| 692 | rewriter.replaceOp(op: loadOp, newValues: result); |
| 693 | return success(); |
| 694 | } |
| 695 | |
| 696 | private: |
| 697 | vector::UnrollVectorOptions options; |
| 698 | }; |
| 699 | |
| 700 | struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> { |
| 701 | UnrollStorePattern(MLIRContext *context, |
| 702 | const vector::UnrollVectorOptions &options, |
| 703 | PatternBenefit benefit = 1) |
| 704 | : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {} |
| 705 | |
| 706 | LogicalResult matchAndRewrite(vector::StoreOp storeOp, |
| 707 | PatternRewriter &rewriter) const override { |
| 708 | VectorType vecType = storeOp.getVectorType(); |
| 709 | |
| 710 | auto targetShape = getTargetShape(options, op: storeOp); |
| 711 | if (!targetShape) |
| 712 | return failure(); |
| 713 | |
| 714 | Location loc = storeOp.getLoc(); |
| 715 | ArrayRef<int64_t> originalShape = vecType.getShape(); |
| 716 | SmallVector<int64_t> strides(targetShape->size(), 1); |
| 717 | |
| 718 | Value base = storeOp.getBase(); |
| 719 | Value vector = storeOp.getValueToStore(); |
| 720 | |
| 721 | SmallVector<int64_t> loopOrder = |
| 722 | getUnrollOrder(numLoops: originalShape.size(), op: storeOp, options); |
| 723 | |
| 724 | for (SmallVector<int64_t> offsets : |
| 725 | StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) { |
| 726 | SmallVector<Value> indices = |
| 727 | sliceLoadStoreIndices(rewriter, loc, originalIndices: storeOp.getIndices(), offsets); |
| 728 | Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 729 | location: loc, args&: vector, args&: offsets, args&: *targetShape, args&: strides); |
| 730 | rewriter.create<vector::StoreOp>(location: loc, args&: slice, args&: base, args&: indices); |
| 731 | } |
| 732 | rewriter.eraseOp(op: storeOp); |
| 733 | return success(); |
| 734 | } |
| 735 | |
| 736 | private: |
| 737 | vector::UnrollVectorOptions options; |
| 738 | }; |
| 739 | |
| 740 | struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> { |
| 741 | UnrollBroadcastPattern(MLIRContext *context, |
| 742 | const vector::UnrollVectorOptions &options, |
| 743 | PatternBenefit benefit = 1) |
| 744 | : OpRewritePattern<vector::BroadcastOp>(context, benefit), |
| 745 | options(options) {} |
| 746 | |
| 747 | LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
| 748 | PatternRewriter &rewriter) const override { |
| 749 | auto targetShape = getTargetShape(options, op: broadcastOp); |
| 750 | if (!targetShape) |
| 751 | return failure(); |
| 752 | |
| 753 | Location loc = broadcastOp.getLoc(); |
| 754 | VectorType srcType = dyn_cast<VectorType>(Val: broadcastOp.getSourceType()); |
| 755 | VectorType resType = broadcastOp.getResultVectorType(); |
| 756 | VectorType targetType = |
| 757 | resType.cloneWith(shape: *targetShape, elementType: resType.getElementType()); |
| 758 | Value result = rewriter.create<arith::ConstantOp>( |
| 759 | location: loc, args&: resType, args: rewriter.getZeroAttr(type: resType)); |
| 760 | |
| 761 | SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll(); |
| 762 | SmallVector<int64_t> strides(originalShape.size(), 1); |
| 763 | |
| 764 | for (SmallVector<int64_t> offsets : |
| 765 | StaticTileOffsetRange(originalShape, *targetShape)) { |
| 766 | Value newSrc; |
| 767 | if (!srcType) { |
| 768 | // Scalar to vector broadcast. |
| 769 | newSrc = broadcastOp.getSource(); |
| 770 | } else { |
| 771 | // Vector to vector broadcast. |
| 772 | int64_t rank = srcType.getRank(); |
| 773 | SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end()); |
| 774 | SmallVector<int64_t> srcShape(targetShape->end() - rank, |
| 775 | targetShape->end()); |
| 776 | SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end()); |
| 777 | // adjust the offset and shape for src if the corresponding dim is 1. |
| 778 | for (int64_t i = 0; i < rank; ++i) { |
| 779 | if (srcType.getDimSize(idx: i) == 1) { |
| 780 | srcOffsets[i] = 0; |
| 781 | srcShape[i] = 1; |
| 782 | } |
| 783 | } |
| 784 | newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
| 785 | location: loc, args: broadcastOp.getSource(), args&: srcOffsets, args&: srcShape, args&: srcStrides); |
| 786 | } |
| 787 | |
| 788 | Operation *newOp = cloneOpWithOperandsAndTypes(builder&: rewriter, loc, op: broadcastOp, |
| 789 | operands: newSrc, resultTypes: targetType); |
| 790 | |
| 791 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
| 792 | location: loc, args: newOp->getResult(idx: 0), args&: result, args&: offsets, args&: strides); |
| 793 | } |
| 794 | |
| 795 | rewriter.replaceOp(op: broadcastOp, newValues: result); |
| 796 | return success(); |
| 797 | } |
| 798 | |
| 799 | private: |
| 800 | vector::UnrollVectorOptions options; |
| 801 | }; |
| 802 | |
| 803 | } // namespace |
| 804 | |
| 805 | void mlir::vector::populateVectorUnrollPatterns( |
| 806 | RewritePatternSet &patterns, const UnrollVectorOptions &options, |
| 807 | PatternBenefit benefit) { |
| 808 | patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, |
| 809 | UnrollContractionPattern, UnrollElementwisePattern, |
| 810 | UnrollReductionPattern, UnrollMultiReductionPattern, |
| 811 | UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, |
| 812 | UnrollStorePattern, UnrollBroadcastPattern>( |
| 813 | arg: patterns.getContext(), args: options, args&: benefit); |
| 814 | } |
| 815 | |