1 | //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to do vector unrolling and vector distribution. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
15 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
16 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
17 | #include "mlir/Interfaces/VectorInterfaces.h" |
18 | #include "mlir/Support/MathExtras.h" |
19 | #include "llvm/ADT/MapVector.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/Support/Debug.h" |
22 | #include <numeric> |
23 | #include <optional> |
24 | |
25 | #define DEBUG_TYPE "vector-unroll" |
26 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
27 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::vector; |
31 | |
32 | /// Compute the indices of the slice `index` for a tranfer op. |
33 | static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, |
34 | ArrayRef<Value> indices, |
35 | AffineMap permutationMap, |
36 | Location loc, |
37 | OpBuilder &builder) { |
38 | MLIRContext *ctx = builder.getContext(); |
39 | auto isBroadcast = [](AffineExpr expr) { |
40 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) |
41 | return constExpr.getValue() == 0; |
42 | return false; |
43 | }; |
44 | // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. |
45 | SmallVector<Value> slicedIndices(indices.begin(), indices.end()); |
46 | for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) { |
47 | if (isBroadcast(dim.value())) |
48 | continue; |
49 | unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition(); |
50 | auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) + |
51 | getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx); |
52 | auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr); |
53 | slicedIndices[pos] = |
54 | builder.create<affine::AffineApplyOp>(loc, map, indices[pos]); |
55 | } |
56 | return slicedIndices; |
57 | } |
58 | |
59 | // Clones `op` into a new operations that takes `operands` and returns |
60 | // `resultTypes`. |
61 | static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, |
62 | Operation *op, |
63 | ArrayRef<Value> operands, |
64 | ArrayRef<Type> resultTypes) { |
65 | return builder.create(loc, op->getName().getIdentifier(), operands, |
66 | resultTypes, op->getAttrs()); |
67 | } |
68 | |
69 | /// Return the target shape for unrolling for the given `op`. Return |
70 | /// std::nullopt if the op shouldn't be or cannot be unrolled. |
71 | static std::optional<SmallVector<int64_t>> |
72 | getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { |
73 | LDBG("" ); |
74 | LDBG("Get unroll shape for op " << op->getName().getStringRef()); |
75 | if (options.filterConstraint && failed(options.filterConstraint(op))) { |
76 | LDBG("--no filter constraint -> BAIL" ); |
77 | return std::nullopt; |
78 | } |
79 | assert(options.nativeShape && |
80 | "vector unrolling expects the native shape or native" |
81 | "shape call back function to be set" ); |
82 | auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); |
83 | if (!unrollableVectorOp) { |
84 | LDBG("--not an unrollable op -> BAIL" ); |
85 | return std::nullopt; |
86 | } |
87 | auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); |
88 | if (!maybeUnrollShape) { |
89 | LDBG("--could not get shape of op " << *op << " -> BAIL" ); |
90 | return std::nullopt; |
91 | } |
92 | LLVM_DEBUG( |
93 | llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: " ); |
94 | llvm::dbgs() << "\n" ;); |
95 | |
96 | std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); |
97 | if (!targetShape) { |
98 | LDBG("--no unrolling target shape defined " << *op << "-> SKIP" ); |
99 | return std::nullopt; |
100 | } |
101 | LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: " ); |
102 | llvm::dbgs() << "\n" ;); |
103 | |
104 | auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); |
105 | if (!maybeShapeRatio) { |
106 | LDBG("--could not compute integral shape ratio -> BAIL" ); |
107 | return std::nullopt; |
108 | } |
109 | if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { |
110 | LDBG("--no unrolling needed -> SKIP" ); |
111 | return std::nullopt; |
112 | } |
113 | LDBG("--found an integral shape ratio to unroll to -> SUCCESS" ); |
114 | return targetShape; |
115 | } |
116 | |
117 | static SmallVector<int64_t> |
118 | getUnrollOrder(unsigned numLoops, Operation *op, |
119 | const vector::UnrollVectorOptions &options) { |
120 | SmallVector<int64_t> loopOrder = |
121 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops))); |
122 | if (options.traversalOrderCallback != nullptr) { |
123 | std::optional<SmallVector<int64_t>> order = |
124 | options.traversalOrderCallback(op); |
125 | if (order) { |
126 | loopOrder = std::move(*order); |
127 | } |
128 | } |
129 | return loopOrder; |
130 | } |
131 | |
132 | namespace { |
133 | |
134 | struct UnrollTransferReadPattern |
135 | : public OpRewritePattern<vector::TransferReadOp> { |
136 | UnrollTransferReadPattern(MLIRContext *context, |
137 | const vector::UnrollVectorOptions &options, |
138 | PatternBenefit benefit = 1) |
139 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
140 | options(options) {} |
141 | |
142 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
143 | PatternRewriter &rewriter) const override { |
144 | // TODO: support 0-d corner case. |
145 | if (readOp.getTransferRank() == 0) |
146 | return failure(); |
147 | if (readOp.getMask()) |
148 | return failure(); |
149 | auto targetShape = getTargetShape(options, readOp); |
150 | if (!targetShape) |
151 | return failure(); |
152 | auto sourceVectorType = readOp.getVectorType(); |
153 | SmallVector<int64_t> strides(targetShape->size(), 1); |
154 | Location loc = readOp.getLoc(); |
155 | ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); |
156 | |
157 | // Prepare the result vector; |
158 | Value result = rewriter.create<arith::ConstantOp>( |
159 | loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); |
160 | auto targetType = |
161 | VectorType::get(*targetShape, sourceVectorType.getElementType()); |
162 | SmallVector<Value> originalIndices(readOp.getIndices().begin(), |
163 | readOp.getIndices().end()); |
164 | SmallVector<int64_t> loopOrder = |
165 | getUnrollOrder(originalSize.size(), readOp, options); |
166 | for (SmallVector<int64_t> elementOffsets : |
167 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
168 | SmallVector<Value> indices = |
169 | sliceTransferIndices(elementOffsets, originalIndices, |
170 | readOp.getPermutationMap(), loc, rewriter); |
171 | auto slicedRead = rewriter.create<vector::TransferReadOp>( |
172 | loc, targetType, readOp.getSource(), indices, |
173 | readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), |
174 | readOp.getInBoundsAttr()); |
175 | |
176 | result = rewriter.create<vector::InsertStridedSliceOp>( |
177 | loc, slicedRead, result, elementOffsets, strides); |
178 | } |
179 | rewriter.replaceOp(readOp, result); |
180 | return success(); |
181 | } |
182 | |
183 | private: |
184 | vector::UnrollVectorOptions options; |
185 | }; |
186 | |
187 | struct UnrollTransferWritePattern |
188 | : public OpRewritePattern<vector::TransferWriteOp> { |
189 | UnrollTransferWritePattern(MLIRContext *context, |
190 | const vector::UnrollVectorOptions &options, |
191 | PatternBenefit benefit = 1) |
192 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
193 | options(options) {} |
194 | |
195 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
196 | PatternRewriter &rewriter) const override { |
197 | // TODO: support 0-d corner case. |
198 | if (writeOp.getTransferRank() == 0) |
199 | return failure(); |
200 | |
201 | if (writeOp.getMask()) |
202 | return failure(); |
203 | auto targetShape = getTargetShape(options, writeOp); |
204 | if (!targetShape) |
205 | return failure(); |
206 | auto sourceVectorType = writeOp.getVectorType(); |
207 | SmallVector<int64_t> strides(targetShape->size(), 1); |
208 | Location loc = writeOp.getLoc(); |
209 | ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); |
210 | SmallVector<Value> originalIndices(writeOp.getIndices().begin(), |
211 | writeOp.getIndices().end()); |
212 | SmallVector<int64_t> loopOrder = |
213 | getUnrollOrder(originalSize.size(), writeOp, options); |
214 | Value resultTensor; |
215 | for (SmallVector<int64_t> elementOffsets : |
216 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
217 | Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( |
218 | loc, writeOp.getVector(), elementOffsets, *targetShape, strides); |
219 | SmallVector<Value> indices = |
220 | sliceTransferIndices(elementOffsets, originalIndices, |
221 | writeOp.getPermutationMap(), loc, rewriter); |
222 | Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( |
223 | loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), |
224 | indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); |
225 | // For the tensor case update the destination for the next transfer write. |
226 | if (!slicedWrite->getResults().empty()) |
227 | resultTensor = slicedWrite->getResult(0); |
228 | } |
229 | if (resultTensor) |
230 | rewriter.replaceOp(writeOp, resultTensor); |
231 | else |
232 | rewriter.eraseOp(op: writeOp); |
233 | return success(); |
234 | } |
235 | |
236 | private: |
237 | vector::UnrollVectorOptions options; |
238 | }; |
239 | |
240 | struct OffsetMapInfo { |
241 | static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } |
242 | |
243 | static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } |
244 | |
245 | static unsigned getHashValue(const SmallVector<int64_t> &v) { |
246 | return static_cast<unsigned>(llvm::hash_combine_range(first: v.begin(), last: v.end())); |
247 | } |
248 | |
249 | static bool isEqual(const SmallVector<int64_t> &lhs, |
250 | const SmallVector<int64_t> &rhs) { |
251 | return lhs == rhs; |
252 | } |
253 | }; |
254 | |
255 | struct UnrollContractionPattern |
256 | : public OpRewritePattern<vector::ContractionOp> { |
257 | UnrollContractionPattern(MLIRContext *context, |
258 | const vector::UnrollVectorOptions &options, |
259 | PatternBenefit benefit = 1) |
260 | : OpRewritePattern<vector::ContractionOp>(context, benefit), |
261 | options(options) {} |
262 | |
263 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
264 | PatternRewriter &rewriter) const override { |
265 | auto targetShape = getTargetShape(options, contractOp); |
266 | if (!targetShape) |
267 | return failure(); |
268 | auto dstVecType = cast<VectorType>(contractOp.getResultType()); |
269 | SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); |
270 | |
271 | Location loc = contractOp.getLoc(); |
272 | unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); |
273 | AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; |
274 | llvm::MapVector< |
275 | SmallVector<int64_t>, Value, |
276 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
277 | accCache; |
278 | |
279 | SmallVector<int64_t> loopOrder = getUnrollOrder( |
280 | contractOp.getIteratorTypes().size(), contractOp, options); |
281 | |
282 | for (SmallVector<int64_t> offsets : |
283 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
284 | SmallVector<Value> slicesOperands(contractOp.getNumOperands()); |
285 | |
286 | // Helper to compute the new shape of each operand and extract the slice. |
287 | auto extractOperand = [&](unsigned index, Value operand, |
288 | AffineMap permutationMap, |
289 | ArrayRef<int64_t> operandOffets) { |
290 | SmallVector<int64_t> operandShape = applyPermutationMap( |
291 | permutationMap, ArrayRef<int64_t>(*targetShape)); |
292 | SmallVector<int64_t> operandStrides(operandOffets.size(), 1); |
293 | slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( |
294 | loc, operand, operandOffets, operandShape, operandStrides); |
295 | }; |
296 | |
297 | // Extract the new lhs operand. |
298 | AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; |
299 | SmallVector<int64_t> lhsOffets = |
300 | applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
301 | extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); |
302 | |
303 | // Extract the new rhs operand. |
304 | AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; |
305 | SmallVector<int64_t> rhsOffets = |
306 | applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
307 | extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); |
308 | |
309 | AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; |
310 | SmallVector<int64_t> accOffets = |
311 | applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
312 | // If a version of the accumulator has already been computed, use it |
313 | // otherwise extract the first version from the original operand. |
314 | auto *accIt = accCache.find(accOffets); |
315 | if (accIt != accCache.end()) |
316 | slicesOperands[2] = accIt->second; |
317 | else |
318 | extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); |
319 | |
320 | SmallVector<int64_t> dstShape = |
321 | applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); |
322 | auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); |
323 | Operation *newOp = cloneOpWithOperandsAndTypes( |
324 | rewriter, loc, contractOp, slicesOperands, targetType); |
325 | |
326 | SmallVector<int64_t> dstOffets = |
327 | applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); |
328 | // Save the accumulated value untill all the loops are unrolled since |
329 | // reduction loop keep updating the accumulator. |
330 | accCache[dstOffets] = newOp->getResult(0); |
331 | } |
332 | // Assemble back the accumulator into a single vector. |
333 | Value result = rewriter.create<arith::ConstantOp>( |
334 | loc, dstVecType, rewriter.getZeroAttr(dstVecType)); |
335 | for (const auto &it : accCache) { |
336 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
337 | result = rewriter.create<vector::InsertStridedSliceOp>( |
338 | loc, it.second, result, it.first, dstStrides); |
339 | } |
340 | rewriter.replaceOp(contractOp, result); |
341 | return success(); |
342 | } |
343 | |
344 | private: |
345 | vector::UnrollVectorOptions options; |
346 | }; |
347 | |
348 | struct UnrollMultiReductionPattern |
349 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
350 | UnrollMultiReductionPattern(MLIRContext *context, |
351 | const vector::UnrollVectorOptions &options, |
352 | PatternBenefit benefit = 1) |
353 | : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
354 | options(options) {} |
355 | |
356 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, |
357 | PatternRewriter &rewriter) const override { |
358 | std::optional<SmallVector<int64_t>> targetShape = |
359 | getTargetShape(options, reductionOp); |
360 | if (!targetShape) |
361 | return failure(); |
362 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
363 | llvm::MapVector< |
364 | SmallVector<int64_t>, Value, |
365 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
366 | accCache; |
367 | Location loc = reductionOp.getLoc(); |
368 | |
369 | // Stride of the ratios, this gives us the offsets of sliceCount in a basis |
370 | // of multiples of the targetShape. |
371 | for (SmallVector<int64_t> offsets : |
372 | StaticTileOffsetRange(originalSize, *targetShape)) { |
373 | SmallVector<Value> operands; |
374 | SmallVector<int64_t> operandStrides(offsets.size(), 1); |
375 | Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( |
376 | loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); |
377 | operands.push_back(slicedOperand); |
378 | SmallVector<int64_t> dstShape; |
379 | SmallVector<int64_t> destOffset; |
380 | for (size_t i : llvm::seq(size_t(0), targetShape->size())) { |
381 | if (!reductionOp.isReducedDim(i)) { |
382 | destOffset.push_back(offsets[i]); |
383 | dstShape.push_back((*targetShape)[i]); |
384 | } |
385 | } |
386 | Value acc; |
387 | SmallVector<int64_t> accStrides(destOffset.size(), 1); |
388 | // If a version of the accumulator has already been computed, use it |
389 | // otherwise extract the first version from the original operand. |
390 | auto *accIt = accCache.find(destOffset); |
391 | if (accIt != accCache.end()) |
392 | acc = accIt->second; |
393 | else |
394 | acc = rewriter.create<vector::ExtractStridedSliceOp>( |
395 | loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); |
396 | operands.push_back(acc); |
397 | auto targetType = VectorType::get( |
398 | dstShape, reductionOp.getSourceVectorType().getElementType()); |
399 | Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, |
400 | operands, targetType); |
401 | Value result = newOp->getResult(0); |
402 | accCache[destOffset] = result; |
403 | } |
404 | // Assemble back the accumulator into a single vector. |
405 | Value result = rewriter.create<arith::ConstantOp>( |
406 | loc, reductionOp.getDestType(), |
407 | rewriter.getZeroAttr(reductionOp.getDestType())); |
408 | for (const auto &it : accCache) { |
409 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
410 | result = rewriter.create<vector::InsertStridedSliceOp>( |
411 | loc, it.second, result, it.first, dstStrides); |
412 | } |
413 | rewriter.replaceOp(reductionOp, result); |
414 | return success(); |
415 | } |
416 | |
417 | private: |
418 | vector::UnrollVectorOptions options; |
419 | }; |
420 | |
421 | struct UnrollElementwisePattern : public RewritePattern { |
422 | UnrollElementwisePattern(MLIRContext *context, |
423 | const vector::UnrollVectorOptions &options, |
424 | PatternBenefit benefit = 1) |
425 | : RewritePattern(MatchAnyOpTypeTag(), benefit, context), |
426 | options(options) {} |
427 | |
428 | LogicalResult matchAndRewrite(Operation *op, |
429 | PatternRewriter &rewriter) const override { |
430 | if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
431 | return failure(); |
432 | auto targetShape = getTargetShape(options, op); |
433 | if (!targetShape) |
434 | return failure(); |
435 | auto dstVecType = cast<VectorType>(op->getResult(idx: 0).getType()); |
436 | SmallVector<int64_t> originalSize = |
437 | *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); |
438 | Location loc = op->getLoc(); |
439 | // Prepare the result vector. |
440 | Value result = rewriter.create<arith::ConstantOp>( |
441 | loc, dstVecType, rewriter.getZeroAttr(dstVecType)); |
442 | SmallVector<int64_t> strides(targetShape->size(), 1); |
443 | VectorType newVecType = |
444 | VectorType::get(*targetShape, dstVecType.getElementType()); |
445 | |
446 | // Create the unrolled computation. |
447 | for (SmallVector<int64_t> offsets : |
448 | StaticTileOffsetRange(originalSize, *targetShape)) { |
449 | SmallVector<Value> extractOperands; |
450 | for (OpOperand &operand : op->getOpOperands()) { |
451 | auto vecType = dyn_cast<VectorType>(operand.get().getType()); |
452 | if (!vecType) { |
453 | extractOperands.push_back(operand.get()); |
454 | continue; |
455 | } |
456 | extractOperands.push_back( |
457 | rewriter.create<vector::ExtractStridedSliceOp>( |
458 | loc, operand.get(), offsets, *targetShape, strides)); |
459 | } |
460 | Operation *newOp = cloneOpWithOperandsAndTypes( |
461 | rewriter, loc, op, extractOperands, newVecType); |
462 | result = rewriter.create<vector::InsertStridedSliceOp>( |
463 | loc, newOp->getResult(0), result, offsets, strides); |
464 | } |
465 | rewriter.replaceOp(op, newValues: result); |
466 | return success(); |
467 | } |
468 | |
469 | private: |
470 | vector::UnrollVectorOptions options; |
471 | }; |
472 | |
473 | struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { |
474 | UnrollReductionPattern(MLIRContext *context, |
475 | const vector::UnrollVectorOptions &options, |
476 | PatternBenefit benefit = 1) |
477 | : OpRewritePattern<vector::ReductionOp>(context, benefit), |
478 | options(options) {} |
479 | |
480 | LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, |
481 | PatternRewriter &rewriter) const override { |
482 | std::optional<SmallVector<int64_t>> targetShape = |
483 | getTargetShape(options, reductionOp); |
484 | if (!targetShape) |
485 | return failure(); |
486 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
487 | |
488 | // Create unrolled vector reduction. |
489 | Location loc = reductionOp.getLoc(); |
490 | Value accumulator = nullptr; |
491 | for (SmallVector<int64_t> offsets : |
492 | StaticTileOffsetRange(originalSize, *targetShape)) { |
493 | SmallVector<int64_t> strides(offsets.size(), 1); |
494 | Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( |
495 | loc, reductionOp.getVector(), offsets, *targetShape, strides); |
496 | Operation *newOp = cloneOpWithOperandsAndTypes( |
497 | rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); |
498 | Value result = newOp->getResult(0); |
499 | |
500 | if (!accumulator) { |
501 | // This is the first reduction. |
502 | accumulator = result; |
503 | } else { |
504 | // On subsequent reduction, combine with the accumulator. |
505 | accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), |
506 | accumulator, result); |
507 | } |
508 | } |
509 | |
510 | rewriter.replaceOp(reductionOp, accumulator); |
511 | return success(); |
512 | } |
513 | |
514 | private: |
515 | const vector::UnrollVectorOptions options; |
516 | }; |
517 | |
518 | struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { |
519 | UnrollTransposePattern(MLIRContext *context, |
520 | const vector::UnrollVectorOptions &options, |
521 | PatternBenefit benefit = 1) |
522 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
523 | options(options) {} |
524 | |
525 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
526 | PatternRewriter &rewriter) const override { |
527 | if (transposeOp.getResultVectorType().getRank() == 0) |
528 | return failure(); |
529 | auto targetShape = getTargetShape(options, transposeOp); |
530 | if (!targetShape) |
531 | return failure(); |
532 | auto originalVectorType = transposeOp.getResultVectorType(); |
533 | SmallVector<int64_t> strides(targetShape->size(), 1); |
534 | Location loc = transposeOp.getLoc(); |
535 | ArrayRef<int64_t> originalSize = originalVectorType.getShape(); |
536 | |
537 | // Prepare the result vector; |
538 | Value result = rewriter.create<arith::ConstantOp>( |
539 | loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); |
540 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
541 | |
542 | // Unroll the computation. |
543 | for (SmallVector<int64_t> elementOffsets : |
544 | StaticTileOffsetRange(originalSize, *targetShape)) { |
545 | SmallVector<int64_t> permutedOffsets(elementOffsets.size()); |
546 | SmallVector<int64_t> permutedShape(elementOffsets.size()); |
547 | // Compute the source offsets and shape. |
548 | for (auto indices : llvm::enumerate(permutation)) { |
549 | permutedOffsets[indices.value()] = elementOffsets[indices.index()]; |
550 | permutedShape[indices.value()] = (*targetShape)[indices.index()]; |
551 | } |
552 | Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( |
553 | loc, transposeOp.getVector(), permutedOffsets, permutedShape, |
554 | strides); |
555 | Value transposedSlice = |
556 | rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); |
557 | result = rewriter.create<vector::InsertStridedSliceOp>( |
558 | loc, transposedSlice, result, elementOffsets, strides); |
559 | } |
560 | rewriter.replaceOp(transposeOp, result); |
561 | return success(); |
562 | } |
563 | |
564 | private: |
565 | vector::UnrollVectorOptions options; |
566 | }; |
567 | |
568 | struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { |
569 | UnrollGatherPattern(MLIRContext *context, |
570 | const vector::UnrollVectorOptions &options, |
571 | PatternBenefit benefit = 1) |
572 | : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { |
573 | } |
574 | |
575 | LogicalResult matchAndRewrite(vector::GatherOp gatherOp, |
576 | PatternRewriter &rewriter) const override { |
577 | VectorType sourceVectorType = gatherOp.getVectorType(); |
578 | if (sourceVectorType.getRank() == 0) |
579 | return failure(); |
580 | auto targetShape = getTargetShape(options, gatherOp); |
581 | if (!targetShape) |
582 | return failure(); |
583 | SmallVector<int64_t> strides(targetShape->size(), 1); |
584 | Location loc = gatherOp.getLoc(); |
585 | ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); |
586 | |
587 | // Prepare the result vector; |
588 | Value result = rewriter.create<arith::ConstantOp>( |
589 | loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); |
590 | auto targetType = |
591 | VectorType::get(*targetShape, sourceVectorType.getElementType()); |
592 | |
593 | SmallVector<int64_t> loopOrder = |
594 | getUnrollOrder(originalSize.size(), gatherOp, options); |
595 | for (SmallVector<int64_t> elementOffsets : |
596 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
597 | // To get the unrolled gather, extract the same slice based on the |
598 | // decomposed shape from each of the index, mask, and pass-through |
599 | // vectors. |
600 | Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>( |
601 | loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); |
602 | Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>( |
603 | loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); |
604 | Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>( |
605 | loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); |
606 | auto slicedGather = rewriter.create<vector::GatherOp>( |
607 | loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), |
608 | indexSubVec, maskSubVec, passThruSubVec); |
609 | |
610 | result = rewriter.create<vector::InsertStridedSliceOp>( |
611 | loc, slicedGather, result, elementOffsets, strides); |
612 | } |
613 | rewriter.replaceOp(gatherOp, result); |
614 | return success(); |
615 | } |
616 | |
617 | private: |
618 | vector::UnrollVectorOptions options; |
619 | }; |
620 | |
621 | } // namespace |
622 | |
623 | void mlir::vector::populateVectorUnrollPatterns( |
624 | RewritePatternSet &patterns, const UnrollVectorOptions &options, |
625 | PatternBenefit benefit) { |
626 | patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, |
627 | UnrollContractionPattern, UnrollElementwisePattern, |
628 | UnrollReductionPattern, UnrollMultiReductionPattern, |
629 | UnrollTransposePattern, UnrollGatherPattern>( |
630 | arg: patterns.getContext(), args: options, args&: benefit); |
631 | } |
632 | |