1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Utils/IndexingUtils.h"
15#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16#include "mlir/IR/ImplicitLocOpBuilder.h"
17#include "mlir/Interfaces/VectorInterfaces.h"
18#include "mlir/Support/MathExtras.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/Debug.h"
22#include <numeric>
23#include <optional>
24
25#define DEBUG_TYPE "vector-unroll"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
28
29using namespace mlir;
30using namespace mlir::vector;
31
32/// Compute the indices of the slice `index` for a tranfer op.
33static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
34 ArrayRef<Value> indices,
35 AffineMap permutationMap,
36 Location loc,
37 OpBuilder &builder) {
38 MLIRContext *ctx = builder.getContext();
39 auto isBroadcast = [](AffineExpr expr) {
40 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr))
41 return constExpr.getValue() == 0;
42 return false;
43 };
44 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
45 SmallVector<Value> slicedIndices(indices.begin(), indices.end());
46 for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) {
47 if (isBroadcast(dim.value()))
48 continue;
49 unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition();
50 auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) +
51 getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx);
52 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr);
53 slicedIndices[pos] =
54 builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
55 }
56 return slicedIndices;
57}
58
59// Clones `op` into a new operations that takes `operands` and returns
60// `resultTypes`.
61static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
62 Operation *op,
63 ArrayRef<Value> operands,
64 ArrayRef<Type> resultTypes) {
65 return builder.create(loc, op->getName().getIdentifier(), operands,
66 resultTypes, op->getAttrs());
67}
68
69/// Return the target shape for unrolling for the given `op`. Return
70/// std::nullopt if the op shouldn't be or cannot be unrolled.
71static std::optional<SmallVector<int64_t>>
72getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
73 LDBG("");
74 LDBG("Get unroll shape for op " << op->getName().getStringRef());
75 if (options.filterConstraint && failed(options.filterConstraint(op))) {
76 LDBG("--no filter constraint -> BAIL");
77 return std::nullopt;
78 }
79 assert(options.nativeShape &&
80 "vector unrolling expects the native shape or native"
81 "shape call back function to be set");
82 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
83 if (!unrollableVectorOp) {
84 LDBG("--not an unrollable op -> BAIL");
85 return std::nullopt;
86 }
87 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
88 if (!maybeUnrollShape) {
89 LDBG("--could not get shape of op " << *op << " -> BAIL");
90 return std::nullopt;
91 }
92 LLVM_DEBUG(
93 llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: ");
94 llvm::dbgs() << "\n";);
95
96 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
97 if (!targetShape) {
98 LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
99 return std::nullopt;
100 }
101 LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: ");
102 llvm::dbgs() << "\n";);
103
104 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
105 if (!maybeShapeRatio) {
106 LDBG("--could not compute integral shape ratio -> BAIL");
107 return std::nullopt;
108 }
109 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
110 LDBG("--no unrolling needed -> SKIP");
111 return std::nullopt;
112 }
113 LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
114 return targetShape;
115}
116
117static SmallVector<int64_t>
118getUnrollOrder(unsigned numLoops, Operation *op,
119 const vector::UnrollVectorOptions &options) {
120 SmallVector<int64_t> loopOrder =
121 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops)));
122 if (options.traversalOrderCallback != nullptr) {
123 std::optional<SmallVector<int64_t>> order =
124 options.traversalOrderCallback(op);
125 if (order) {
126 loopOrder = std::move(*order);
127 }
128 }
129 return loopOrder;
130}
131
132namespace {
133
134struct UnrollTransferReadPattern
135 : public OpRewritePattern<vector::TransferReadOp> {
136 UnrollTransferReadPattern(MLIRContext *context,
137 const vector::UnrollVectorOptions &options,
138 PatternBenefit benefit = 1)
139 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
140 options(options) {}
141
142 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
143 PatternRewriter &rewriter) const override {
144 // TODO: support 0-d corner case.
145 if (readOp.getTransferRank() == 0)
146 return failure();
147 if (readOp.getMask())
148 return failure();
149 auto targetShape = getTargetShape(options, readOp);
150 if (!targetShape)
151 return failure();
152 auto sourceVectorType = readOp.getVectorType();
153 SmallVector<int64_t> strides(targetShape->size(), 1);
154 Location loc = readOp.getLoc();
155 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
156
157 // Prepare the result vector;
158 Value result = rewriter.create<arith::ConstantOp>(
159 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
160 auto targetType =
161 VectorType::get(*targetShape, sourceVectorType.getElementType());
162 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
163 readOp.getIndices().end());
164 SmallVector<int64_t> loopOrder =
165 getUnrollOrder(originalSize.size(), readOp, options);
166 for (SmallVector<int64_t> elementOffsets :
167 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
168 SmallVector<Value> indices =
169 sliceTransferIndices(elementOffsets, originalIndices,
170 readOp.getPermutationMap(), loc, rewriter);
171 auto slicedRead = rewriter.create<vector::TransferReadOp>(
172 loc, targetType, readOp.getSource(), indices,
173 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
174 readOp.getInBoundsAttr());
175
176 result = rewriter.create<vector::InsertStridedSliceOp>(
177 loc, slicedRead, result, elementOffsets, strides);
178 }
179 rewriter.replaceOp(readOp, result);
180 return success();
181 }
182
183private:
184 vector::UnrollVectorOptions options;
185};
186
187struct UnrollTransferWritePattern
188 : public OpRewritePattern<vector::TransferWriteOp> {
189 UnrollTransferWritePattern(MLIRContext *context,
190 const vector::UnrollVectorOptions &options,
191 PatternBenefit benefit = 1)
192 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
193 options(options) {}
194
195 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
196 PatternRewriter &rewriter) const override {
197 // TODO: support 0-d corner case.
198 if (writeOp.getTransferRank() == 0)
199 return failure();
200
201 if (writeOp.getMask())
202 return failure();
203 auto targetShape = getTargetShape(options, writeOp);
204 if (!targetShape)
205 return failure();
206 auto sourceVectorType = writeOp.getVectorType();
207 SmallVector<int64_t> strides(targetShape->size(), 1);
208 Location loc = writeOp.getLoc();
209 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
210 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
211 writeOp.getIndices().end());
212 SmallVector<int64_t> loopOrder =
213 getUnrollOrder(originalSize.size(), writeOp, options);
214 Value resultTensor;
215 for (SmallVector<int64_t> elementOffsets :
216 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
217 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
218 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
219 SmallVector<Value> indices =
220 sliceTransferIndices(elementOffsets, originalIndices,
221 writeOp.getPermutationMap(), loc, rewriter);
222 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
223 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
224 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
225 // For the tensor case update the destination for the next transfer write.
226 if (!slicedWrite->getResults().empty())
227 resultTensor = slicedWrite->getResult(0);
228 }
229 if (resultTensor)
230 rewriter.replaceOp(writeOp, resultTensor);
231 else
232 rewriter.eraseOp(op: writeOp);
233 return success();
234 }
235
236private:
237 vector::UnrollVectorOptions options;
238};
239
240struct OffsetMapInfo {
241 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
242
243 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
244
245 static unsigned getHashValue(const SmallVector<int64_t> &v) {
246 return static_cast<unsigned>(llvm::hash_combine_range(first: v.begin(), last: v.end()));
247 }
248
249 static bool isEqual(const SmallVector<int64_t> &lhs,
250 const SmallVector<int64_t> &rhs) {
251 return lhs == rhs;
252 }
253};
254
255struct UnrollContractionPattern
256 : public OpRewritePattern<vector::ContractionOp> {
257 UnrollContractionPattern(MLIRContext *context,
258 const vector::UnrollVectorOptions &options,
259 PatternBenefit benefit = 1)
260 : OpRewritePattern<vector::ContractionOp>(context, benefit),
261 options(options) {}
262
263 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
264 PatternRewriter &rewriter) const override {
265 auto targetShape = getTargetShape(options, contractOp);
266 if (!targetShape)
267 return failure();
268 auto dstVecType = cast<VectorType>(contractOp.getResultType());
269 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
270
271 Location loc = contractOp.getLoc();
272 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
273 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
274 llvm::MapVector<
275 SmallVector<int64_t>, Value,
276 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
277 accCache;
278
279 SmallVector<int64_t> loopOrder = getUnrollOrder(
280 contractOp.getIteratorTypes().size(), contractOp, options);
281
282 for (SmallVector<int64_t> offsets :
283 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
284 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
285
286 // Helper to compute the new shape of each operand and extract the slice.
287 auto extractOperand = [&](unsigned index, Value operand,
288 AffineMap permutationMap,
289 ArrayRef<int64_t> operandOffets) {
290 SmallVector<int64_t> operandShape = applyPermutationMap(
291 permutationMap, ArrayRef<int64_t>(*targetShape));
292 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
293 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
294 loc, operand, operandOffets, operandShape, operandStrides);
295 };
296
297 // Extract the new lhs operand.
298 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
299 SmallVector<int64_t> lhsOffets =
300 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
301 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
302
303 // Extract the new rhs operand.
304 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
305 SmallVector<int64_t> rhsOffets =
306 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
307 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
308
309 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
310 SmallVector<int64_t> accOffets =
311 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
312 // If a version of the accumulator has already been computed, use it
313 // otherwise extract the first version from the original operand.
314 auto *accIt = accCache.find(accOffets);
315 if (accIt != accCache.end())
316 slicesOperands[2] = accIt->second;
317 else
318 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
319
320 SmallVector<int64_t> dstShape =
321 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
322 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
323 Operation *newOp = cloneOpWithOperandsAndTypes(
324 rewriter, loc, contractOp, slicesOperands, targetType);
325
326 SmallVector<int64_t> dstOffets =
327 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
328 // Save the accumulated value untill all the loops are unrolled since
329 // reduction loop keep updating the accumulator.
330 accCache[dstOffets] = newOp->getResult(0);
331 }
332 // Assemble back the accumulator into a single vector.
333 Value result = rewriter.create<arith::ConstantOp>(
334 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
335 for (const auto &it : accCache) {
336 SmallVector<int64_t> dstStrides(it.first.size(), 1);
337 result = rewriter.create<vector::InsertStridedSliceOp>(
338 loc, it.second, result, it.first, dstStrides);
339 }
340 rewriter.replaceOp(contractOp, result);
341 return success();
342 }
343
344private:
345 vector::UnrollVectorOptions options;
346};
347
348struct UnrollMultiReductionPattern
349 : public OpRewritePattern<vector::MultiDimReductionOp> {
350 UnrollMultiReductionPattern(MLIRContext *context,
351 const vector::UnrollVectorOptions &options,
352 PatternBenefit benefit = 1)
353 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
354 options(options) {}
355
356 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
357 PatternRewriter &rewriter) const override {
358 std::optional<SmallVector<int64_t>> targetShape =
359 getTargetShape(options, reductionOp);
360 if (!targetShape)
361 return failure();
362 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
363 llvm::MapVector<
364 SmallVector<int64_t>, Value,
365 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
366 accCache;
367 Location loc = reductionOp.getLoc();
368
369 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
370 // of multiples of the targetShape.
371 for (SmallVector<int64_t> offsets :
372 StaticTileOffsetRange(originalSize, *targetShape)) {
373 SmallVector<Value> operands;
374 SmallVector<int64_t> operandStrides(offsets.size(), 1);
375 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
376 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
377 operands.push_back(slicedOperand);
378 SmallVector<int64_t> dstShape;
379 SmallVector<int64_t> destOffset;
380 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
381 if (!reductionOp.isReducedDim(i)) {
382 destOffset.push_back(offsets[i]);
383 dstShape.push_back((*targetShape)[i]);
384 }
385 }
386 Value acc;
387 SmallVector<int64_t> accStrides(destOffset.size(), 1);
388 // If a version of the accumulator has already been computed, use it
389 // otherwise extract the first version from the original operand.
390 auto *accIt = accCache.find(destOffset);
391 if (accIt != accCache.end())
392 acc = accIt->second;
393 else
394 acc = rewriter.create<vector::ExtractStridedSliceOp>(
395 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
396 operands.push_back(acc);
397 auto targetType = VectorType::get(
398 dstShape, reductionOp.getSourceVectorType().getElementType());
399 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
400 operands, targetType);
401 Value result = newOp->getResult(0);
402 accCache[destOffset] = result;
403 }
404 // Assemble back the accumulator into a single vector.
405 Value result = rewriter.create<arith::ConstantOp>(
406 loc, reductionOp.getDestType(),
407 rewriter.getZeroAttr(reductionOp.getDestType()));
408 for (const auto &it : accCache) {
409 SmallVector<int64_t> dstStrides(it.first.size(), 1);
410 result = rewriter.create<vector::InsertStridedSliceOp>(
411 loc, it.second, result, it.first, dstStrides);
412 }
413 rewriter.replaceOp(reductionOp, result);
414 return success();
415 }
416
417private:
418 vector::UnrollVectorOptions options;
419};
420
421struct UnrollElementwisePattern : public RewritePattern {
422 UnrollElementwisePattern(MLIRContext *context,
423 const vector::UnrollVectorOptions &options,
424 PatternBenefit benefit = 1)
425 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
426 options(options) {}
427
428 LogicalResult matchAndRewrite(Operation *op,
429 PatternRewriter &rewriter) const override {
430 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
431 return failure();
432 auto targetShape = getTargetShape(options, op);
433 if (!targetShape)
434 return failure();
435 auto dstVecType = cast<VectorType>(op->getResult(idx: 0).getType());
436 SmallVector<int64_t> originalSize =
437 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
438 Location loc = op->getLoc();
439 // Prepare the result vector.
440 Value result = rewriter.create<arith::ConstantOp>(
441 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
442 SmallVector<int64_t> strides(targetShape->size(), 1);
443 VectorType newVecType =
444 VectorType::get(*targetShape, dstVecType.getElementType());
445
446 // Create the unrolled computation.
447 for (SmallVector<int64_t> offsets :
448 StaticTileOffsetRange(originalSize, *targetShape)) {
449 SmallVector<Value> extractOperands;
450 for (OpOperand &operand : op->getOpOperands()) {
451 auto vecType = dyn_cast<VectorType>(operand.get().getType());
452 if (!vecType) {
453 extractOperands.push_back(operand.get());
454 continue;
455 }
456 extractOperands.push_back(
457 rewriter.create<vector::ExtractStridedSliceOp>(
458 loc, operand.get(), offsets, *targetShape, strides));
459 }
460 Operation *newOp = cloneOpWithOperandsAndTypes(
461 rewriter, loc, op, extractOperands, newVecType);
462 result = rewriter.create<vector::InsertStridedSliceOp>(
463 loc, newOp->getResult(0), result, offsets, strides);
464 }
465 rewriter.replaceOp(op, newValues: result);
466 return success();
467 }
468
469private:
470 vector::UnrollVectorOptions options;
471};
472
473struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
474 UnrollReductionPattern(MLIRContext *context,
475 const vector::UnrollVectorOptions &options,
476 PatternBenefit benefit = 1)
477 : OpRewritePattern<vector::ReductionOp>(context, benefit),
478 options(options) {}
479
480 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
481 PatternRewriter &rewriter) const override {
482 std::optional<SmallVector<int64_t>> targetShape =
483 getTargetShape(options, reductionOp);
484 if (!targetShape)
485 return failure();
486 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
487
488 // Create unrolled vector reduction.
489 Location loc = reductionOp.getLoc();
490 Value accumulator = nullptr;
491 for (SmallVector<int64_t> offsets :
492 StaticTileOffsetRange(originalSize, *targetShape)) {
493 SmallVector<int64_t> strides(offsets.size(), 1);
494 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
495 loc, reductionOp.getVector(), offsets, *targetShape, strides);
496 Operation *newOp = cloneOpWithOperandsAndTypes(
497 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
498 Value result = newOp->getResult(0);
499
500 if (!accumulator) {
501 // This is the first reduction.
502 accumulator = result;
503 } else {
504 // On subsequent reduction, combine with the accumulator.
505 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
506 accumulator, result);
507 }
508 }
509
510 rewriter.replaceOp(reductionOp, accumulator);
511 return success();
512 }
513
514private:
515 const vector::UnrollVectorOptions options;
516};
517
518struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
519 UnrollTransposePattern(MLIRContext *context,
520 const vector::UnrollVectorOptions &options,
521 PatternBenefit benefit = 1)
522 : OpRewritePattern<vector::TransposeOp>(context, benefit),
523 options(options) {}
524
525 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
526 PatternRewriter &rewriter) const override {
527 if (transposeOp.getResultVectorType().getRank() == 0)
528 return failure();
529 auto targetShape = getTargetShape(options, transposeOp);
530 if (!targetShape)
531 return failure();
532 auto originalVectorType = transposeOp.getResultVectorType();
533 SmallVector<int64_t> strides(targetShape->size(), 1);
534 Location loc = transposeOp.getLoc();
535 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
536
537 // Prepare the result vector;
538 Value result = rewriter.create<arith::ConstantOp>(
539 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
540 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
541
542 // Unroll the computation.
543 for (SmallVector<int64_t> elementOffsets :
544 StaticTileOffsetRange(originalSize, *targetShape)) {
545 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
546 SmallVector<int64_t> permutedShape(elementOffsets.size());
547 // Compute the source offsets and shape.
548 for (auto indices : llvm::enumerate(permutation)) {
549 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
550 permutedShape[indices.value()] = (*targetShape)[indices.index()];
551 }
552 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
553 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
554 strides);
555 Value transposedSlice =
556 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
557 result = rewriter.create<vector::InsertStridedSliceOp>(
558 loc, transposedSlice, result, elementOffsets, strides);
559 }
560 rewriter.replaceOp(transposeOp, result);
561 return success();
562 }
563
564private:
565 vector::UnrollVectorOptions options;
566};
567
568struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
569 UnrollGatherPattern(MLIRContext *context,
570 const vector::UnrollVectorOptions &options,
571 PatternBenefit benefit = 1)
572 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
573 }
574
575 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
576 PatternRewriter &rewriter) const override {
577 VectorType sourceVectorType = gatherOp.getVectorType();
578 if (sourceVectorType.getRank() == 0)
579 return failure();
580 auto targetShape = getTargetShape(options, gatherOp);
581 if (!targetShape)
582 return failure();
583 SmallVector<int64_t> strides(targetShape->size(), 1);
584 Location loc = gatherOp.getLoc();
585 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
586
587 // Prepare the result vector;
588 Value result = rewriter.create<arith::ConstantOp>(
589 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
590 auto targetType =
591 VectorType::get(*targetShape, sourceVectorType.getElementType());
592
593 SmallVector<int64_t> loopOrder =
594 getUnrollOrder(originalSize.size(), gatherOp, options);
595 for (SmallVector<int64_t> elementOffsets :
596 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
597 // To get the unrolled gather, extract the same slice based on the
598 // decomposed shape from each of the index, mask, and pass-through
599 // vectors.
600 Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
601 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
602 Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
603 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
604 Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
605 loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
606 auto slicedGather = rewriter.create<vector::GatherOp>(
607 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
608 indexSubVec, maskSubVec, passThruSubVec);
609
610 result = rewriter.create<vector::InsertStridedSliceOp>(
611 loc, slicedGather, result, elementOffsets, strides);
612 }
613 rewriter.replaceOp(gatherOp, result);
614 return success();
615 }
616
617private:
618 vector::UnrollVectorOptions options;
619};
620
621} // namespace
622
623void mlir::vector::populateVectorUnrollPatterns(
624 RewritePatternSet &patterns, const UnrollVectorOptions &options,
625 PatternBenefit benefit) {
626 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
627 UnrollContractionPattern, UnrollElementwisePattern,
628 UnrollReductionPattern, UnrollMultiReductionPattern,
629 UnrollTransposePattern, UnrollGatherPattern>(
630 arg: patterns.getContext(), args: options, args&: benefit);
631}
632

source code of mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp