1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Utils/IndexingUtils.h"
15#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16#include "mlir/Interfaces/VectorInterfaces.h"
17#include "llvm/ADT/MapVector.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/InterleavedRange.h"
21#include <optional>
22
23#define DEBUG_TYPE "vector-unroll"
24#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26
27using namespace mlir;
28using namespace mlir::vector;
29
30/// Compute the indices of the slice `index` for a transfer op.
31static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
32 ArrayRef<Value> indices,
33 AffineMap permutationMap,
34 Location loc,
35 OpBuilder &builder) {
36 MLIRContext *ctx = builder.getContext();
37 auto isBroadcast = [](AffineExpr expr) {
38 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr))
39 return constExpr.getValue() == 0;
40 return false;
41 };
42 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
43 SmallVector<Value> slicedIndices(indices);
44 for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) {
45 if (isBroadcast(dim.value()))
46 continue;
47 unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition();
48 auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) +
49 getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx);
50 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr);
51 slicedIndices[pos] =
52 builder.create<affine::AffineApplyOp>(location: loc, args&: map, args: indices[pos]);
53 }
54 return slicedIndices;
55}
56
57// Compute the new indices by adding `offsets` to `originalIndices`.
58// If m < n (m = offsets.size(), n = originalIndices.size()),
59// then only the trailing m values in `originalIndices` are updated.
60static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
61 Location loc,
62 OperandRange originalIndices,
63 ArrayRef<int64_t> offsets) {
64 assert(offsets.size() <= originalIndices.size() &&
65 "Offsets should not exceed the number of original indices");
66 SmallVector<Value> indices(originalIndices);
67
68 auto start = indices.size() - offsets.size();
69 for (auto [i, offset] : llvm::enumerate(First&: offsets)) {
70 if (offset != 0) {
71 indices[start + i] = rewriter.create<arith::AddIOp>(
72 location: loc, args: originalIndices[start + i],
73 args: rewriter.create<arith::ConstantIndexOp>(location: loc, args: offset));
74 }
75 }
76 return indices;
77}
78
79// Clones `op` into a new operations that takes `operands` and returns
80// `resultTypes`.
81static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
82 Operation *op,
83 ArrayRef<Value> operands,
84 ArrayRef<Type> resultTypes) {
85 return builder.create(loc, opName: op->getName().getIdentifier(), operands,
86 types: resultTypes, attributes: op->getAttrs());
87}
88
89/// Return the target shape for unrolling for the given `op`. Return
90/// std::nullopt if the op shouldn't be or cannot be unrolled.
91static std::optional<SmallVector<int64_t>>
92getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
93 LDBG("");
94 LDBG("Get unroll shape for op " << op->getName().getStringRef());
95 if (options.filterConstraint && failed(Result: options.filterConstraint(op))) {
96 LDBG("--no filter constraint -> BAIL");
97 return std::nullopt;
98 }
99 assert(options.nativeShape &&
100 "vector unrolling expects the native shape or native"
101 "shape call back function to be set");
102 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(Val: op);
103 if (!unrollableVectorOp) {
104 LDBG("--not an unrollable op -> BAIL");
105 return std::nullopt;
106 }
107 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
108 if (!maybeUnrollShape) {
109 LDBG("--could not get shape of op " << *op << " -> BAIL");
110 return std::nullopt;
111 }
112 LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape));
113
114 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
115 if (!targetShape) {
116 LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
117 return std::nullopt;
118 }
119 LDBG("--target shape: " << llvm::interleaved(*targetShape));
120
121 auto maybeShapeRatio = computeShapeRatio(shape: *maybeUnrollShape, subShape: *targetShape);
122 if (!maybeShapeRatio) {
123 LDBG("--could not compute integral shape ratio -> BAIL");
124 return std::nullopt;
125 }
126 if (llvm::all_of(Range&: *maybeShapeRatio, P: [](int64_t v) { return v == 1; })) {
127 LDBG("--no unrolling needed -> SKIP");
128 return std::nullopt;
129 }
130 LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
131 return targetShape;
132}
133
134static SmallVector<int64_t>
135getUnrollOrder(unsigned numLoops, Operation *op,
136 const vector::UnrollVectorOptions &options) {
137 SmallVector<int64_t> loopOrder =
138 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops)));
139 if (options.traversalOrderCallback != nullptr) {
140 std::optional<SmallVector<int64_t>> order =
141 options.traversalOrderCallback(op);
142 if (order) {
143 loopOrder = std::move(*order);
144 }
145 }
146 return loopOrder;
147}
148
149namespace {
150
151struct UnrollTransferReadPattern
152 : public OpRewritePattern<vector::TransferReadOp> {
153 UnrollTransferReadPattern(MLIRContext *context,
154 const vector::UnrollVectorOptions &options,
155 PatternBenefit benefit = 1)
156 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
157 options(options) {}
158
159 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
160 PatternRewriter &rewriter) const override {
161 // TODO: support 0-d corner case.
162 if (readOp.getTransferRank() == 0)
163 return failure();
164 if (readOp.getMask())
165 return failure();
166 auto targetShape = getTargetShape(options, op: readOp);
167 if (!targetShape)
168 return failure();
169 auto sourceVectorType = readOp.getVectorType();
170 SmallVector<int64_t> strides(targetShape->size(), 1);
171 Location loc = readOp.getLoc();
172 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
173
174 // Prepare the result vector;
175 Value result = rewriter.create<arith::ConstantOp>(
176 location: loc, args&: sourceVectorType, args: rewriter.getZeroAttr(type: sourceVectorType));
177 auto targetType =
178 VectorType::get(shape: *targetShape, elementType: sourceVectorType.getElementType());
179 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
180 readOp.getIndices().end());
181 SmallVector<int64_t> loopOrder =
182 getUnrollOrder(numLoops: originalSize.size(), op: readOp, options);
183 for (SmallVector<int64_t> elementOffsets :
184 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
185 SmallVector<Value> indices =
186 sliceTransferIndices(elementOffsets, indices: originalIndices,
187 permutationMap: readOp.getPermutationMap(), loc, builder&: rewriter);
188 auto slicedRead = rewriter.create<vector::TransferReadOp>(
189 location: loc, args&: targetType, args: readOp.getBase(), args&: indices,
190 args: readOp.getPermutationMapAttr(), args: readOp.getPadding(), args: readOp.getMask(),
191 args: readOp.getInBoundsAttr());
192
193 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
194 location: loc, args&: slicedRead, args&: result, args&: elementOffsets, args&: strides);
195 }
196 rewriter.replaceOp(op: readOp, newValues: result);
197 return success();
198 }
199
200private:
201 vector::UnrollVectorOptions options;
202};
203
204struct UnrollTransferWritePattern
205 : public OpRewritePattern<vector::TransferWriteOp> {
206 UnrollTransferWritePattern(MLIRContext *context,
207 const vector::UnrollVectorOptions &options,
208 PatternBenefit benefit = 1)
209 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
210 options(options) {}
211
212 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
213 PatternRewriter &rewriter) const override {
214 // TODO: support 0-d corner case.
215 if (writeOp.getTransferRank() == 0)
216 return failure();
217
218 if (writeOp.getMask())
219 return failure();
220 auto targetShape = getTargetShape(options, op: writeOp);
221 if (!targetShape)
222 return failure();
223 auto sourceVectorType = writeOp.getVectorType();
224 SmallVector<int64_t> strides(targetShape->size(), 1);
225 Location loc = writeOp.getLoc();
226 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
227 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
228 writeOp.getIndices().end());
229 SmallVector<int64_t> loopOrder =
230 getUnrollOrder(numLoops: originalSize.size(), op: writeOp, options);
231 Value resultTensor;
232 for (SmallVector<int64_t> elementOffsets :
233 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
234 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
235 location: loc, args: writeOp.getVector(), args&: elementOffsets, args&: *targetShape, args&: strides);
236 SmallVector<Value> indices =
237 sliceTransferIndices(elementOffsets, indices: originalIndices,
238 permutationMap: writeOp.getPermutationMap(), loc, builder&: rewriter);
239 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
240 location: loc, args&: slicedVector, args: resultTensor ? resultTensor : writeOp.getBase(),
241 args&: indices, args: writeOp.getPermutationMapAttr(), args: writeOp.getInBoundsAttr());
242 // For the tensor case update the destination for the next transfer write.
243 if (!slicedWrite->getResults().empty())
244 resultTensor = slicedWrite->getResult(idx: 0);
245 }
246 if (resultTensor)
247 rewriter.replaceOp(op: writeOp, newValues: resultTensor);
248 else
249 rewriter.eraseOp(op: writeOp);
250 return success();
251 }
252
253private:
254 vector::UnrollVectorOptions options;
255};
256
257struct OffsetMapInfo {
258 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
259
260 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
261
262 static unsigned getHashValue(const SmallVector<int64_t> &v) {
263 return static_cast<unsigned>(llvm::hash_combine_range(R: v));
264 }
265
266 static bool isEqual(const SmallVector<int64_t> &lhs,
267 const SmallVector<int64_t> &rhs) {
268 return lhs == rhs;
269 }
270};
271
272struct UnrollContractionPattern
273 : public OpRewritePattern<vector::ContractionOp> {
274 UnrollContractionPattern(MLIRContext *context,
275 const vector::UnrollVectorOptions &options,
276 PatternBenefit benefit = 1)
277 : OpRewritePattern<vector::ContractionOp>(context, benefit),
278 options(options) {}
279
280 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
281 PatternRewriter &rewriter) const override {
282 auto targetShape = getTargetShape(options, op: contractOp);
283 if (!targetShape)
284 return failure();
285 auto dstVecType = cast<VectorType>(Val: contractOp.getResultType());
286 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
287
288 Location loc = contractOp.getLoc();
289 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
290 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
291 llvm::MapVector<
292 SmallVector<int64_t>, Value,
293 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
294 accCache;
295
296 SmallVector<int64_t> loopOrder = getUnrollOrder(
297 numLoops: contractOp.getIteratorTypes().size(), op: contractOp, options);
298
299 for (SmallVector<int64_t> offsets :
300 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
301 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
302
303 // Helper to compute the new shape of each operand and extract the slice.
304 auto extractOperand = [&](unsigned index, Value operand,
305 AffineMap permutationMap,
306 ArrayRef<int64_t> operandOffets) {
307 SmallVector<int64_t> operandShape = applyPermutationMap(
308 map: permutationMap, source: ArrayRef<int64_t>(*targetShape));
309 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
310 slicesOperands[index] =
311 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
312 location: loc, args&: operand, args&: operandOffets, args&: operandShape, args&: operandStrides);
313 };
314
315 // Extract the new lhs operand.
316 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
317 SmallVector<int64_t> lhsOffets =
318 applyPermutationMap(map: lhsPermutationMap, source: ArrayRef<int64_t>(offsets));
319 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
320
321 // Extract the new rhs operand.
322 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
323 SmallVector<int64_t> rhsOffets =
324 applyPermutationMap(map: rhsPermutationMap, source: ArrayRef<int64_t>(offsets));
325 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
326
327 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
328 SmallVector<int64_t> accOffets =
329 applyPermutationMap(map: accPermutationMap, source: ArrayRef<int64_t>(offsets));
330 // If a version of the accumulator has already been computed, use it
331 // otherwise extract the first version from the original operand.
332 auto *accIt = accCache.find(Key: accOffets);
333 if (accIt != accCache.end())
334 slicesOperands[2] = accIt->second;
335 else
336 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
337
338 SmallVector<int64_t> dstShape =
339 applyPermutationMap(map: dstAffineMap, source: ArrayRef<int64_t>(*targetShape));
340 auto targetType = VectorType::get(shape: dstShape, elementType: dstVecType.getElementType());
341 Operation *newOp = cloneOpWithOperandsAndTypes(
342 builder&: rewriter, loc, op: contractOp, operands: slicesOperands, resultTypes: targetType);
343
344 SmallVector<int64_t> dstOffets =
345 applyPermutationMap(map: dstAffineMap, source: ArrayRef<int64_t>(offsets));
346 // Save the accumulated value untill all the loops are unrolled since
347 // reduction loop keep updating the accumulator.
348 accCache[dstOffets] = newOp->getResult(idx: 0);
349 }
350 // Assemble back the accumulator into a single vector.
351 Value result = rewriter.create<arith::ConstantOp>(
352 location: loc, args&: dstVecType, args: rewriter.getZeroAttr(type: dstVecType));
353 for (const auto &it : accCache) {
354 SmallVector<int64_t> dstStrides(it.first.size(), 1);
355 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
356 location: loc, args: it.second, args&: result, args: it.first, args&: dstStrides);
357 }
358 rewriter.replaceOp(op: contractOp, newValues: result);
359 return success();
360 }
361
362private:
363 vector::UnrollVectorOptions options;
364};
365
366struct UnrollMultiReductionPattern
367 : public OpRewritePattern<vector::MultiDimReductionOp> {
368 UnrollMultiReductionPattern(MLIRContext *context,
369 const vector::UnrollVectorOptions &options,
370 PatternBenefit benefit = 1)
371 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
372 options(options) {}
373
374 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
375 PatternRewriter &rewriter) const override {
376 auto resultType = reductionOp->getResult(idx: 0).getType();
377 if (resultType.isIntOrFloat()) {
378 return rewriter.notifyMatchFailure(arg&: reductionOp,
379 msg: "Unrolling scalars is not supported");
380 }
381 std::optional<SmallVector<int64_t>> targetShape =
382 getTargetShape(options, op: reductionOp);
383 if (!targetShape)
384 return failure();
385 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
386 llvm::MapVector<
387 SmallVector<int64_t>, Value,
388 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
389 accCache;
390 Location loc = reductionOp.getLoc();
391
392 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
393 // of multiples of the targetShape.
394 for (SmallVector<int64_t> offsets :
395 StaticTileOffsetRange(originalSize, *targetShape)) {
396 SmallVector<Value> operands;
397 SmallVector<int64_t> operandStrides(offsets.size(), 1);
398 Value slicedOperand =
399 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
400 location: loc, args: reductionOp.getSource(), args&: offsets, args&: *targetShape,
401 args&: operandStrides);
402 operands.push_back(Elt: slicedOperand);
403 SmallVector<int64_t> dstShape;
404 SmallVector<int64_t> destOffset;
405 for (size_t i : llvm::seq(Begin: size_t(0), End: targetShape->size())) {
406 if (!reductionOp.isReducedDim(d: i)) {
407 destOffset.push_back(Elt: offsets[i]);
408 dstShape.push_back(Elt: (*targetShape)[i]);
409 }
410 }
411 Value acc;
412 SmallVector<int64_t> accStrides(destOffset.size(), 1);
413 // If a version of the accumulator has already been computed, use it
414 // otherwise extract the first version from the original operand.
415 auto *accIt = accCache.find(Key: destOffset);
416 if (accIt != accCache.end())
417 acc = accIt->second;
418 else
419 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
420 location: loc, args: reductionOp.getAcc(), args&: destOffset, args&: dstShape, args&: accStrides);
421 operands.push_back(Elt: acc);
422 auto targetType = VectorType::get(
423 shape: dstShape, elementType: reductionOp.getSourceVectorType().getElementType());
424 Operation *newOp = cloneOpWithOperandsAndTypes(builder&: rewriter, loc, op: reductionOp,
425 operands, resultTypes: targetType);
426 Value result = newOp->getResult(idx: 0);
427 accCache[destOffset] = result;
428 }
429 // Assemble back the accumulator into a single vector.
430 Value result = rewriter.create<arith::ConstantOp>(
431 location: loc, args: reductionOp.getDestType(),
432 args: rewriter.getZeroAttr(type: reductionOp.getDestType()));
433 for (const auto &it : accCache) {
434 SmallVector<int64_t> dstStrides(it.first.size(), 1);
435 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
436 location: loc, args: it.second, args&: result, args: it.first, args&: dstStrides);
437 }
438 rewriter.replaceOp(op: reductionOp, newValues: result);
439 return success();
440 }
441
442private:
443 vector::UnrollVectorOptions options;
444};
445
446struct UnrollElementwisePattern : public RewritePattern {
447 UnrollElementwisePattern(MLIRContext *context,
448 const vector::UnrollVectorOptions &options,
449 PatternBenefit benefit = 1)
450 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
451 options(options) {}
452
453 LogicalResult matchAndRewrite(Operation *op,
454 PatternRewriter &rewriter) const override {
455 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
456 return failure();
457 auto targetShape = getTargetShape(options, op);
458 if (!targetShape)
459 return failure();
460 auto dstVecType = cast<VectorType>(Val: op->getResult(idx: 0).getType());
461 SmallVector<int64_t> originalSize =
462 *cast<VectorUnrollOpInterface>(Val: op).getShapeForUnroll();
463 // Bail-out if rank(source) != rank(target). The main limitation here is the
464 // fact that `ExtractStridedSlice` requires the rank for the input and
465 // output to match. If needed, we can relax this later.
466 if (originalSize.size() != targetShape->size())
467 return rewriter.notifyMatchFailure(
468 arg&: op, msg: "expected input vector rank to match target shape rank");
469 Location loc = op->getLoc();
470 // Prepare the result vector.
471 Value result = rewriter.create<arith::ConstantOp>(
472 location: loc, args&: dstVecType, args: rewriter.getZeroAttr(type: dstVecType));
473 SmallVector<int64_t> strides(targetShape->size(), 1);
474 VectorType newVecType =
475 VectorType::get(shape: *targetShape, elementType: dstVecType.getElementType());
476
477 // Create the unrolled computation.
478 for (SmallVector<int64_t> offsets :
479 StaticTileOffsetRange(originalSize, *targetShape)) {
480 SmallVector<Value> extractOperands;
481 for (OpOperand &operand : op->getOpOperands()) {
482 auto vecType = dyn_cast<VectorType>(Val: operand.get().getType());
483 if (!vecType) {
484 extractOperands.push_back(Elt: operand.get());
485 continue;
486 }
487 extractOperands.push_back(
488 Elt: rewriter.createOrFold<vector::ExtractStridedSliceOp>(
489 location: loc, args: operand.get(), args&: offsets, args&: *targetShape, args&: strides));
490 }
491 Operation *newOp = cloneOpWithOperandsAndTypes(
492 builder&: rewriter, loc, op, operands: extractOperands, resultTypes: newVecType);
493 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
494 location: loc, args: newOp->getResult(idx: 0), args&: result, args&: offsets, args&: strides);
495 }
496 rewriter.replaceOp(op, newValues: result);
497 return success();
498 }
499
500private:
501 vector::UnrollVectorOptions options;
502};
503
504struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
505 UnrollReductionPattern(MLIRContext *context,
506 const vector::UnrollVectorOptions &options,
507 PatternBenefit benefit = 1)
508 : OpRewritePattern<vector::ReductionOp>(context, benefit),
509 options(options) {}
510
511 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
512 PatternRewriter &rewriter) const override {
513 std::optional<SmallVector<int64_t>> targetShape =
514 getTargetShape(options, op: reductionOp);
515 if (!targetShape)
516 return failure();
517 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
518
519 // Create unrolled vector reduction.
520 Location loc = reductionOp.getLoc();
521 Value accumulator = nullptr;
522 for (SmallVector<int64_t> offsets :
523 StaticTileOffsetRange(originalSize, *targetShape)) {
524 SmallVector<int64_t> strides(offsets.size(), 1);
525 Value slicedOperand =
526 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
527 location: loc, args: reductionOp.getVector(), args&: offsets, args&: *targetShape, args&: strides);
528 Operation *newOp = cloneOpWithOperandsAndTypes(
529 builder&: rewriter, loc, op: reductionOp, operands: slicedOperand, resultTypes: reductionOp.getType());
530 Value result = newOp->getResult(idx: 0);
531
532 if (!accumulator) {
533 // This is the first reduction.
534 accumulator = result;
535 } else {
536 // On subsequent reduction, combine with the accumulator.
537 accumulator = makeArithReduction(b&: rewriter, loc, kind: reductionOp.getKind(),
538 v1: accumulator, acc: result);
539 }
540 }
541
542 rewriter.replaceOp(op: reductionOp, newValues: accumulator);
543 return success();
544 }
545
546private:
547 const vector::UnrollVectorOptions options;
548};
549
550struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
551 UnrollTransposePattern(MLIRContext *context,
552 const vector::UnrollVectorOptions &options,
553 PatternBenefit benefit = 1)
554 : OpRewritePattern<vector::TransposeOp>(context, benefit),
555 options(options) {}
556
557 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
558 PatternRewriter &rewriter) const override {
559 if (transposeOp.getResultVectorType().getRank() == 0)
560 return failure();
561 auto targetShape = getTargetShape(options, op: transposeOp);
562 if (!targetShape)
563 return failure();
564 auto originalVectorType = transposeOp.getResultVectorType();
565 SmallVector<int64_t> strides(targetShape->size(), 1);
566 Location loc = transposeOp.getLoc();
567 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
568
569 // Prepare the result vector;
570 Value result = rewriter.create<arith::ConstantOp>(
571 location: loc, args&: originalVectorType, args: rewriter.getZeroAttr(type: originalVectorType));
572 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
573
574 // Unroll the computation.
575 for (SmallVector<int64_t> elementOffsets :
576 StaticTileOffsetRange(originalSize, *targetShape)) {
577 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
578 SmallVector<int64_t> permutedShape(elementOffsets.size());
579 // Compute the source offsets and shape.
580 for (auto indices : llvm::enumerate(First&: permutation)) {
581 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
582 permutedShape[indices.value()] = (*targetShape)[indices.index()];
583 }
584 Value slicedOperand =
585 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
586 location: loc, args: transposeOp.getVector(), args&: permutedOffsets, args&: permutedShape,
587 args&: strides);
588 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
589 location: loc, args&: slicedOperand, args&: permutation);
590 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
591 location: loc, args&: transposedSlice, args&: result, args&: elementOffsets, args&: strides);
592 }
593 rewriter.replaceOp(op: transposeOp, newValues: result);
594 return success();
595 }
596
597private:
598 vector::UnrollVectorOptions options;
599};
600
601struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
602 UnrollGatherPattern(MLIRContext *context,
603 const vector::UnrollVectorOptions &options,
604 PatternBenefit benefit = 1)
605 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
606 }
607
608 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
609 PatternRewriter &rewriter) const override {
610 VectorType sourceVectorType = gatherOp.getVectorType();
611 if (sourceVectorType.getRank() == 0)
612 return failure();
613 auto targetShape = getTargetShape(options, op: gatherOp);
614 if (!targetShape)
615 return failure();
616 SmallVector<int64_t> strides(targetShape->size(), 1);
617 Location loc = gatherOp.getLoc();
618 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
619
620 // Prepare the result vector;
621 Value result = rewriter.create<arith::ConstantOp>(
622 location: loc, args&: sourceVectorType, args: rewriter.getZeroAttr(type: sourceVectorType));
623 auto targetType =
624 VectorType::get(shape: *targetShape, elementType: sourceVectorType.getElementType());
625
626 SmallVector<int64_t> loopOrder =
627 getUnrollOrder(numLoops: originalSize.size(), op: gatherOp, options);
628 for (SmallVector<int64_t> elementOffsets :
629 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
630 // To get the unrolled gather, extract the same slice based on the
631 // decomposed shape from each of the index, mask, and pass-through
632 // vectors.
633 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
634 location: loc, args: gatherOp.getIndexVec(), args&: elementOffsets, args&: *targetShape, args&: strides);
635 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
636 location: loc, args: gatherOp.getMask(), args&: elementOffsets, args&: *targetShape, args&: strides);
637 Value passThruSubVec =
638 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
639 location: loc, args: gatherOp.getPassThru(), args&: elementOffsets, args&: *targetShape,
640 args&: strides);
641 auto slicedGather = rewriter.create<vector::GatherOp>(
642 location: loc, args&: targetType, args: gatherOp.getBase(), args: gatherOp.getIndices(),
643 args&: indexSubVec, args&: maskSubVec, args&: passThruSubVec);
644
645 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
646 location: loc, args&: slicedGather, args&: result, args&: elementOffsets, args&: strides);
647 }
648 rewriter.replaceOp(op: gatherOp, newValues: result);
649 return success();
650 }
651
652private:
653 vector::UnrollVectorOptions options;
654};
655
656struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
657 UnrollLoadPattern(MLIRContext *context,
658 const vector::UnrollVectorOptions &options,
659 PatternBenefit benefit = 1)
660 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
661
662 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
663 PatternRewriter &rewriter) const override {
664 VectorType vecType = loadOp.getVectorType();
665
666 auto targetShape = getTargetShape(options, op: loadOp);
667 if (!targetShape)
668 return failure();
669
670 Location loc = loadOp.getLoc();
671 ArrayRef<int64_t> originalShape = vecType.getShape();
672 SmallVector<int64_t> strides(targetShape->size(), 1);
673
674 Value result = rewriter.create<arith::ConstantOp>(
675 location: loc, args&: vecType, args: rewriter.getZeroAttr(type: vecType));
676
677 SmallVector<int64_t> loopOrder =
678 getUnrollOrder(numLoops: originalShape.size(), op: loadOp, options);
679
680 auto targetVecType =
681 VectorType::get(shape: *targetShape, elementType: vecType.getElementType());
682
683 for (SmallVector<int64_t> offsets :
684 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
685 SmallVector<Value> indices =
686 sliceLoadStoreIndices(rewriter, loc, originalIndices: loadOp.getIndices(), offsets);
687 Value slicedLoad = rewriter.create<vector::LoadOp>(
688 location: loc, args&: targetVecType, args: loadOp.getBase(), args&: indices);
689 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
690 location: loc, args&: slicedLoad, args&: result, args&: offsets, args&: strides);
691 }
692 rewriter.replaceOp(op: loadOp, newValues: result);
693 return success();
694 }
695
696private:
697 vector::UnrollVectorOptions options;
698};
699
700struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
701 UnrollStorePattern(MLIRContext *context,
702 const vector::UnrollVectorOptions &options,
703 PatternBenefit benefit = 1)
704 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
705
706 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
707 PatternRewriter &rewriter) const override {
708 VectorType vecType = storeOp.getVectorType();
709
710 auto targetShape = getTargetShape(options, op: storeOp);
711 if (!targetShape)
712 return failure();
713
714 Location loc = storeOp.getLoc();
715 ArrayRef<int64_t> originalShape = vecType.getShape();
716 SmallVector<int64_t> strides(targetShape->size(), 1);
717
718 Value base = storeOp.getBase();
719 Value vector = storeOp.getValueToStore();
720
721 SmallVector<int64_t> loopOrder =
722 getUnrollOrder(numLoops: originalShape.size(), op: storeOp, options);
723
724 for (SmallVector<int64_t> offsets :
725 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
726 SmallVector<Value> indices =
727 sliceLoadStoreIndices(rewriter, loc, originalIndices: storeOp.getIndices(), offsets);
728 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
729 location: loc, args&: vector, args&: offsets, args&: *targetShape, args&: strides);
730 rewriter.create<vector::StoreOp>(location: loc, args&: slice, args&: base, args&: indices);
731 }
732 rewriter.eraseOp(op: storeOp);
733 return success();
734 }
735
736private:
737 vector::UnrollVectorOptions options;
738};
739
740struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
741 UnrollBroadcastPattern(MLIRContext *context,
742 const vector::UnrollVectorOptions &options,
743 PatternBenefit benefit = 1)
744 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
745 options(options) {}
746
747 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
748 PatternRewriter &rewriter) const override {
749 auto targetShape = getTargetShape(options, op: broadcastOp);
750 if (!targetShape)
751 return failure();
752
753 Location loc = broadcastOp.getLoc();
754 VectorType srcType = dyn_cast<VectorType>(Val: broadcastOp.getSourceType());
755 VectorType resType = broadcastOp.getResultVectorType();
756 VectorType targetType =
757 resType.cloneWith(shape: *targetShape, elementType: resType.getElementType());
758 Value result = rewriter.create<arith::ConstantOp>(
759 location: loc, args&: resType, args: rewriter.getZeroAttr(type: resType));
760
761 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
762 SmallVector<int64_t> strides(originalShape.size(), 1);
763
764 for (SmallVector<int64_t> offsets :
765 StaticTileOffsetRange(originalShape, *targetShape)) {
766 Value newSrc;
767 if (!srcType) {
768 // Scalar to vector broadcast.
769 newSrc = broadcastOp.getSource();
770 } else {
771 // Vector to vector broadcast.
772 int64_t rank = srcType.getRank();
773 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
774 SmallVector<int64_t> srcShape(targetShape->end() - rank,
775 targetShape->end());
776 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
777 // adjust the offset and shape for src if the corresponding dim is 1.
778 for (int64_t i = 0; i < rank; ++i) {
779 if (srcType.getDimSize(idx: i) == 1) {
780 srcOffsets[i] = 0;
781 srcShape[i] = 1;
782 }
783 }
784 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
785 location: loc, args: broadcastOp.getSource(), args&: srcOffsets, args&: srcShape, args&: srcStrides);
786 }
787
788 Operation *newOp = cloneOpWithOperandsAndTypes(builder&: rewriter, loc, op: broadcastOp,
789 operands: newSrc, resultTypes: targetType);
790
791 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
792 location: loc, args: newOp->getResult(idx: 0), args&: result, args&: offsets, args&: strides);
793 }
794
795 rewriter.replaceOp(op: broadcastOp, newValues: result);
796 return success();
797 }
798
799private:
800 vector::UnrollVectorOptions options;
801};
802
803} // namespace
804
805void mlir::vector::populateVectorUnrollPatterns(
806 RewritePatternSet &patterns, const UnrollVectorOptions &options,
807 PatternBenefit benefit) {
808 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
809 UnrollContractionPattern, UnrollElementwisePattern,
810 UnrollReductionPattern, UnrollMultiReductionPattern,
811 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
812 UnrollStorePattern, UnrollBroadcastPattern>(
813 arg: patterns.getContext(), args: options, args&: benefit);
814}
815

source code of mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp