1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Utils/IndexingUtils.h"
15#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16#include "mlir/Interfaces/VectorInterfaces.h"
17#include "llvm/ADT/MapVector.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/InterleavedRange.h"
21#include <optional>
22
23#define DEBUG_TYPE "vector-unroll"
24#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26
27using namespace mlir;
28using namespace mlir::vector;
29
30/// Compute the indices of the slice `index` for a transfer op.
31static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
32 ArrayRef<Value> indices,
33 AffineMap permutationMap,
34 Location loc,
35 OpBuilder &builder) {
36 MLIRContext *ctx = builder.getContext();
37 auto isBroadcast = [](AffineExpr expr) {
38 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr))
39 return constExpr.getValue() == 0;
40 return false;
41 };
42 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
43 SmallVector<Value> slicedIndices(indices);
44 for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) {
45 if (isBroadcast(dim.value()))
46 continue;
47 unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition();
48 auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) +
49 getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx);
50 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr);
51 slicedIndices[pos] =
52 builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
53 }
54 return slicedIndices;
55}
56
57// Clones `op` into a new operations that takes `operands` and returns
58// `resultTypes`.
59static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
60 Operation *op,
61 ArrayRef<Value> operands,
62 ArrayRef<Type> resultTypes) {
63 return builder.create(loc, op->getName().getIdentifier(), operands,
64 resultTypes, op->getAttrs());
65}
66
67/// Return the target shape for unrolling for the given `op`. Return
68/// std::nullopt if the op shouldn't be or cannot be unrolled.
69static std::optional<SmallVector<int64_t>>
70getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
71 LDBG("");
72 LDBG("Get unroll shape for op " << op->getName().getStringRef());
73 if (options.filterConstraint && failed(options.filterConstraint(op))) {
74 LDBG("--no filter constraint -> BAIL");
75 return std::nullopt;
76 }
77 assert(options.nativeShape &&
78 "vector unrolling expects the native shape or native"
79 "shape call back function to be set");
80 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
81 if (!unrollableVectorOp) {
82 LDBG("--not an unrollable op -> BAIL");
83 return std::nullopt;
84 }
85 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
86 if (!maybeUnrollShape) {
87 LDBG("--could not get shape of op " << *op << " -> BAIL");
88 return std::nullopt;
89 }
90 LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape));
91
92 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
93 if (!targetShape) {
94 LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
95 return std::nullopt;
96 }
97 LDBG("--target shape: " << llvm::interleaved(*targetShape));
98
99 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
100 if (!maybeShapeRatio) {
101 LDBG("--could not compute integral shape ratio -> BAIL");
102 return std::nullopt;
103 }
104 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
105 LDBG("--no unrolling needed -> SKIP");
106 return std::nullopt;
107 }
108 LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
109 return targetShape;
110}
111
112static SmallVector<int64_t>
113getUnrollOrder(unsigned numLoops, Operation *op,
114 const vector::UnrollVectorOptions &options) {
115 SmallVector<int64_t> loopOrder =
116 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops)));
117 if (options.traversalOrderCallback != nullptr) {
118 std::optional<SmallVector<int64_t>> order =
119 options.traversalOrderCallback(op);
120 if (order) {
121 loopOrder = std::move(*order);
122 }
123 }
124 return loopOrder;
125}
126
127namespace {
128
129struct UnrollTransferReadPattern
130 : public OpRewritePattern<vector::TransferReadOp> {
131 UnrollTransferReadPattern(MLIRContext *context,
132 const vector::UnrollVectorOptions &options,
133 PatternBenefit benefit = 1)
134 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
135 options(options) {}
136
137 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
138 PatternRewriter &rewriter) const override {
139 // TODO: support 0-d corner case.
140 if (readOp.getTransferRank() == 0)
141 return failure();
142 if (readOp.getMask())
143 return failure();
144 auto targetShape = getTargetShape(options, readOp);
145 if (!targetShape)
146 return failure();
147 auto sourceVectorType = readOp.getVectorType();
148 SmallVector<int64_t> strides(targetShape->size(), 1);
149 Location loc = readOp.getLoc();
150 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
151
152 // Prepare the result vector;
153 Value result = rewriter.create<arith::ConstantOp>(
154 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
155 auto targetType =
156 VectorType::get(*targetShape, sourceVectorType.getElementType());
157 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
158 readOp.getIndices().end());
159 SmallVector<int64_t> loopOrder =
160 getUnrollOrder(originalSize.size(), readOp, options);
161 for (SmallVector<int64_t> elementOffsets :
162 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
163 SmallVector<Value> indices =
164 sliceTransferIndices(elementOffsets, originalIndices,
165 readOp.getPermutationMap(), loc, rewriter);
166 auto slicedRead = rewriter.create<vector::TransferReadOp>(
167 loc, targetType, readOp.getBase(), indices,
168 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
169 readOp.getInBoundsAttr());
170
171 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
172 loc, slicedRead, result, elementOffsets, strides);
173 }
174 rewriter.replaceOp(readOp, result);
175 return success();
176 }
177
178private:
179 vector::UnrollVectorOptions options;
180};
181
182struct UnrollTransferWritePattern
183 : public OpRewritePattern<vector::TransferWriteOp> {
184 UnrollTransferWritePattern(MLIRContext *context,
185 const vector::UnrollVectorOptions &options,
186 PatternBenefit benefit = 1)
187 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
188 options(options) {}
189
190 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
191 PatternRewriter &rewriter) const override {
192 // TODO: support 0-d corner case.
193 if (writeOp.getTransferRank() == 0)
194 return failure();
195
196 if (writeOp.getMask())
197 return failure();
198 auto targetShape = getTargetShape(options, writeOp);
199 if (!targetShape)
200 return failure();
201 auto sourceVectorType = writeOp.getVectorType();
202 SmallVector<int64_t> strides(targetShape->size(), 1);
203 Location loc = writeOp.getLoc();
204 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
205 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
206 writeOp.getIndices().end());
207 SmallVector<int64_t> loopOrder =
208 getUnrollOrder(originalSize.size(), writeOp, options);
209 Value resultTensor;
210 for (SmallVector<int64_t> elementOffsets :
211 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
212 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
213 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
214 SmallVector<Value> indices =
215 sliceTransferIndices(elementOffsets, originalIndices,
216 writeOp.getPermutationMap(), loc, rewriter);
217 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
218 loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(),
219 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
220 // For the tensor case update the destination for the next transfer write.
221 if (!slicedWrite->getResults().empty())
222 resultTensor = slicedWrite->getResult(0);
223 }
224 if (resultTensor)
225 rewriter.replaceOp(writeOp, resultTensor);
226 else
227 rewriter.eraseOp(op: writeOp);
228 return success();
229 }
230
231private:
232 vector::UnrollVectorOptions options;
233};
234
235struct OffsetMapInfo {
236 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
237
238 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
239
240 static unsigned getHashValue(const SmallVector<int64_t> &v) {
241 return static_cast<unsigned>(llvm::hash_combine_range(R: v));
242 }
243
244 static bool isEqual(const SmallVector<int64_t> &lhs,
245 const SmallVector<int64_t> &rhs) {
246 return lhs == rhs;
247 }
248};
249
250struct UnrollContractionPattern
251 : public OpRewritePattern<vector::ContractionOp> {
252 UnrollContractionPattern(MLIRContext *context,
253 const vector::UnrollVectorOptions &options,
254 PatternBenefit benefit = 1)
255 : OpRewritePattern<vector::ContractionOp>(context, benefit),
256 options(options) {}
257
258 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
259 PatternRewriter &rewriter) const override {
260 auto targetShape = getTargetShape(options, contractOp);
261 if (!targetShape)
262 return failure();
263 auto dstVecType = cast<VectorType>(contractOp.getResultType());
264 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
265
266 Location loc = contractOp.getLoc();
267 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
268 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
269 llvm::MapVector<
270 SmallVector<int64_t>, Value,
271 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
272 accCache;
273
274 SmallVector<int64_t> loopOrder = getUnrollOrder(
275 contractOp.getIteratorTypes().size(), contractOp, options);
276
277 for (SmallVector<int64_t> offsets :
278 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
279 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
280
281 // Helper to compute the new shape of each operand and extract the slice.
282 auto extractOperand = [&](unsigned index, Value operand,
283 AffineMap permutationMap,
284 ArrayRef<int64_t> operandOffets) {
285 SmallVector<int64_t> operandShape = applyPermutationMap(
286 permutationMap, ArrayRef<int64_t>(*targetShape));
287 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
288 slicesOperands[index] =
289 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
290 loc, operand, operandOffets, operandShape, operandStrides);
291 };
292
293 // Extract the new lhs operand.
294 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
295 SmallVector<int64_t> lhsOffets =
296 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
297 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
298
299 // Extract the new rhs operand.
300 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
301 SmallVector<int64_t> rhsOffets =
302 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
303 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
304
305 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
306 SmallVector<int64_t> accOffets =
307 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
308 // If a version of the accumulator has already been computed, use it
309 // otherwise extract the first version from the original operand.
310 auto *accIt = accCache.find(accOffets);
311 if (accIt != accCache.end())
312 slicesOperands[2] = accIt->second;
313 else
314 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
315
316 SmallVector<int64_t> dstShape =
317 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
318 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
319 Operation *newOp = cloneOpWithOperandsAndTypes(
320 rewriter, loc, contractOp, slicesOperands, targetType);
321
322 SmallVector<int64_t> dstOffets =
323 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
324 // Save the accumulated value untill all the loops are unrolled since
325 // reduction loop keep updating the accumulator.
326 accCache[dstOffets] = newOp->getResult(0);
327 }
328 // Assemble back the accumulator into a single vector.
329 Value result = rewriter.create<arith::ConstantOp>(
330 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
331 for (const auto &it : accCache) {
332 SmallVector<int64_t> dstStrides(it.first.size(), 1);
333 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
334 loc, it.second, result, it.first, dstStrides);
335 }
336 rewriter.replaceOp(contractOp, result);
337 return success();
338 }
339
340private:
341 vector::UnrollVectorOptions options;
342};
343
344struct UnrollMultiReductionPattern
345 : public OpRewritePattern<vector::MultiDimReductionOp> {
346 UnrollMultiReductionPattern(MLIRContext *context,
347 const vector::UnrollVectorOptions &options,
348 PatternBenefit benefit = 1)
349 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
350 options(options) {}
351
352 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
353 PatternRewriter &rewriter) const override {
354 auto resultType = reductionOp->getResult(0).getType();
355 if (resultType.isIntOrFloat()) {
356 return rewriter.notifyMatchFailure(reductionOp,
357 "Unrolling scalars is not supported");
358 }
359 std::optional<SmallVector<int64_t>> targetShape =
360 getTargetShape(options, reductionOp);
361 if (!targetShape)
362 return failure();
363 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
364 llvm::MapVector<
365 SmallVector<int64_t>, Value,
366 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
367 accCache;
368 Location loc = reductionOp.getLoc();
369
370 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
371 // of multiples of the targetShape.
372 for (SmallVector<int64_t> offsets :
373 StaticTileOffsetRange(originalSize, *targetShape)) {
374 SmallVector<Value> operands;
375 SmallVector<int64_t> operandStrides(offsets.size(), 1);
376 Value slicedOperand =
377 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
378 loc, reductionOp.getSource(), offsets, *targetShape,
379 operandStrides);
380 operands.push_back(slicedOperand);
381 SmallVector<int64_t> dstShape;
382 SmallVector<int64_t> destOffset;
383 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
384 if (!reductionOp.isReducedDim(i)) {
385 destOffset.push_back(offsets[i]);
386 dstShape.push_back((*targetShape)[i]);
387 }
388 }
389 Value acc;
390 SmallVector<int64_t> accStrides(destOffset.size(), 1);
391 // If a version of the accumulator has already been computed, use it
392 // otherwise extract the first version from the original operand.
393 auto *accIt = accCache.find(destOffset);
394 if (accIt != accCache.end())
395 acc = accIt->second;
396 else
397 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
398 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
399 operands.push_back(acc);
400 auto targetType = VectorType::get(
401 dstShape, reductionOp.getSourceVectorType().getElementType());
402 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
403 operands, targetType);
404 Value result = newOp->getResult(0);
405 accCache[destOffset] = result;
406 }
407 // Assemble back the accumulator into a single vector.
408 Value result = rewriter.create<arith::ConstantOp>(
409 loc, reductionOp.getDestType(),
410 rewriter.getZeroAttr(reductionOp.getDestType()));
411 for (const auto &it : accCache) {
412 SmallVector<int64_t> dstStrides(it.first.size(), 1);
413 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
414 loc, it.second, result, it.first, dstStrides);
415 }
416 rewriter.replaceOp(reductionOp, result);
417 return success();
418 }
419
420private:
421 vector::UnrollVectorOptions options;
422};
423
424struct UnrollElementwisePattern : public RewritePattern {
425 UnrollElementwisePattern(MLIRContext *context,
426 const vector::UnrollVectorOptions &options,
427 PatternBenefit benefit = 1)
428 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
429 options(options) {}
430
431 LogicalResult matchAndRewrite(Operation *op,
432 PatternRewriter &rewriter) const override {
433 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
434 return failure();
435 auto targetShape = getTargetShape(options, op);
436 if (!targetShape)
437 return failure();
438 auto dstVecType = cast<VectorType>(op->getResult(idx: 0).getType());
439 SmallVector<int64_t> originalSize =
440 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
441 // Bail-out if rank(source) != rank(target). The main limitation here is the
442 // fact that `ExtractStridedSlice` requires the rank for the input and
443 // output to match. If needed, we can relax this later.
444 if (originalSize.size() != targetShape->size())
445 return rewriter.notifyMatchFailure(
446 arg&: op, msg: "expected input vector rank to match target shape rank");
447 Location loc = op->getLoc();
448 // Prepare the result vector.
449 Value result = rewriter.create<arith::ConstantOp>(
450 loc, dstVecType, rewriter.getZeroAttr(dstVecType));
451 SmallVector<int64_t> strides(targetShape->size(), 1);
452 VectorType newVecType =
453 VectorType::get(*targetShape, dstVecType.getElementType());
454
455 // Create the unrolled computation.
456 for (SmallVector<int64_t> offsets :
457 StaticTileOffsetRange(originalSize, *targetShape)) {
458 SmallVector<Value> extractOperands;
459 for (OpOperand &operand : op->getOpOperands()) {
460 auto vecType = dyn_cast<VectorType>(operand.get().getType());
461 if (!vecType) {
462 extractOperands.push_back(operand.get());
463 continue;
464 }
465 extractOperands.push_back(
466 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
467 loc, operand.get(), offsets, *targetShape, strides));
468 }
469 Operation *newOp = cloneOpWithOperandsAndTypes(
470 rewriter, loc, op, extractOperands, newVecType);
471 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
472 loc, newOp->getResult(0), result, offsets, strides);
473 }
474 rewriter.replaceOp(op, newValues: result);
475 return success();
476 }
477
478private:
479 vector::UnrollVectorOptions options;
480};
481
482struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
483 UnrollReductionPattern(MLIRContext *context,
484 const vector::UnrollVectorOptions &options,
485 PatternBenefit benefit = 1)
486 : OpRewritePattern<vector::ReductionOp>(context, benefit),
487 options(options) {}
488
489 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
490 PatternRewriter &rewriter) const override {
491 std::optional<SmallVector<int64_t>> targetShape =
492 getTargetShape(options, reductionOp);
493 if (!targetShape)
494 return failure();
495 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
496
497 // Create unrolled vector reduction.
498 Location loc = reductionOp.getLoc();
499 Value accumulator = nullptr;
500 for (SmallVector<int64_t> offsets :
501 StaticTileOffsetRange(originalSize, *targetShape)) {
502 SmallVector<int64_t> strides(offsets.size(), 1);
503 Value slicedOperand =
504 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
505 loc, reductionOp.getVector(), offsets, *targetShape, strides);
506 Operation *newOp = cloneOpWithOperandsAndTypes(
507 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
508 Value result = newOp->getResult(0);
509
510 if (!accumulator) {
511 // This is the first reduction.
512 accumulator = result;
513 } else {
514 // On subsequent reduction, combine with the accumulator.
515 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
516 accumulator, result);
517 }
518 }
519
520 rewriter.replaceOp(reductionOp, accumulator);
521 return success();
522 }
523
524private:
525 const vector::UnrollVectorOptions options;
526};
527
528struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
529 UnrollTransposePattern(MLIRContext *context,
530 const vector::UnrollVectorOptions &options,
531 PatternBenefit benefit = 1)
532 : OpRewritePattern<vector::TransposeOp>(context, benefit),
533 options(options) {}
534
535 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
536 PatternRewriter &rewriter) const override {
537 if (transposeOp.getResultVectorType().getRank() == 0)
538 return failure();
539 auto targetShape = getTargetShape(options, transposeOp);
540 if (!targetShape)
541 return failure();
542 auto originalVectorType = transposeOp.getResultVectorType();
543 SmallVector<int64_t> strides(targetShape->size(), 1);
544 Location loc = transposeOp.getLoc();
545 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
546
547 // Prepare the result vector;
548 Value result = rewriter.create<arith::ConstantOp>(
549 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
550 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
551
552 // Unroll the computation.
553 for (SmallVector<int64_t> elementOffsets :
554 StaticTileOffsetRange(originalSize, *targetShape)) {
555 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
556 SmallVector<int64_t> permutedShape(elementOffsets.size());
557 // Compute the source offsets and shape.
558 for (auto indices : llvm::enumerate(permutation)) {
559 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
560 permutedShape[indices.value()] = (*targetShape)[indices.index()];
561 }
562 Value slicedOperand =
563 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
564 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
565 strides);
566 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
567 loc, slicedOperand, permutation);
568 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
569 loc, transposedSlice, result, elementOffsets, strides);
570 }
571 rewriter.replaceOp(transposeOp, result);
572 return success();
573 }
574
575private:
576 vector::UnrollVectorOptions options;
577};
578
579struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
580 UnrollGatherPattern(MLIRContext *context,
581 const vector::UnrollVectorOptions &options,
582 PatternBenefit benefit = 1)
583 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
584 }
585
586 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
587 PatternRewriter &rewriter) const override {
588 VectorType sourceVectorType = gatherOp.getVectorType();
589 if (sourceVectorType.getRank() == 0)
590 return failure();
591 auto targetShape = getTargetShape(options, gatherOp);
592 if (!targetShape)
593 return failure();
594 SmallVector<int64_t> strides(targetShape->size(), 1);
595 Location loc = gatherOp.getLoc();
596 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
597
598 // Prepare the result vector;
599 Value result = rewriter.create<arith::ConstantOp>(
600 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
601 auto targetType =
602 VectorType::get(*targetShape, sourceVectorType.getElementType());
603
604 SmallVector<int64_t> loopOrder =
605 getUnrollOrder(originalSize.size(), gatherOp, options);
606 for (SmallVector<int64_t> elementOffsets :
607 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
608 // To get the unrolled gather, extract the same slice based on the
609 // decomposed shape from each of the index, mask, and pass-through
610 // vectors.
611 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
612 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
613 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
614 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
615 Value passThruSubVec =
616 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
617 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
618 strides);
619 auto slicedGather = rewriter.create<vector::GatherOp>(
620 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
621 indexSubVec, maskSubVec, passThruSubVec);
622
623 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
624 loc, slicedGather, result, elementOffsets, strides);
625 }
626 rewriter.replaceOp(gatherOp, result);
627 return success();
628 }
629
630private:
631 vector::UnrollVectorOptions options;
632};
633
634struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
635 UnrollBroadcastPattern(MLIRContext *context,
636 const vector::UnrollVectorOptions &options,
637 PatternBenefit benefit = 1)
638 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
639 options(options) {}
640
641 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
642 PatternRewriter &rewriter) const override {
643 auto targetShape = getTargetShape(options, broadcastOp);
644 if (!targetShape)
645 return failure();
646
647 Location loc = broadcastOp.getLoc();
648 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
649 VectorType resType = broadcastOp.getResultVectorType();
650 VectorType targetType =
651 resType.cloneWith(*targetShape, resType.getElementType());
652 Value result = rewriter.create<arith::ConstantOp>(
653 loc, resType, rewriter.getZeroAttr(resType));
654
655 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
656 SmallVector<int64_t> strides(originalShape.size(), 1);
657
658 for (SmallVector<int64_t> offsets :
659 StaticTileOffsetRange(originalShape, *targetShape)) {
660 Value newSrc;
661 if (!srcType) {
662 // Scalar to vector broadcast.
663 newSrc = broadcastOp.getSource();
664 } else {
665 // Vector to vector broadcast.
666 int64_t rank = srcType.getRank();
667 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
668 SmallVector<int64_t> srcShape(targetShape->end() - rank,
669 targetShape->end());
670 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
671 // adjust the offset and shape for src if the corresponding dim is 1.
672 for (int64_t i = 0; i < rank; ++i) {
673 if (srcType.getDimSize(i) == 1) {
674 srcOffsets[i] = 0;
675 srcShape[i] = 1;
676 }
677 }
678 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
679 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
680 }
681
682 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
683 newSrc, targetType);
684
685 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
686 loc, newOp->getResult(0), result, offsets, strides);
687 }
688
689 rewriter.replaceOp(broadcastOp, result);
690 return success();
691 }
692
693private:
694 vector::UnrollVectorOptions options;
695};
696
697} // namespace
698
699void mlir::vector::populateVectorUnrollPatterns(
700 RewritePatternSet &patterns, const UnrollVectorOptions &options,
701 PatternBenefit benefit) {
702 patterns
703 .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
704 UnrollContractionPattern, UnrollElementwisePattern,
705 UnrollReductionPattern, UnrollMultiReductionPattern,
706 UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
707 arg: patterns.getContext(), args: options, args&: benefit);
708}
709

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp