1 | //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to do vector unrolling and vector distribution. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
15 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
16 | #include "mlir/Interfaces/VectorInterfaces.h" |
17 | #include "llvm/ADT/MapVector.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/Support/Debug.h" |
20 | #include "llvm/Support/InterleavedRange.h" |
21 | #include <optional> |
22 | |
23 | #define DEBUG_TYPE "vector-unroll" |
24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
25 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::vector; |
29 | |
30 | /// Compute the indices of the slice `index` for a transfer op. |
31 | static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, |
32 | ArrayRef<Value> indices, |
33 | AffineMap permutationMap, |
34 | Location loc, |
35 | OpBuilder &builder) { |
36 | MLIRContext *ctx = builder.getContext(); |
37 | auto isBroadcast = [](AffineExpr expr) { |
38 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: expr)) |
39 | return constExpr.getValue() == 0; |
40 | return false; |
41 | }; |
42 | // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. |
43 | SmallVector<Value> slicedIndices(indices); |
44 | for (const auto &dim : llvm::enumerate(First: permutationMap.getResults())) { |
45 | if (isBroadcast(dim.value())) |
46 | continue; |
47 | unsigned pos = cast<AffineDimExpr>(Val: dim.value()).getPosition(); |
48 | auto expr = getAffineDimExpr(position: 0, context: builder.getContext()) + |
49 | getAffineConstantExpr(constant: elementOffsets[dim.index()], context: ctx); |
50 | auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, result: expr); |
51 | slicedIndices[pos] = |
52 | builder.create<affine::AffineApplyOp>(loc, map, indices[pos]); |
53 | } |
54 | return slicedIndices; |
55 | } |
56 | |
57 | // Clones `op` into a new operations that takes `operands` and returns |
58 | // `resultTypes`. |
59 | static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, |
60 | Operation *op, |
61 | ArrayRef<Value> operands, |
62 | ArrayRef<Type> resultTypes) { |
63 | return builder.create(loc, op->getName().getIdentifier(), operands, |
64 | resultTypes, op->getAttrs()); |
65 | } |
66 | |
67 | /// Return the target shape for unrolling for the given `op`. Return |
68 | /// std::nullopt if the op shouldn't be or cannot be unrolled. |
69 | static std::optional<SmallVector<int64_t>> |
70 | getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { |
71 | LDBG(""); |
72 | LDBG("Get unroll shape for op "<< op->getName().getStringRef()); |
73 | if (options.filterConstraint && failed(options.filterConstraint(op))) { |
74 | LDBG("--no filter constraint -> BAIL"); |
75 | return std::nullopt; |
76 | } |
77 | assert(options.nativeShape && |
78 | "vector unrolling expects the native shape or native" |
79 | "shape call back function to be set"); |
80 | auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); |
81 | if (!unrollableVectorOp) { |
82 | LDBG("--not an unrollable op -> BAIL"); |
83 | return std::nullopt; |
84 | } |
85 | auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); |
86 | if (!maybeUnrollShape) { |
87 | LDBG("--could not get shape of op "<< *op << " -> BAIL"); |
88 | return std::nullopt; |
89 | } |
90 | LDBG("--vector op shape: "<< llvm::interleaved(*maybeUnrollShape)); |
91 | |
92 | std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); |
93 | if (!targetShape) { |
94 | LDBG("--no unrolling target shape defined "<< *op << "-> SKIP"); |
95 | return std::nullopt; |
96 | } |
97 | LDBG("--target shape: "<< llvm::interleaved(*targetShape)); |
98 | |
99 | auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); |
100 | if (!maybeShapeRatio) { |
101 | LDBG("--could not compute integral shape ratio -> BAIL"); |
102 | return std::nullopt; |
103 | } |
104 | if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { |
105 | LDBG("--no unrolling needed -> SKIP"); |
106 | return std::nullopt; |
107 | } |
108 | LDBG("--found an integral shape ratio to unroll to -> SUCCESS"); |
109 | return targetShape; |
110 | } |
111 | |
112 | static SmallVector<int64_t> |
113 | getUnrollOrder(unsigned numLoops, Operation *op, |
114 | const vector::UnrollVectorOptions &options) { |
115 | SmallVector<int64_t> loopOrder = |
116 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: static_cast<int64_t>(numLoops))); |
117 | if (options.traversalOrderCallback != nullptr) { |
118 | std::optional<SmallVector<int64_t>> order = |
119 | options.traversalOrderCallback(op); |
120 | if (order) { |
121 | loopOrder = std::move(*order); |
122 | } |
123 | } |
124 | return loopOrder; |
125 | } |
126 | |
127 | namespace { |
128 | |
129 | struct UnrollTransferReadPattern |
130 | : public OpRewritePattern<vector::TransferReadOp> { |
131 | UnrollTransferReadPattern(MLIRContext *context, |
132 | const vector::UnrollVectorOptions &options, |
133 | PatternBenefit benefit = 1) |
134 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
135 | options(options) {} |
136 | |
137 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
138 | PatternRewriter &rewriter) const override { |
139 | // TODO: support 0-d corner case. |
140 | if (readOp.getTransferRank() == 0) |
141 | return failure(); |
142 | if (readOp.getMask()) |
143 | return failure(); |
144 | auto targetShape = getTargetShape(options, readOp); |
145 | if (!targetShape) |
146 | return failure(); |
147 | auto sourceVectorType = readOp.getVectorType(); |
148 | SmallVector<int64_t> strides(targetShape->size(), 1); |
149 | Location loc = readOp.getLoc(); |
150 | ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); |
151 | |
152 | // Prepare the result vector; |
153 | Value result = rewriter.create<arith::ConstantOp>( |
154 | loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); |
155 | auto targetType = |
156 | VectorType::get(*targetShape, sourceVectorType.getElementType()); |
157 | SmallVector<Value> originalIndices(readOp.getIndices().begin(), |
158 | readOp.getIndices().end()); |
159 | SmallVector<int64_t> loopOrder = |
160 | getUnrollOrder(originalSize.size(), readOp, options); |
161 | for (SmallVector<int64_t> elementOffsets : |
162 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
163 | SmallVector<Value> indices = |
164 | sliceTransferIndices(elementOffsets, originalIndices, |
165 | readOp.getPermutationMap(), loc, rewriter); |
166 | auto slicedRead = rewriter.create<vector::TransferReadOp>( |
167 | loc, targetType, readOp.getBase(), indices, |
168 | readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), |
169 | readOp.getInBoundsAttr()); |
170 | |
171 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
172 | loc, slicedRead, result, elementOffsets, strides); |
173 | } |
174 | rewriter.replaceOp(readOp, result); |
175 | return success(); |
176 | } |
177 | |
178 | private: |
179 | vector::UnrollVectorOptions options; |
180 | }; |
181 | |
182 | struct UnrollTransferWritePattern |
183 | : public OpRewritePattern<vector::TransferWriteOp> { |
184 | UnrollTransferWritePattern(MLIRContext *context, |
185 | const vector::UnrollVectorOptions &options, |
186 | PatternBenefit benefit = 1) |
187 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
188 | options(options) {} |
189 | |
190 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
191 | PatternRewriter &rewriter) const override { |
192 | // TODO: support 0-d corner case. |
193 | if (writeOp.getTransferRank() == 0) |
194 | return failure(); |
195 | |
196 | if (writeOp.getMask()) |
197 | return failure(); |
198 | auto targetShape = getTargetShape(options, writeOp); |
199 | if (!targetShape) |
200 | return failure(); |
201 | auto sourceVectorType = writeOp.getVectorType(); |
202 | SmallVector<int64_t> strides(targetShape->size(), 1); |
203 | Location loc = writeOp.getLoc(); |
204 | ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); |
205 | SmallVector<Value> originalIndices(writeOp.getIndices().begin(), |
206 | writeOp.getIndices().end()); |
207 | SmallVector<int64_t> loopOrder = |
208 | getUnrollOrder(originalSize.size(), writeOp, options); |
209 | Value resultTensor; |
210 | for (SmallVector<int64_t> elementOffsets : |
211 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
212 | Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
213 | loc, writeOp.getVector(), elementOffsets, *targetShape, strides); |
214 | SmallVector<Value> indices = |
215 | sliceTransferIndices(elementOffsets, originalIndices, |
216 | writeOp.getPermutationMap(), loc, rewriter); |
217 | Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( |
218 | loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(), |
219 | indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); |
220 | // For the tensor case update the destination for the next transfer write. |
221 | if (!slicedWrite->getResults().empty()) |
222 | resultTensor = slicedWrite->getResult(0); |
223 | } |
224 | if (resultTensor) |
225 | rewriter.replaceOp(writeOp, resultTensor); |
226 | else |
227 | rewriter.eraseOp(op: writeOp); |
228 | return success(); |
229 | } |
230 | |
231 | private: |
232 | vector::UnrollVectorOptions options; |
233 | }; |
234 | |
235 | struct OffsetMapInfo { |
236 | static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } |
237 | |
238 | static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } |
239 | |
240 | static unsigned getHashValue(const SmallVector<int64_t> &v) { |
241 | return static_cast<unsigned>(llvm::hash_combine_range(R: v)); |
242 | } |
243 | |
244 | static bool isEqual(const SmallVector<int64_t> &lhs, |
245 | const SmallVector<int64_t> &rhs) { |
246 | return lhs == rhs; |
247 | } |
248 | }; |
249 | |
250 | struct UnrollContractionPattern |
251 | : public OpRewritePattern<vector::ContractionOp> { |
252 | UnrollContractionPattern(MLIRContext *context, |
253 | const vector::UnrollVectorOptions &options, |
254 | PatternBenefit benefit = 1) |
255 | : OpRewritePattern<vector::ContractionOp>(context, benefit), |
256 | options(options) {} |
257 | |
258 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
259 | PatternRewriter &rewriter) const override { |
260 | auto targetShape = getTargetShape(options, contractOp); |
261 | if (!targetShape) |
262 | return failure(); |
263 | auto dstVecType = cast<VectorType>(contractOp.getResultType()); |
264 | SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); |
265 | |
266 | Location loc = contractOp.getLoc(); |
267 | unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); |
268 | AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; |
269 | llvm::MapVector< |
270 | SmallVector<int64_t>, Value, |
271 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
272 | accCache; |
273 | |
274 | SmallVector<int64_t> loopOrder = getUnrollOrder( |
275 | contractOp.getIteratorTypes().size(), contractOp, options); |
276 | |
277 | for (SmallVector<int64_t> offsets : |
278 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
279 | SmallVector<Value> slicesOperands(contractOp.getNumOperands()); |
280 | |
281 | // Helper to compute the new shape of each operand and extract the slice. |
282 | auto extractOperand = [&](unsigned index, Value operand, |
283 | AffineMap permutationMap, |
284 | ArrayRef<int64_t> operandOffets) { |
285 | SmallVector<int64_t> operandShape = applyPermutationMap( |
286 | permutationMap, ArrayRef<int64_t>(*targetShape)); |
287 | SmallVector<int64_t> operandStrides(operandOffets.size(), 1); |
288 | slicesOperands[index] = |
289 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
290 | loc, operand, operandOffets, operandShape, operandStrides); |
291 | }; |
292 | |
293 | // Extract the new lhs operand. |
294 | AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; |
295 | SmallVector<int64_t> lhsOffets = |
296 | applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); |
297 | extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); |
298 | |
299 | // Extract the new rhs operand. |
300 | AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; |
301 | SmallVector<int64_t> rhsOffets = |
302 | applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); |
303 | extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); |
304 | |
305 | AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; |
306 | SmallVector<int64_t> accOffets = |
307 | applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); |
308 | // If a version of the accumulator has already been computed, use it |
309 | // otherwise extract the first version from the original operand. |
310 | auto *accIt = accCache.find(accOffets); |
311 | if (accIt != accCache.end()) |
312 | slicesOperands[2] = accIt->second; |
313 | else |
314 | extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); |
315 | |
316 | SmallVector<int64_t> dstShape = |
317 | applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); |
318 | auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); |
319 | Operation *newOp = cloneOpWithOperandsAndTypes( |
320 | rewriter, loc, contractOp, slicesOperands, targetType); |
321 | |
322 | SmallVector<int64_t> dstOffets = |
323 | applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); |
324 | // Save the accumulated value untill all the loops are unrolled since |
325 | // reduction loop keep updating the accumulator. |
326 | accCache[dstOffets] = newOp->getResult(0); |
327 | } |
328 | // Assemble back the accumulator into a single vector. |
329 | Value result = rewriter.create<arith::ConstantOp>( |
330 | loc, dstVecType, rewriter.getZeroAttr(dstVecType)); |
331 | for (const auto &it : accCache) { |
332 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
333 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
334 | loc, it.second, result, it.first, dstStrides); |
335 | } |
336 | rewriter.replaceOp(contractOp, result); |
337 | return success(); |
338 | } |
339 | |
340 | private: |
341 | vector::UnrollVectorOptions options; |
342 | }; |
343 | |
344 | struct UnrollMultiReductionPattern |
345 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
346 | UnrollMultiReductionPattern(MLIRContext *context, |
347 | const vector::UnrollVectorOptions &options, |
348 | PatternBenefit benefit = 1) |
349 | : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
350 | options(options) {} |
351 | |
352 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, |
353 | PatternRewriter &rewriter) const override { |
354 | auto resultType = reductionOp->getResult(0).getType(); |
355 | if (resultType.isIntOrFloat()) { |
356 | return rewriter.notifyMatchFailure(reductionOp, |
357 | "Unrolling scalars is not supported"); |
358 | } |
359 | std::optional<SmallVector<int64_t>> targetShape = |
360 | getTargetShape(options, reductionOp); |
361 | if (!targetShape) |
362 | return failure(); |
363 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
364 | llvm::MapVector< |
365 | SmallVector<int64_t>, Value, |
366 | llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> |
367 | accCache; |
368 | Location loc = reductionOp.getLoc(); |
369 | |
370 | // Stride of the ratios, this gives us the offsets of sliceCount in a basis |
371 | // of multiples of the targetShape. |
372 | for (SmallVector<int64_t> offsets : |
373 | StaticTileOffsetRange(originalSize, *targetShape)) { |
374 | SmallVector<Value> operands; |
375 | SmallVector<int64_t> operandStrides(offsets.size(), 1); |
376 | Value slicedOperand = |
377 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
378 | loc, reductionOp.getSource(), offsets, *targetShape, |
379 | operandStrides); |
380 | operands.push_back(slicedOperand); |
381 | SmallVector<int64_t> dstShape; |
382 | SmallVector<int64_t> destOffset; |
383 | for (size_t i : llvm::seq(size_t(0), targetShape->size())) { |
384 | if (!reductionOp.isReducedDim(i)) { |
385 | destOffset.push_back(offsets[i]); |
386 | dstShape.push_back((*targetShape)[i]); |
387 | } |
388 | } |
389 | Value acc; |
390 | SmallVector<int64_t> accStrides(destOffset.size(), 1); |
391 | // If a version of the accumulator has already been computed, use it |
392 | // otherwise extract the first version from the original operand. |
393 | auto *accIt = accCache.find(destOffset); |
394 | if (accIt != accCache.end()) |
395 | acc = accIt->second; |
396 | else |
397 | acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
398 | loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); |
399 | operands.push_back(acc); |
400 | auto targetType = VectorType::get( |
401 | dstShape, reductionOp.getSourceVectorType().getElementType()); |
402 | Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, |
403 | operands, targetType); |
404 | Value result = newOp->getResult(0); |
405 | accCache[destOffset] = result; |
406 | } |
407 | // Assemble back the accumulator into a single vector. |
408 | Value result = rewriter.create<arith::ConstantOp>( |
409 | loc, reductionOp.getDestType(), |
410 | rewriter.getZeroAttr(reductionOp.getDestType())); |
411 | for (const auto &it : accCache) { |
412 | SmallVector<int64_t> dstStrides(it.first.size(), 1); |
413 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
414 | loc, it.second, result, it.first, dstStrides); |
415 | } |
416 | rewriter.replaceOp(reductionOp, result); |
417 | return success(); |
418 | } |
419 | |
420 | private: |
421 | vector::UnrollVectorOptions options; |
422 | }; |
423 | |
424 | struct UnrollElementwisePattern : public RewritePattern { |
425 | UnrollElementwisePattern(MLIRContext *context, |
426 | const vector::UnrollVectorOptions &options, |
427 | PatternBenefit benefit = 1) |
428 | : RewritePattern(MatchAnyOpTypeTag(), benefit, context), |
429 | options(options) {} |
430 | |
431 | LogicalResult matchAndRewrite(Operation *op, |
432 | PatternRewriter &rewriter) const override { |
433 | if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
434 | return failure(); |
435 | auto targetShape = getTargetShape(options, op); |
436 | if (!targetShape) |
437 | return failure(); |
438 | auto dstVecType = cast<VectorType>(op->getResult(idx: 0).getType()); |
439 | SmallVector<int64_t> originalSize = |
440 | *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); |
441 | // Bail-out if rank(source) != rank(target). The main limitation here is the |
442 | // fact that `ExtractStridedSlice` requires the rank for the input and |
443 | // output to match. If needed, we can relax this later. |
444 | if (originalSize.size() != targetShape->size()) |
445 | return rewriter.notifyMatchFailure( |
446 | arg&: op, msg: "expected input vector rank to match target shape rank"); |
447 | Location loc = op->getLoc(); |
448 | // Prepare the result vector. |
449 | Value result = rewriter.create<arith::ConstantOp>( |
450 | loc, dstVecType, rewriter.getZeroAttr(dstVecType)); |
451 | SmallVector<int64_t> strides(targetShape->size(), 1); |
452 | VectorType newVecType = |
453 | VectorType::get(*targetShape, dstVecType.getElementType()); |
454 | |
455 | // Create the unrolled computation. |
456 | for (SmallVector<int64_t> offsets : |
457 | StaticTileOffsetRange(originalSize, *targetShape)) { |
458 | SmallVector<Value> extractOperands; |
459 | for (OpOperand &operand : op->getOpOperands()) { |
460 | auto vecType = dyn_cast<VectorType>(operand.get().getType()); |
461 | if (!vecType) { |
462 | extractOperands.push_back(operand.get()); |
463 | continue; |
464 | } |
465 | extractOperands.push_back( |
466 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
467 | loc, operand.get(), offsets, *targetShape, strides)); |
468 | } |
469 | Operation *newOp = cloneOpWithOperandsAndTypes( |
470 | rewriter, loc, op, extractOperands, newVecType); |
471 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
472 | loc, newOp->getResult(0), result, offsets, strides); |
473 | } |
474 | rewriter.replaceOp(op, newValues: result); |
475 | return success(); |
476 | } |
477 | |
478 | private: |
479 | vector::UnrollVectorOptions options; |
480 | }; |
481 | |
482 | struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { |
483 | UnrollReductionPattern(MLIRContext *context, |
484 | const vector::UnrollVectorOptions &options, |
485 | PatternBenefit benefit = 1) |
486 | : OpRewritePattern<vector::ReductionOp>(context, benefit), |
487 | options(options) {} |
488 | |
489 | LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, |
490 | PatternRewriter &rewriter) const override { |
491 | std::optional<SmallVector<int64_t>> targetShape = |
492 | getTargetShape(options, reductionOp); |
493 | if (!targetShape) |
494 | return failure(); |
495 | SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); |
496 | |
497 | // Create unrolled vector reduction. |
498 | Location loc = reductionOp.getLoc(); |
499 | Value accumulator = nullptr; |
500 | for (SmallVector<int64_t> offsets : |
501 | StaticTileOffsetRange(originalSize, *targetShape)) { |
502 | SmallVector<int64_t> strides(offsets.size(), 1); |
503 | Value slicedOperand = |
504 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
505 | loc, reductionOp.getVector(), offsets, *targetShape, strides); |
506 | Operation *newOp = cloneOpWithOperandsAndTypes( |
507 | rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); |
508 | Value result = newOp->getResult(0); |
509 | |
510 | if (!accumulator) { |
511 | // This is the first reduction. |
512 | accumulator = result; |
513 | } else { |
514 | // On subsequent reduction, combine with the accumulator. |
515 | accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), |
516 | accumulator, result); |
517 | } |
518 | } |
519 | |
520 | rewriter.replaceOp(reductionOp, accumulator); |
521 | return success(); |
522 | } |
523 | |
524 | private: |
525 | const vector::UnrollVectorOptions options; |
526 | }; |
527 | |
528 | struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { |
529 | UnrollTransposePattern(MLIRContext *context, |
530 | const vector::UnrollVectorOptions &options, |
531 | PatternBenefit benefit = 1) |
532 | : OpRewritePattern<vector::TransposeOp>(context, benefit), |
533 | options(options) {} |
534 | |
535 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
536 | PatternRewriter &rewriter) const override { |
537 | if (transposeOp.getResultVectorType().getRank() == 0) |
538 | return failure(); |
539 | auto targetShape = getTargetShape(options, transposeOp); |
540 | if (!targetShape) |
541 | return failure(); |
542 | auto originalVectorType = transposeOp.getResultVectorType(); |
543 | SmallVector<int64_t> strides(targetShape->size(), 1); |
544 | Location loc = transposeOp.getLoc(); |
545 | ArrayRef<int64_t> originalSize = originalVectorType.getShape(); |
546 | |
547 | // Prepare the result vector; |
548 | Value result = rewriter.create<arith::ConstantOp>( |
549 | loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); |
550 | ArrayRef<int64_t> permutation = transposeOp.getPermutation(); |
551 | |
552 | // Unroll the computation. |
553 | for (SmallVector<int64_t> elementOffsets : |
554 | StaticTileOffsetRange(originalSize, *targetShape)) { |
555 | SmallVector<int64_t> permutedOffsets(elementOffsets.size()); |
556 | SmallVector<int64_t> permutedShape(elementOffsets.size()); |
557 | // Compute the source offsets and shape. |
558 | for (auto indices : llvm::enumerate(permutation)) { |
559 | permutedOffsets[indices.value()] = elementOffsets[indices.index()]; |
560 | permutedShape[indices.value()] = (*targetShape)[indices.index()]; |
561 | } |
562 | Value slicedOperand = |
563 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
564 | loc, transposeOp.getVector(), permutedOffsets, permutedShape, |
565 | strides); |
566 | Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>( |
567 | loc, slicedOperand, permutation); |
568 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
569 | loc, transposedSlice, result, elementOffsets, strides); |
570 | } |
571 | rewriter.replaceOp(transposeOp, result); |
572 | return success(); |
573 | } |
574 | |
575 | private: |
576 | vector::UnrollVectorOptions options; |
577 | }; |
578 | |
579 | struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { |
580 | UnrollGatherPattern(MLIRContext *context, |
581 | const vector::UnrollVectorOptions &options, |
582 | PatternBenefit benefit = 1) |
583 | : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { |
584 | } |
585 | |
586 | LogicalResult matchAndRewrite(vector::GatherOp gatherOp, |
587 | PatternRewriter &rewriter) const override { |
588 | VectorType sourceVectorType = gatherOp.getVectorType(); |
589 | if (sourceVectorType.getRank() == 0) |
590 | return failure(); |
591 | auto targetShape = getTargetShape(options, gatherOp); |
592 | if (!targetShape) |
593 | return failure(); |
594 | SmallVector<int64_t> strides(targetShape->size(), 1); |
595 | Location loc = gatherOp.getLoc(); |
596 | ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); |
597 | |
598 | // Prepare the result vector; |
599 | Value result = rewriter.create<arith::ConstantOp>( |
600 | loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); |
601 | auto targetType = |
602 | VectorType::get(*targetShape, sourceVectorType.getElementType()); |
603 | |
604 | SmallVector<int64_t> loopOrder = |
605 | getUnrollOrder(originalSize.size(), gatherOp, options); |
606 | for (SmallVector<int64_t> elementOffsets : |
607 | StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { |
608 | // To get the unrolled gather, extract the same slice based on the |
609 | // decomposed shape from each of the index, mask, and pass-through |
610 | // vectors. |
611 | Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
612 | loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); |
613 | Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
614 | loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); |
615 | Value passThruSubVec = |
616 | rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
617 | loc, gatherOp.getPassThru(), elementOffsets, *targetShape, |
618 | strides); |
619 | auto slicedGather = rewriter.create<vector::GatherOp>( |
620 | loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), |
621 | indexSubVec, maskSubVec, passThruSubVec); |
622 | |
623 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
624 | loc, slicedGather, result, elementOffsets, strides); |
625 | } |
626 | rewriter.replaceOp(gatherOp, result); |
627 | return success(); |
628 | } |
629 | |
630 | private: |
631 | vector::UnrollVectorOptions options; |
632 | }; |
633 | |
634 | struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> { |
635 | UnrollBroadcastPattern(MLIRContext *context, |
636 | const vector::UnrollVectorOptions &options, |
637 | PatternBenefit benefit = 1) |
638 | : OpRewritePattern<vector::BroadcastOp>(context, benefit), |
639 | options(options) {} |
640 | |
641 | LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, |
642 | PatternRewriter &rewriter) const override { |
643 | auto targetShape = getTargetShape(options, broadcastOp); |
644 | if (!targetShape) |
645 | return failure(); |
646 | |
647 | Location loc = broadcastOp.getLoc(); |
648 | VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); |
649 | VectorType resType = broadcastOp.getResultVectorType(); |
650 | VectorType targetType = |
651 | resType.cloneWith(*targetShape, resType.getElementType()); |
652 | Value result = rewriter.create<arith::ConstantOp>( |
653 | loc, resType, rewriter.getZeroAttr(resType)); |
654 | |
655 | SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll(); |
656 | SmallVector<int64_t> strides(originalShape.size(), 1); |
657 | |
658 | for (SmallVector<int64_t> offsets : |
659 | StaticTileOffsetRange(originalShape, *targetShape)) { |
660 | Value newSrc; |
661 | if (!srcType) { |
662 | // Scalar to vector broadcast. |
663 | newSrc = broadcastOp.getSource(); |
664 | } else { |
665 | // Vector to vector broadcast. |
666 | int64_t rank = srcType.getRank(); |
667 | SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end()); |
668 | SmallVector<int64_t> srcShape(targetShape->end() - rank, |
669 | targetShape->end()); |
670 | SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end()); |
671 | // adjust the offset and shape for src if the corresponding dim is 1. |
672 | for (int64_t i = 0; i < rank; ++i) { |
673 | if (srcType.getDimSize(i) == 1) { |
674 | srcOffsets[i] = 0; |
675 | srcShape[i] = 1; |
676 | } |
677 | } |
678 | newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>( |
679 | loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides); |
680 | } |
681 | |
682 | Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp, |
683 | newSrc, targetType); |
684 | |
685 | result = rewriter.createOrFold<vector::InsertStridedSliceOp>( |
686 | loc, newOp->getResult(0), result, offsets, strides); |
687 | } |
688 | |
689 | rewriter.replaceOp(broadcastOp, result); |
690 | return success(); |
691 | } |
692 | |
693 | private: |
694 | vector::UnrollVectorOptions options; |
695 | }; |
696 | |
697 | } // namespace |
698 | |
699 | void mlir::vector::populateVectorUnrollPatterns( |
700 | RewritePatternSet &patterns, const UnrollVectorOptions &options, |
701 | PatternBenefit benefit) { |
702 | patterns |
703 | .add<UnrollTransferReadPattern, UnrollTransferWritePattern, |
704 | UnrollContractionPattern, UnrollElementwisePattern, |
705 | UnrollReductionPattern, UnrollMultiReductionPattern, |
706 | UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>( |
707 | arg: patterns.getContext(), args: options, args&: benefit); |
708 | } |
709 |
Definitions
- sliceTransferIndices
- cloneOpWithOperandsAndTypes
- getTargetShape
- getUnrollOrder
- UnrollTransferReadPattern
- UnrollTransferReadPattern
- matchAndRewrite
- UnrollTransferWritePattern
- UnrollTransferWritePattern
- matchAndRewrite
- OffsetMapInfo
- getEmptyKey
- getTombstoneKey
- getHashValue
- isEqual
- UnrollContractionPattern
- UnrollContractionPattern
- matchAndRewrite
- UnrollMultiReductionPattern
- UnrollMultiReductionPattern
- matchAndRewrite
- UnrollElementwisePattern
- UnrollElementwisePattern
- matchAndRewrite
- UnrollReductionPattern
- UnrollReductionPattern
- matchAndRewrite
- UnrollTransposePattern
- UnrollTransposePattern
- matchAndRewrite
- UnrollGatherPattern
- UnrollGatherPattern
- matchAndRewrite
- UnrollBroadcastPattern
- UnrollBroadcastPattern
- matchAndRewrite
Improve your Profiling and Debugging skills
Find out more