1 | //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the linalg dialect Vectorization transformations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | #include "mlir/Dialect/Affine/Utils.h" |
13 | |
14 | #include "mlir/Analysis/SliceAnalysis.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
26 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
28 | #include "mlir/IR/AffineExpr.h" |
29 | #include "mlir/IR/Builders.h" |
30 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
31 | #include "mlir/IR/BuiltinTypes.h" |
32 | #include "mlir/IR/OpDefinition.h" |
33 | #include "mlir/IR/PatternMatch.h" |
34 | #include "mlir/Support/LLVM.h" |
35 | #include "mlir/Transforms/RegionUtils.h" |
36 | #include "llvm/ADT/STLExtras.h" |
37 | #include "llvm/ADT/Sequence.h" |
38 | #include "llvm/ADT/SmallVector.h" |
39 | #include "llvm/ADT/TypeSwitch.h" |
40 | #include "llvm/ADT/iterator_range.h" |
41 | #include "llvm/Support/Debug.h" |
42 | #include "llvm/Support/MathExtras.h" |
43 | #include "llvm/Support/raw_ostream.h" |
44 | #include <optional> |
45 | #include <type_traits> |
46 | |
47 | using namespace mlir; |
48 | using namespace mlir::linalg; |
49 | |
50 | #define DEBUG_TYPE "linalg-vectorization" |
51 | |
52 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
53 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
54 | |
55 | /// Try to vectorize `convOp` as a convolution. |
56 | static FailureOr<Operation *> |
57 | vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, |
58 | ArrayRef<int64_t> inputVecSizes = {}, |
59 | ArrayRef<bool> inputVecScalableFlags = {}, |
60 | bool flatten1DDepthwiseConv = false); |
61 | |
62 | /// Return the unique instance of OpType in `block` if it is indeed unique. |
63 | /// Return null if none or more than 1 instances exist. |
64 | template <typename OpType> |
65 | static OpType getSingleOpOfType(Block &block) { |
66 | OpType res; |
67 | block.walk([&](OpType op) { |
68 | if (res) { |
69 | res = nullptr; |
70 | return WalkResult::interrupt(); |
71 | } |
72 | res = op; |
73 | return WalkResult::advance(); |
74 | }); |
75 | return res; |
76 | } |
77 | |
78 | /// Helper function to extract the input slices after filter is unrolled along |
79 | /// kw. |
80 | static SmallVector<Value> |
81 | (RewriterBase &rewriter, Location loc, Value input, |
82 | int64_t nSize, int64_t wSize, int64_t cSize, |
83 | int64_t kwSize, int strideW, int dilationW, |
84 | int64_t wSizeStep, bool isSingleChanneled) { |
85 | SmallVector<Value> result; |
86 | if (isSingleChanneled) { |
87 | // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled |
88 | // convolution. |
89 | SmallVector<int64_t> sizes{wSizeStep}; |
90 | SmallVector<int64_t> strides{1}; |
91 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
92 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
93 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
94 | loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides)); |
95 | } |
96 | } |
97 | } else { |
98 | // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] |
99 | // for channeled convolution. |
100 | SmallVector<int64_t> sizes{nSize, wSizeStep, cSize}; |
101 | SmallVector<int64_t> strides{1, 1, 1}; |
102 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
103 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
104 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
105 | loc, input, |
106 | /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
107 | sizes, strides)); |
108 | } |
109 | } |
110 | } |
111 | return result; |
112 | } |
113 | |
114 | /// Helper function to extract the filter slices after filter is unrolled along |
115 | /// kw. |
116 | static SmallVector<Value> (RewriterBase &rewriter, |
117 | Location loc, Value filter, |
118 | int64_t kwSize) { |
119 | SmallVector<Value> result; |
120 | // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for |
121 | // non-chanelled convolution] @ [kw]. |
122 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
123 | result.push_back(rewriter.create<vector::ExtractOp>( |
124 | loc, filter, /*offsets=*/ArrayRef<int64_t>{kw})); |
125 | } |
126 | return result; |
127 | } |
128 | |
129 | /// Helper function to extract the result slices after filter is unrolled along |
130 | /// kw. |
131 | static SmallVector<Value> |
132 | (RewriterBase &rewriter, Location loc, Value res, |
133 | int64_t nSize, int64_t wSize, int64_t fSize, |
134 | int64_t wSizeStep, bool isSingleChanneled) { |
135 | SmallVector<Value> result; |
136 | if (isSingleChanneled) { |
137 | // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. |
138 | SmallVector<int64_t> sizes{wSizeStep}; |
139 | SmallVector<int64_t> strides{1}; |
140 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
141 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
142 | loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides)); |
143 | } |
144 | } else { |
145 | // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
146 | // convolution. |
147 | SmallVector<int64_t> sizes{nSize, wSizeStep, fSize}; |
148 | SmallVector<int64_t> strides{1, 1, 1}; |
149 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
150 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
151 | loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides)); |
152 | } |
153 | } |
154 | return result; |
155 | } |
156 | |
157 | /// Helper function to insert the computed result slices. |
158 | static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, |
159 | Value res, int64_t wSize, int64_t wSizeStep, |
160 | SmallVectorImpl<Value> &resVals, |
161 | bool isSingleChanneled) { |
162 | |
163 | if (isSingleChanneled) { |
164 | // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. |
165 | // This does not depend on kw. |
166 | SmallVector<int64_t> strides{1}; |
167 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
168 | res = rewriter.create<vector::InsertStridedSliceOp>( |
169 | loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides); |
170 | } |
171 | } else { |
172 | // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
173 | // convolution. This does not depend on kw. |
174 | SmallVector<int64_t> strides{1, 1, 1}; |
175 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
176 | res = rewriter.create<vector::InsertStridedSliceOp>( |
177 | loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, |
178 | strides); |
179 | } |
180 | } |
181 | return res; |
182 | } |
183 | |
184 | /// Contains the vectorization state and related methods used across the |
185 | /// vectorization process of a given operation. |
186 | struct VectorizationState { |
187 | VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {} |
188 | |
189 | /// Initializes the vectorization state, including the computation of the |
190 | /// canonical vector shape for vectorization. |
191 | LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, |
192 | ArrayRef<int64_t> inputVectorSizes, |
193 | ArrayRef<bool> inputScalableVecDims); |
194 | |
195 | /// Returns the canonical vector shape used to vectorize the iteration space. |
196 | ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } |
197 | |
198 | /// Returns a vector type of the provided `elementType` with the canonical |
199 | /// vector shape and the corresponding fixed/scalable dimensions bit. If |
200 | /// `dimPermutation` is provided, the canonical vector dimensions are permuted |
201 | /// accordingly. |
202 | VectorType getCanonicalVecType( |
203 | Type elementType, |
204 | std::optional<AffineMap> dimPermutation = std::nullopt) const { |
205 | SmallVector<int64_t> vectorShape; |
206 | SmallVector<bool> scalableDims; |
207 | if (dimPermutation.has_value()) { |
208 | vectorShape = |
209 | applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape); |
210 | scalableDims = |
211 | applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims); |
212 | } else { |
213 | vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end()); |
214 | scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end()); |
215 | } |
216 | |
217 | return VectorType::get(vectorShape, elementType, scalableDims); |
218 | } |
219 | |
220 | /// Masks an operation with the canonical vector mask if the operation needs |
221 | /// masking. Returns the masked operation or the original operation if masking |
222 | /// is not needed. If provided, the canonical mask for this operation is |
223 | /// permuted using `maybeMaskingMap`. |
224 | Operation * |
225 | maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
226 | std::optional<AffineMap> maybeMaskingMap = std::nullopt); |
227 | |
228 | private: |
229 | /// Initializes the iteration space static sizes using the Linalg op |
230 | /// information. This may become more complicated in the future. |
231 | void initIterSpaceStaticSizes(LinalgOp linalgOp) { |
232 | iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges()); |
233 | } |
234 | |
235 | /// Generates 'arith.constant' and 'tensor/memref.dim' operations for |
236 | /// all the static and dynamic dimensions of the iteration space to be |
237 | /// vectorized and store them in `iterSpaceValueSizes`. |
238 | LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
239 | LinalgOp linalgOp); |
240 | |
241 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
242 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is |
243 | /// permuted using that permutation map. If a new mask is created, it will be |
244 | /// cached for future users. |
245 | Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, |
246 | LinalgOp linalgOp, |
247 | std::optional<AffineMap> maybeMaskingMap); |
248 | |
249 | // Holds the compile-time static sizes of the iteration space to vectorize. |
250 | // Dynamic dimensions are represented using ShapedType::kDynamic. |
251 | SmallVector<int64_t> iterSpaceStaticSizes; |
252 | |
253 | /// Holds the value sizes of the iteration space to vectorize. Static |
254 | /// dimensions are represented by 'arith.constant' and dynamic |
255 | /// dimensions by 'tensor/memref.dim'. |
256 | SmallVector<Value> iterSpaceValueSizes; |
257 | |
258 | /// Holds the canonical vector shape used to vectorize the iteration space. |
259 | SmallVector<int64_t> canonicalVecShape; |
260 | |
261 | /// Holds the vector dimensions that are scalable in the canonical vector |
262 | /// shape. |
263 | SmallVector<bool> scalableVecDims; |
264 | |
265 | /// Holds the active masks for permutations of the canonical vector iteration |
266 | /// space. |
267 | DenseMap<AffineMap, Value> activeMaskCache; |
268 | |
269 | /// Global vectorization guard for the incoming rewriter. It's initialized |
270 | /// when the vectorization state is initialized. |
271 | OpBuilder::InsertionGuard rewriterGuard; |
272 | }; |
273 | |
274 | LogicalResult |
275 | VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
276 | LinalgOp linalgOp) { |
277 | // TODO: Support 0-d vectors. |
278 | for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { |
279 | if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { |
280 | // Create constant index op for static dimensions. |
281 | iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>( |
282 | linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); |
283 | continue; |
284 | } |
285 | |
286 | // Find an operand defined on this dimension of the iteration space to |
287 | // extract the runtime dimension size. |
288 | Value operand; |
289 | unsigned operandDimPos; |
290 | if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand, |
291 | operandDimPos))) |
292 | return failure(); |
293 | |
294 | Value dynamicDim = linalgOp.hasPureTensorSemantics() |
295 | ? (Value)rewriter.create<tensor::DimOp>( |
296 | linalgOp.getLoc(), operand, operandDimPos) |
297 | : (Value)rewriter.create<memref::DimOp>( |
298 | linalgOp.getLoc(), operand, operandDimPos); |
299 | iterSpaceValueSizes.push_back(Elt: dynamicDim); |
300 | } |
301 | |
302 | return success(); |
303 | } |
304 | |
305 | /// Initializes the vectorization state, including the computation of the |
306 | /// canonical vector shape for vectorization. |
307 | // TODO: Move this to the constructor when we can remove the failure cases. |
308 | LogicalResult |
309 | VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, |
310 | ArrayRef<int64_t> inputVectorSizes, |
311 | ArrayRef<bool> inputScalableVecDims) { |
312 | // Initialize the insertion point. |
313 | rewriter.setInsertionPoint(linalgOp); |
314 | |
315 | if (!inputVectorSizes.empty()) { |
316 | // Get the canonical vector shape from the input vector sizes provided. This |
317 | // path should be taken to vectorize code with dynamic shapes and when using |
318 | // vector sizes greater than the iteration space sizes. |
319 | canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end()); |
320 | scalableVecDims.append(in_start: inputScalableVecDims.begin(), |
321 | in_end: inputScalableVecDims.end()); |
322 | } else { |
323 | // Compute the canonical vector shape from the operation shape. If there are |
324 | // dynamic shapes, the operation won't be vectorized. We assume all the |
325 | // vector dimensions are fixed. |
326 | canonicalVecShape = linalgOp.getStaticLoopRanges(); |
327 | scalableVecDims.append(linalgOp.getNumLoops(), false); |
328 | } |
329 | |
330 | LDBG("Canonical vector shape: " ); |
331 | LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); |
332 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
333 | LDBG("Scalable vector dims: " ); |
334 | LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); |
335 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
336 | |
337 | if (ShapedType::isDynamicShape(canonicalVecShape)) |
338 | return failure(); |
339 | |
340 | // Initialize iteration space static sizes. |
341 | initIterSpaceStaticSizes(linalgOp: linalgOp); |
342 | |
343 | // Generate 'arith.constant' and 'tensor/memref.dim' operations for |
344 | // all the static and dynamic dimensions of the iteration space, needed to |
345 | // compute a mask during vectorization. |
346 | if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp: linalgOp))) |
347 | return failure(); |
348 | |
349 | return success(); |
350 | } |
351 | |
352 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
353 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted |
354 | /// using that permutation map. If a new mask is created, it will be cached for |
355 | /// future users. |
356 | Value VectorizationState::getOrCreateMaskFor( |
357 | RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
358 | std::optional<AffineMap> maybeMaskingMap) { |
359 | // No mask is needed if the operation is not maskable. |
360 | auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask); |
361 | if (!maskableOp) |
362 | return Value(); |
363 | |
364 | assert(!maskableOp.isMasked() && |
365 | "Masking an operation that is already masked" ); |
366 | |
367 | // If no masking map was provided, use an identity map with the loop dims. |
368 | assert((!maybeMaskingMap || *maybeMaskingMap) && |
369 | "Unexpected null mask permutation map" ); |
370 | AffineMap maskingMap = |
371 | maybeMaskingMap ? *maybeMaskingMap |
372 | : AffineMap::getMultiDimIdentityMap( |
373 | numDims: linalgOp.getNumLoops(), context: rewriter.getContext()); |
374 | |
375 | LDBG("Masking map: " << maskingMap << "\n" ); |
376 | |
377 | // Return the active mask for the masking map of this operation if it was |
378 | // already created. |
379 | auto activeMaskIt = activeMaskCache.find(Val: maskingMap); |
380 | if (activeMaskIt != activeMaskCache.end()) { |
381 | Value mask = activeMaskIt->second; |
382 | LDBG("Reusing mask: " << mask << "\n" ); |
383 | return mask; |
384 | } |
385 | |
386 | // Compute permuted projection of the iteration space to be masked and the |
387 | // corresponding mask shape. If the resulting iteration space dimensions are |
388 | // static and identical to the mask shape, masking is not needed for this |
389 | // operation. |
390 | // TODO: Improve this check. Only projected permutation indexing maps are |
391 | // supported. |
392 | SmallVector<int64_t> permutedStaticSizes = |
393 | applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes); |
394 | auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap); |
395 | auto maskShape = maskType.getShape(); |
396 | |
397 | LDBG("Mask shape: " ); |
398 | LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); |
399 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
400 | |
401 | if (permutedStaticSizes == maskShape) { |
402 | LDBG("Masking is not needed for masking map: " << maskingMap << "\n" ); |
403 | activeMaskCache[maskingMap] = Value(); |
404 | return Value(); |
405 | } |
406 | |
407 | // Permute the iteration space value sizes to compute the mask upper bounds. |
408 | SmallVector<Value> upperBounds = |
409 | applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes)); |
410 | assert(!maskShape.empty() && !upperBounds.empty() && |
411 | "Masked 0-d vectors are not supported yet" ); |
412 | |
413 | // Create the mask based on the dimension values. |
414 | Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(), |
415 | maskType, upperBounds); |
416 | LDBG("Creating new mask: " << mask << "\n" ); |
417 | activeMaskCache[maskingMap] = mask; |
418 | return mask; |
419 | } |
420 | |
421 | /// Masks an operation with the canonical vector mask if the operation needs |
422 | /// masking. Returns the masked operation or the original operation if masking |
423 | /// is not needed. If provided, the canonical mask for this operation is |
424 | /// permuted using `maybeMaskingMap`. |
425 | Operation * |
426 | VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, |
427 | LinalgOp linalgOp, |
428 | std::optional<AffineMap> maybeMaskingMap) { |
429 | LDBG("Trying to mask: " << *opToMask << "\n" ); |
430 | |
431 | // Create or retrieve mask for this operation. |
432 | Value mask = |
433 | getOrCreateMaskFor(rewriter, opToMask, linalgOp: linalgOp, maybeMaskingMap); |
434 | |
435 | if (!mask) { |
436 | LDBG("No mask required\n" ); |
437 | return opToMask; |
438 | } |
439 | |
440 | // Wrap the operation with a new `vector.mask` and update D-U chain. |
441 | assert(opToMask && "Expected a valid operation to mask" ); |
442 | auto maskOp = cast<vector::MaskOp>( |
443 | mlir::vector::maskOperation(rewriter, opToMask, mask)); |
444 | Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); |
445 | |
446 | for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults())) |
447 | rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx), |
448 | maskOpTerminator); |
449 | |
450 | LDBG("Masked operation: " << *maskOp << "\n" ); |
451 | return maskOp; |
452 | } |
453 | |
454 | /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a |
455 | /// projectedPermutation, compress the unused dimensions to serve as a |
456 | /// permutation_map for a vector transfer operation. |
457 | /// For example, given a linalg op such as: |
458 | /// |
459 | /// ``` |
460 | /// %0 = linalg.generic { |
461 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>, |
462 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)> |
463 | /// } |
464 | /// ins(%0 : tensor<2x3x4xf32>) |
465 | /// outs(%1 : tensor<5x6xf32>) |
466 | /// ``` |
467 | /// |
468 | /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine |
469 | /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second |
470 | /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. |
471 | static AffineMap reindexIndexingMap(AffineMap map) { |
472 | assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) && |
473 | "expected projected permutation" ); |
474 | auto res = compressUnusedDims(map); |
475 | assert(res.getNumDims() == res.getNumResults() && |
476 | "expected reindexed map with same number of dims and results" ); |
477 | return res; |
478 | } |
479 | |
480 | /// Helper enum to represent conv1d input traversal order. |
481 | enum class Conv1DOpOrder { |
482 | W, // Corresponds to non-channeled 1D convolution operation. |
483 | Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. |
484 | Nwc // Corresponds to operation that traverses the input in (n, w, c) order. |
485 | }; |
486 | |
487 | /// Helper data structure to represent the result of vectorization. |
488 | /// In certain specific cases, like terminators, we do not want to propagate/ |
489 | enum VectorizationStatus { |
490 | /// Op failed to vectorize. |
491 | Failure = 0, |
492 | /// Op vectorized and custom function took care of replacement logic |
493 | NoReplace, |
494 | /// Op vectorized into a new Op whose results will replace original Op's |
495 | /// results. |
496 | NewOp |
497 | // TODO: support values if Op vectorized to Many-Ops whose results we need to |
498 | // aggregate for replacement. |
499 | }; |
500 | struct VectorizationResult { |
501 | /// Return status from vectorizing the current op. |
502 | enum VectorizationStatus status = VectorizationStatus::Failure; |
503 | /// New vectorized operation to replace the current op. |
504 | /// Replacement behavior is specified by `status`. |
505 | Operation *newOp; |
506 | }; |
507 | |
508 | std::optional<vector::CombiningKind> |
509 | mlir::linalg::getCombinerOpKind(Operation *combinerOp) { |
510 | using ::mlir::vector::CombiningKind; |
511 | |
512 | if (!combinerOp) |
513 | return std::nullopt; |
514 | return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp) |
515 | .Case<arith::AddIOp, arith::AddFOp>( |
516 | [&](auto op) { return CombiningKind::ADD; }) |
517 | .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; }) |
518 | .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; }) |
519 | .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; }) |
520 | .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; }) |
521 | .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; }) |
522 | .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; }) |
523 | .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; }) |
524 | .Case<arith::MulIOp, arith::MulFOp>( |
525 | [&](auto op) { return CombiningKind::MUL; }) |
526 | .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; }) |
527 | .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) |
528 | .Default([&](auto op) { return std::nullopt; }); |
529 | } |
530 | |
531 | /// Check whether `outputOperand` is a reduction with a single combiner |
532 | /// operation. Return the combiner operation of the reduction. Return |
533 | /// nullptr otherwise. Multiple reduction operations would impose an |
534 | /// ordering between reduction dimensions and is currently unsupported in |
535 | /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != |
536 | /// max(min(X)) |
537 | // TODO: use in LinalgOp verification, there is a circular dependency atm. |
538 | static Operation *matchLinalgReduction(OpOperand *outputOperand) { |
539 | auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); |
540 | unsigned outputPos = |
541 | outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs(); |
542 | // Only single combiner operations are supported for now. |
543 | SmallVector<Operation *, 4> combinerOps; |
544 | if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || |
545 | combinerOps.size() != 1) |
546 | return nullptr; |
547 | |
548 | // Return the combiner operation. |
549 | return combinerOps[0]; |
550 | } |
551 | |
552 | /// Broadcast `value` to a vector of `shape` if possible. Return value |
553 | /// otherwise. |
554 | static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) { |
555 | auto dstVecType = dyn_cast<VectorType>(dstType); |
556 | // If no shape to broadcast to, just return `value`. |
557 | if (dstVecType.getRank() == 0) |
558 | return value; |
559 | if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) != |
560 | vector::BroadcastableToResult::Success) |
561 | return value; |
562 | Location loc = b.getInsertionPoint()->getLoc(); |
563 | return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value); |
564 | } |
565 | |
566 | /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This |
567 | /// assumes that `reductionOp` has two operands and one of them is the reduction |
568 | /// initial value.buildMultiDimReduce |
569 | // Note: this is a true builder that notifies the OpBuilder listener. |
570 | // TODO: Consider moving as a static helper on the ReduceOp. |
571 | static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, |
572 | Value valueToReduce, Value acc, |
573 | ArrayRef<bool> dimsToMask) { |
574 | auto maybeKind = getCombinerOpKind(reduceOp); |
575 | assert(maybeKind && "Failed precondition: could not get reduction kind" ); |
576 | return b.create<vector::MultiDimReductionOp>( |
577 | reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); |
578 | } |
579 | |
580 | static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) { |
581 | return llvm::to_vector( |
582 | llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); |
583 | } |
584 | |
585 | /// Build a vector.transfer_write of `value` into `outputOperand` at indices set |
586 | /// to all `0`; where `outputOperand` is an output operand of the LinalgOp |
587 | /// currently being vectorized. If `dest` has null rank, build an memref.store. |
588 | /// Return the produced value or null if no value is produced. |
589 | // Note: this is a true builder that notifies the OpBuilder listener. |
590 | // TODO: Consider moving as a static helper on the ReduceOp. |
591 | static Value buildVectorWrite(RewriterBase &rewriter, Value value, |
592 | OpOperand *outputOperand, |
593 | VectorizationState &state) { |
594 | Location loc = value.getLoc(); |
595 | auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); |
596 | AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); |
597 | |
598 | // Compute the vector type of the value to store. This type should be an |
599 | // identity or projection of the canonical vector type without any permutation |
600 | // applied, given that any permutation in a transfer write happens as part of |
601 | // the write itself. |
602 | AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap( |
603 | ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(), |
604 | keepDimFilter: [&](AffineDimExpr dimExpr) -> bool { |
605 | return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr); |
606 | }); |
607 | auto vectorType = state.getCanonicalVecType( |
608 | getElementTypeOrSelf(type: outputOperand->get().getType()), vectorTypeMap); |
609 | |
610 | Operation *write; |
611 | if (vectorType.getRank() > 0) { |
612 | AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap)); |
613 | SmallVector<Value> indices(linalgOp.getRank(outputOperand), |
614 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
615 | value = broadcastIfNeeded(rewriter, value, vectorType); |
616 | assert(value.getType() == vectorType && "Incorrect type" ); |
617 | write = rewriter.create<vector::TransferWriteOp>( |
618 | loc, value, outputOperand->get(), indices, writeMap); |
619 | } else { |
620 | // 0-d case is still special: do not invert the reindexing writeMap. |
621 | if (!isa<VectorType>(value.getType())) |
622 | value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value); |
623 | assert(value.getType() == vectorType && "Incorrect type" ); |
624 | write = rewriter.create<vector::TransferWriteOp>( |
625 | loc, value, outputOperand->get(), ValueRange{}); |
626 | } |
627 | |
628 | write = state.maskOperation(rewriter, opToMask: write, linalgOp: linalgOp, maybeMaskingMap: opOperandMap); |
629 | |
630 | // If masked, set in-bounds to true. Masking guarantees that the access will |
631 | // be in-bounds. |
632 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) { |
633 | auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp()); |
634 | SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true); |
635 | maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); |
636 | } |
637 | |
638 | LDBG("vectorized op: " << *write << "\n" ); |
639 | if (!write->getResults().empty()) |
640 | return write->getResult(idx: 0); |
641 | return Value(); |
642 | } |
643 | |
644 | // Custom vectorization precondition function type. This is intented to be used |
645 | // with CustomVectorizationHook. Returns success if the corresponding custom |
646 | // hook can vectorize the op. |
647 | using CustomVectorizationPrecondition = |
648 | std::function<LogicalResult(Operation *, bool)>; |
649 | |
650 | // Custom vectorization function type. Produce a vector form of Operation* |
651 | // assuming all its vectorized operands are already in the IRMapping. |
652 | // Return nullptr if the Operation cannot be vectorized. |
653 | using CustomVectorizationHook = |
654 | std::function<VectorizationResult(Operation *, const IRMapping &)>; |
655 | |
656 | /// Helper function to vectorize the terminator of a `linalgOp`. New result |
657 | /// vector values are appended to `newResults`. Return |
658 | /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it |
659 | /// should not try to map produced operations and instead return the results |
660 | /// using the `newResults` vector making them available to the vectorization |
661 | /// algorithm for RAUW. This function is meant to be used as a |
662 | /// CustomVectorizationHook. |
663 | static VectorizationResult |
664 | vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, |
665 | const IRMapping &bvm, VectorizationState &state, |
666 | LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) { |
667 | auto yieldOp = dyn_cast<linalg::YieldOp>(op); |
668 | if (!yieldOp) |
669 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
670 | for (const auto &output : llvm::enumerate(yieldOp.getValues())) { |
671 | // TODO: Scan for an opportunity for reuse. |
672 | // TODO: use a map. |
673 | Value vectorValue = bvm.lookup(output.value()); |
674 | Value newResult = |
675 | buildVectorWrite(rewriter, vectorValue, |
676 | linalgOp.getDpsInitOperand(output.index()), state); |
677 | if (newResult) |
678 | newResults.push_back(newResult); |
679 | } |
680 | |
681 | return VectorizationResult{.status: VectorizationStatus::NoReplace, .newOp: nullptr}; |
682 | } |
683 | |
684 | /// Helper function to vectorize the index operations of a `linalgOp`. Return |
685 | /// VectorizationStatus::NewOp to signal the vectorization algorithm that it |
686 | /// should map the produced operations. This function is meant to be used as a |
687 | /// CustomVectorizationHook. |
688 | static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, |
689 | VectorizationState &state, |
690 | Operation *op, |
691 | LinalgOp linalgOp) { |
692 | IndexOp indexOp = dyn_cast<linalg::IndexOp>(op); |
693 | if (!indexOp) |
694 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
695 | auto loc = indexOp.getLoc(); |
696 | // Compute the static loop sizes of the index op. |
697 | auto targetShape = state.getCanonicalVecShape(); |
698 | // Compute a one-dimensional index vector for the index op dimension. |
699 | auto constantSeq = |
700 | llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()])); |
701 | auto indexSteps = rewriter.create<arith::ConstantOp>( |
702 | loc, rewriter.getIndexVectorAttr(constantSeq)); |
703 | // Return the one-dimensional index vector if it lives in the trailing |
704 | // dimension of the iteration space since the vectorization algorithm in this |
705 | // case can handle the broadcast. |
706 | if (indexOp.getDim() == targetShape.size() - 1) |
707 | return VectorizationResult{VectorizationStatus::NewOp, indexSteps}; |
708 | // Otherwise permute the targetShape to move the index dimension last, |
709 | // broadcast the one-dimensional index vector to the permuted shape, and |
710 | // finally transpose the broadcasted index vector to undo the permutation. |
711 | auto permPattern = |
712 | llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size())); |
713 | std::swap(permPattern[indexOp.getDim()], permPattern.back()); |
714 | auto permMap = |
715 | AffineMap::getPermutationMap(permPattern, linalgOp.getContext()); |
716 | |
717 | auto broadCastOp = rewriter.create<vector::BroadcastOp>( |
718 | loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap), |
719 | indexSteps); |
720 | SmallVector<int64_t> transposition = |
721 | llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops())); |
722 | std::swap(transposition.back(), transposition[indexOp.getDim()]); |
723 | auto transposeOp = |
724 | rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition); |
725 | return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; |
726 | } |
727 | |
728 | /// Helper function to check if the tensor.extract can be vectorized by the |
729 | /// custom hook vectorizeTensorExtract. |
730 | static LogicalResult |
731 | (Operation *op, bool ) { |
732 | tensor::ExtractOp = dyn_cast<tensor::ExtractOp>(op); |
733 | if (!extractOp) |
734 | return failure(); |
735 | |
736 | if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) |
737 | return failure(); |
738 | |
739 | // Check the index type, but only for non 0-d tensors (for which we do need |
740 | // access indices). |
741 | if (not extractOp.getIndices().empty()) { |
742 | if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) |
743 | return failure(); |
744 | } |
745 | |
746 | if (llvm::any_of(extractOp->getResultTypes(), [](Type type) { |
747 | return !VectorType::isValidElementType(type); |
748 | })) { |
749 | return failure(); |
750 | } |
751 | |
752 | return success(); |
753 | } |
754 | |
755 | /// Calculates the offsets (`$index_vec`) for `vector.gather` operations |
756 | /// generated from `tensor.extract`. The offset is calculated as follows |
757 | /// (example using scalar values): |
758 | /// |
759 | /// offset = extractOp.indices[0] |
760 | /// for (i = 1; i < numIndices; i++) |
761 | /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; |
762 | /// |
763 | /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: |
764 | /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 |
765 | static Value calculateGatherOffset(RewriterBase &rewriter, |
766 | VectorizationState &state, |
767 | tensor::ExtractOp , |
768 | const IRMapping &bvm) { |
769 | // The vector of indices for GatherOp should be shaped as the output vector. |
770 | auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType()); |
771 | auto loc = extractOp.getLoc(); |
772 | |
773 | Value offset = broadcastIfNeeded( |
774 | rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType); |
775 | |
776 | const size_t numIndices = extractOp.getIndices().size(); |
777 | for (size_t i = 1; i < numIndices; i++) { |
778 | Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i); |
779 | |
780 | auto dimSize = broadcastIfNeeded( |
781 | rewriter, |
782 | rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx), |
783 | indexVecType); |
784 | |
785 | offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize); |
786 | |
787 | auto = broadcastIfNeeded( |
788 | rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType); |
789 | |
790 | offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset); |
791 | } |
792 | |
793 | return offset; |
794 | } |
795 | |
796 | enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; |
797 | |
798 | /// Checks whether /p val can be used for calculating a loop invariant index. |
799 | static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { |
800 | |
801 | auto targetShape = linalgOp.getStaticLoopRanges(); |
802 | assert(((llvm::count_if(targetShape, |
803 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
804 | "n-D vectors are not yet supported" ); |
805 | assert(targetShape.back() != 1 && |
806 | "1-D vectors with the trailing dim eqaual 1 are not yet supported" ); |
807 | |
808 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
809 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
810 | // tricky. Just bail out in the latter case. |
811 | // TODO: We could try analysing the corresponding affine map here. |
812 | auto *block = linalgOp.getBlock(); |
813 | if (isa<BlockArgument>(Val: val)) |
814 | return llvm::all_of(block->getArguments(), |
815 | [&val](Value v) { return (v != val); }); |
816 | |
817 | Operation *defOp = val.getDefiningOp(); |
818 | assert(defOp && "This is neither a block argument nor an operation result" ); |
819 | |
820 | // IndexOp is loop invariant as long as its result remains constant across |
821 | // iterations. Given the assumptions on the loop ranges above, only the |
822 | // trailing loop dim ever changes. |
823 | auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1; |
824 | if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) |
825 | return (indexOp.getDim() != trailingLoopDim); |
826 | |
827 | auto *ancestor = block->findAncestorOpInBlock(*defOp); |
828 | |
829 | // Values define outside `linalgOp` are loop invariant. |
830 | if (!ancestor) |
831 | return true; |
832 | |
833 | // Values defined inside `linalgOp`, which are constant, are loop invariant. |
834 | if (isa<arith::ConstantOp>(ancestor)) |
835 | return true; |
836 | |
837 | bool result = true; |
838 | for (auto op : ancestor->getOperands()) |
839 | result &= isLoopInvariantIdx(linalgOp, op); |
840 | |
841 | return result; |
842 | } |
843 | |
844 | /// Check whether \p val could be used for calculating the trailing index for a |
845 | /// contiguous load operation. |
846 | /// |
847 | /// There are currently 3 types of values that are allowed here: |
848 | /// 1. loop-invariant values, |
849 | /// 2. values that increment by 1 with every loop iteration, |
850 | /// 3. results of basic arithmetic operations (linear and continuous) |
851 | /// involving 1., 2. and 3. |
852 | /// This method returns True if indeed only such values are used in calculating |
853 | /// \p val. |
854 | /// |
855 | /// Additionally, the trailing index for a contiguous load operation should |
856 | /// increment by 1 with every loop iteration, i.e. be based on: |
857 | /// * `linalg.index <dim>` , |
858 | /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is |
859 | /// updated to `true` when such an op is found. |
860 | static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, |
861 | bool &foundIndexOp) { |
862 | |
863 | auto targetShape = linalgOp.getStaticLoopRanges(); |
864 | assert(((llvm::count_if(targetShape, |
865 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
866 | "n-D vectors are not yet supported" ); |
867 | assert(targetShape.back() != 1 && |
868 | "1-D vectors with the trailing dim 1 are not yet supported" ); |
869 | |
870 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
871 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
872 | // tricky. Just bail out in the latter case. |
873 | // TODO: We could try analysing the corresponding affine map here. |
874 | auto *block = linalgOp.getBlock(); |
875 | if (isa<BlockArgument>(Val: val)) |
876 | return llvm::all_of(block->getArguments(), |
877 | [&val](Value v) { return (v != val); }); |
878 | |
879 | Operation *defOp = val.getDefiningOp(); |
880 | assert(defOp && "This is neither a block argument nor an operation result" ); |
881 | |
882 | // Given the assumption on the loop ranges above, only the trailing loop |
883 | // index is not constant. |
884 | auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1; |
885 | if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { |
886 | foundIndexOp = (indexOp.getDim() == trailingLoopDim); |
887 | return true; |
888 | } |
889 | |
890 | auto *ancestor = block->findAncestorOpInBlock(*defOp); |
891 | |
892 | if (!ancestor) |
893 | return false; |
894 | |
895 | // Conservatively reject Ops that could lead to indices with stride other |
896 | // than 1. |
897 | if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor)) |
898 | return false; |
899 | |
900 | bool result = false; |
901 | for (auto op : ancestor->getOperands()) |
902 | result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp); |
903 | |
904 | return result; |
905 | } |
906 | |
907 | /// Infer the memory access pattern for the input ExtractOp |
908 | /// |
909 | /// Based on the operation shapes and indices (usually based on the iteration |
910 | /// space of the parent `linalgOp` operation), decides whether the input |
911 | /// ExtractOp is a contiguous load (including a broadcast of a scalar) or a |
912 | /// gather load. |
913 | /// |
914 | /// Note that it is always safe to use gather load operations for contiguous |
915 | /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume |
916 | /// that `extractOp` is a gather load. |
917 | static VectorMemoryAccessKind |
918 | (tensor::ExtractOp , |
919 | LinalgOp &linalgOp) { |
920 | |
921 | auto targetShape = linalgOp.getStaticLoopRanges(); |
922 | auto inputShape = cast<ShapedType>(extractOp.getTensor().getType()); |
923 | |
924 | // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast. |
925 | if (inputShape.getShape().empty()) |
926 | return VectorMemoryAccessKind::ScalarBroadcast; |
927 | |
928 | // 0.2 In the case of dynamic shapes just bail-out and assume that it's a |
929 | // gather load. |
930 | // TODO: Relax this condition. |
931 | if (linalgOp.hasDynamicShape()) |
932 | return VectorMemoryAccessKind::Gather; |
933 | |
934 | // 1. Assume that it's a gather load when reading _into_: |
935 | // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or |
936 | // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. |
937 | // TODO: Relax these conditions. |
938 | // FIXME: This condition assumes non-dynamic sizes. |
939 | if ((llvm::count_if(targetShape, |
940 | [](int64_t dimSize) { return dimSize > 1; }) != 1) || |
941 | targetShape.back() == 1) |
942 | return VectorMemoryAccessKind::Gather; |
943 | |
944 | // 2. Assume that it's a gather load when reading _from_ a tensor for which |
945 | // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. |
946 | // TODO: Relax this condition. |
947 | if (inputShape.getShape().back() == 1) |
948 | return VectorMemoryAccessKind::Gather; |
949 | |
950 | bool leadingIdxsLoopInvariant = true; |
951 | |
952 | // 3. Analyze the leading indices of `extractOp`. |
953 | // Look at the way each index is calculated and decide whether it is suitable |
954 | // for a contiguous load, i.e. whether it's loop invariant. |
955 | auto indices = extractOp.getIndices(); |
956 | auto leadIndices = indices.drop_back(1); |
957 | |
958 | for (auto [i, indexVal] : llvm::enumerate(leadIndices)) { |
959 | if (inputShape.getShape()[i] == 1) |
960 | continue; |
961 | |
962 | leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal); |
963 | } |
964 | |
965 | if (!leadingIdxsLoopInvariant) { |
966 | LDBG("Found gather load: " << extractOp); |
967 | return VectorMemoryAccessKind::Gather; |
968 | } |
969 | |
970 | // 4. Analyze the trailing index for `extractOp`. |
971 | // At this point we know that the leading indices are loop invariant. This |
972 | // means that is potentially a scalar or a contiguous load. We can decide |
973 | // based on the trailing idx. |
974 | auto = indices.back(); |
975 | |
976 | // 4a. Scalar broadcast load |
977 | // If the trailing index is loop invariant then this is a scalar load. |
978 | if (leadingIdxsLoopInvariant && |
979 | isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) { |
980 | LDBG("Found scalar broadcast load: " << extractOp); |
981 | |
982 | return VectorMemoryAccessKind::ScalarBroadcast; |
983 | } |
984 | |
985 | // 4b. Contiguous loads |
986 | // The trailing `extractOp` index should increment with every loop iteration. |
987 | // This effectively means that it must be based on the trailing loop index. |
988 | // This is what the following bool captures. |
989 | bool foundIndexOp = false; |
990 | bool isContiguousLoad = |
991 | isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp); |
992 | isContiguousLoad &= foundIndexOp; |
993 | |
994 | if (isContiguousLoad) { |
995 | LDBG("Found contigous load: " << extractOp); |
996 | return VectorMemoryAccessKind::Contiguous; |
997 | } |
998 | |
999 | // 5. Fallback case - gather load. |
1000 | LDBG("Found gather load: " << extractOp); |
1001 | return VectorMemoryAccessKind::Gather; |
1002 | } |
1003 | |
1004 | /// Helper function to vectorize the tensor.extract operations. Returns |
1005 | /// VectorizationStatus::NewOp to signal the vectorization algorithm that it |
1006 | /// should map the produced operations. This function is meant to be used as a |
1007 | /// CustomVectorizationHook. |
1008 | static VectorizationResult |
1009 | (RewriterBase &rewriter, VectorizationState &state, |
1010 | Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { |
1011 | tensor::ExtractOp = dyn_cast<tensor::ExtractOp>(op); |
1012 | if (!extractOp) |
1013 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
1014 | auto loc = extractOp.getLoc(); |
1015 | |
1016 | // Compute the static loop sizes of the extract op. |
1017 | auto resultType = state.getCanonicalVecType(extractOp.getResult().getType()); |
1018 | auto maskConstantOp = rewriter.create<arith::ConstantOp>( |
1019 | loc, |
1020 | DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()), |
1021 | /*value=*/true)); |
1022 | auto passThruConstantOp = |
1023 | rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType)); |
1024 | |
1025 | // Base indices are currently set to 0. We will need to re-visit if more |
1026 | // generic scenarios are to be supported. |
1027 | SmallVector<Value> baseIndices( |
1028 | extractOp.getIndices().size(), |
1029 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
1030 | |
1031 | VectorMemoryAccessKind memAccessKind = |
1032 | getTensorExtractMemoryAccessPattern(extractOp, linalgOp); |
1033 | |
1034 | // 1. Handle gather access |
1035 | if (memAccessKind == VectorMemoryAccessKind::Gather) { |
1036 | Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); |
1037 | |
1038 | // Generate the gather load |
1039 | Operation *gatherOp = rewriter.create<vector::GatherOp>( |
1040 | loc, resultType, extractOp.getTensor(), baseIndices, offset, |
1041 | maskConstantOp, passThruConstantOp); |
1042 | gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp: linalgOp); |
1043 | |
1044 | LDBG("Vectorised as gather load: " << extractOp << "\n" ); |
1045 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: gatherOp}; |
1046 | } |
1047 | |
1048 | // 2. Handle: |
1049 | // a. scalar loads + broadcast, |
1050 | // b. contiguous loads. |
1051 | // Both cases use vector.transfer_read. |
1052 | |
1053 | // Collect indices for `vector.transfer_read`. At this point, the indices will |
1054 | // either be scalars or would have been broadcast to vectors matching the |
1055 | // result type. For indices that are vectors, there are two options: |
1056 | // * for non-trailing indices, all elements are identical (contiguous |
1057 | // loads are identified by looking for non-trailing indices that are |
1058 | // invariant with respect to the corresponding linalg.generic), or |
1059 | // * for trailing indices, the index vector will contain values with stride |
1060 | // one, but for `vector.transfer_read` only the first (i.e. 0th) index is |
1061 | // needed. |
1062 | // This means that |
1063 | // * for scalar indices - just re-use it, |
1064 | // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom |
1065 | // (0th) element and use that. |
1066 | SmallVector<Value> transferReadIdxs; |
1067 | auto resTrailingDim = resultType.getShape().back(); |
1068 | auto zero = rewriter.create<arith::ConstantOp>( |
1069 | loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type())); |
1070 | for (size_t i = 0; i < extractOp.getIndices().size(); i++) { |
1071 | auto idx = bvm.lookup(extractOp.getIndices()[i]); |
1072 | if (idx.getType().isIndex()) { |
1073 | transferReadIdxs.push_back(Elt: idx); |
1074 | continue; |
1075 | } |
1076 | |
1077 | auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>( |
1078 | loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()), |
1079 | bvm.lookup(extractOp.getIndices()[i])); |
1080 | transferReadIdxs.push_back( |
1081 | rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero)); |
1082 | } |
1083 | |
1084 | // `tensor.extract_element` is always in-bounds, hence the following holds. |
1085 | auto dstRank = resultType.getRank(); |
1086 | auto srcRank = extractOp.getTensor().getType().getRank(); |
1087 | SmallVector<bool> inBounds(dstRank, true); |
1088 | |
1089 | // 2a. Handle scalar broadcast access. |
1090 | if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { |
1091 | MLIRContext *ctx = rewriter.getContext(); |
1092 | SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx)); |
1093 | auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); |
1094 | |
1095 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
1096 | loc, resultType, extractOp.getTensor(), transferReadIdxs, |
1097 | permutationMap, inBounds); |
1098 | |
1099 | LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n" ); |
1100 | return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; |
1101 | } |
1102 | |
1103 | // 2b. Handle contiguous access. |
1104 | auto permutationMap = AffineMap::getMinorIdentityMap( |
1105 | dims: srcRank, results: std::min(dstRank, srcRank), context: rewriter.getContext()); |
1106 | |
1107 | int32_t rankDiff = dstRank - srcRank; |
1108 | // When dstRank > srcRank, broadcast the source tensor to the unitary leading |
1109 | // dims so that the ranks match. This is done by extending the map with 0s. |
1110 | // For example, for dstRank = 3, srcRank = 2, the following map created |
1111 | // above: |
1112 | // (d0, d1) --> (d0, d1) |
1113 | // is extended as: |
1114 | // (d0, d1) --> (0, d0, d1) |
1115 | while (rankDiff > 0) { |
1116 | permutationMap = permutationMap.insertResult( |
1117 | mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), 0); |
1118 | rankDiff--; |
1119 | } |
1120 | |
1121 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
1122 | loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, |
1123 | inBounds); |
1124 | |
1125 | LDBG("Vectorised as contiguous load: " << extractOp); |
1126 | return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; |
1127 | } |
1128 | |
1129 | /// Emit reduction operations if the shapes of the value to reduce is different |
1130 | /// that the result shape. |
1131 | // Note: this is a true builder that notifies the OpBuilder listener. |
1132 | // TODO: Consider moving as a static helper on the ReduceOp. |
1133 | static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, |
1134 | Value reduceValue, Value initialValue, |
1135 | const IRMapping &bvm) { |
1136 | Value reduceVec = bvm.lookup(from: reduceValue); |
1137 | Value outputVec = bvm.lookup(from: initialValue); |
1138 | auto reduceType = dyn_cast<VectorType>(reduceVec.getType()); |
1139 | auto outputType = dyn_cast<VectorType>(outputVec.getType()); |
1140 | // Reduce only if needed as the value may already have been reduce for |
1141 | // contraction vectorization. |
1142 | if (!reduceType || |
1143 | (outputType && reduceType.getShape() == outputType.getShape())) |
1144 | return nullptr; |
1145 | SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp); |
1146 | return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask); |
1147 | } |
1148 | |
1149 | /// Generic vectorization for a single operation `op`, given already vectorized |
1150 | /// operands carried by `bvm`. Vectorization occurs as follows: |
1151 | /// 1. Try to apply any of the `customVectorizationHooks` and return its |
1152 | /// result on success. |
1153 | /// 2. Clone any constant in the current scope without vectorization: each |
1154 | /// consumer of the constant will later determine the shape to which the |
1155 | /// constant needs to be broadcast to. |
1156 | /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose |
1157 | /// of the `customVectorizationHooks` to cover such cases. |
1158 | /// 4. Clone `op` in vector form to a vector of shape prescribed by the first |
1159 | /// operand of maximal rank. Other operands have smaller rank and are |
1160 | /// broadcast accordingly. It is assumed this broadcast is always legal, |
1161 | /// otherwise, it means one of the `customVectorizationHooks` is incorrect. |
1162 | /// |
1163 | /// This function assumes all operands of `op` have been vectorized and are in |
1164 | /// the `bvm` mapping. As a consequence, this function is meant to be called on |
1165 | /// a topologically-sorted list of ops. |
1166 | /// This function does not update `bvm` but returns a VectorizationStatus that |
1167 | /// instructs the caller what `bvm` update needs to occur. |
1168 | static VectorizationResult |
1169 | vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, |
1170 | LinalgOp linalgOp, Operation *op, const IRMapping &bvm, |
1171 | ArrayRef<CustomVectorizationHook> customVectorizationHooks) { |
1172 | LDBG("vectorize op " << *op << "\n" ); |
1173 | |
1174 | // 1. Try to apply any CustomVectorizationHook. |
1175 | if (!customVectorizationHooks.empty()) { |
1176 | for (auto &customFunc : customVectorizationHooks) { |
1177 | VectorizationResult result = customFunc(op, bvm); |
1178 | if (result.status == VectorizationStatus::Failure) |
1179 | continue; |
1180 | return result; |
1181 | } |
1182 | } |
1183 | |
1184 | // 2. Constant ops don't get vectorized but rather broadcasted at their users. |
1185 | // Clone so that the constant is not confined to the linalgOp block . |
1186 | if (isa<arith::ConstantOp, func::ConstantOp>(op)) |
1187 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: rewriter.clone(op&: *op)}; |
1188 | |
1189 | // 3. Only ElementwiseMappable are allowed in the generic vectorization. |
1190 | if (!OpTrait::hasElementwiseMappableTraits(op)) |
1191 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
1192 | |
1193 | // 4 . Check if the operation is a reduction. |
1194 | SmallVector<std::pair<Value, Value>> reductionOperands; |
1195 | for (Value operand : op->getOperands()) { |
1196 | auto blockArg = dyn_cast<BlockArgument>(Val&: operand); |
1197 | if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || |
1198 | blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) |
1199 | continue; |
1200 | SmallVector<Operation *> reductionOps; |
1201 | Value reduceValue = matchReduction( |
1202 | linalgOp.getRegionOutputArgs(), |
1203 | blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); |
1204 | if (!reduceValue) |
1205 | continue; |
1206 | reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand)); |
1207 | } |
1208 | if (!reductionOperands.empty()) { |
1209 | assert(reductionOperands.size() == 1); |
1210 | Operation *reduceOp = |
1211 | reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first, |
1212 | reductionOperands[0].second, bvm); |
1213 | if (reduceOp) |
1214 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: reduceOp}; |
1215 | } |
1216 | |
1217 | // 5. Generic vectorization path for ElementwiseMappable ops. |
1218 | // a. Get the first max ranked shape. |
1219 | VectorType firstMaxRankedType; |
1220 | for (Value operand : op->getOperands()) { |
1221 | auto vecOperand = bvm.lookup(from: operand); |
1222 | assert(vecOperand && "Vector operand couldn't be found" ); |
1223 | |
1224 | auto vecType = dyn_cast<VectorType>(vecOperand.getType()); |
1225 | if (vecType && (!firstMaxRankedType || |
1226 | firstMaxRankedType.getRank() < vecType.getRank())) |
1227 | firstMaxRankedType = vecType; |
1228 | } |
1229 | // b. Broadcast each op if needed. |
1230 | SmallVector<Value> vecOperands; |
1231 | for (Value scalarOperand : op->getOperands()) { |
1232 | Value vecOperand = bvm.lookup(from: scalarOperand); |
1233 | assert(vecOperand && "Vector operand couldn't be found" ); |
1234 | |
1235 | if (firstMaxRankedType) { |
1236 | auto vecType = VectorType::get(firstMaxRankedType.getShape(), |
1237 | getElementTypeOrSelf(vecOperand.getType()), |
1238 | firstMaxRankedType.getScalableDims()); |
1239 | vecOperands.push_back(Elt: broadcastIfNeeded(rewriter, vecOperand, vecType)); |
1240 | } else { |
1241 | vecOperands.push_back(Elt: vecOperand); |
1242 | } |
1243 | } |
1244 | // c. for elementwise, the result is the vector with the firstMaxRankedShape |
1245 | SmallVector<Type> resultTypes; |
1246 | for (Type resultType : op->getResultTypes()) { |
1247 | resultTypes.push_back( |
1248 | firstMaxRankedType |
1249 | ? VectorType::get(firstMaxRankedType.getShape(), resultType, |
1250 | firstMaxRankedType.getScalableDims()) |
1251 | : resultType); |
1252 | } |
1253 | // d. Build and return the new op. |
1254 | return VectorizationResult{ |
1255 | .status: VectorizationStatus::NewOp, |
1256 | .newOp: rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands, |
1257 | resultTypes, op->getAttrs())}; |
1258 | } |
1259 | |
1260 | /// Generic vectorization function that rewrites the body of a `linalgOp` into |
1261 | /// vector form. Generic vectorization proceeds as follows: |
1262 | /// 1. Verify the `linalgOp` has one non-empty region. |
1263 | /// 2. Values defined above the region are mapped to themselves and will be |
1264 | /// broadcasted on a per-need basis by their consumers. |
1265 | /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d |
1266 | /// load). |
1267 | /// TODO: Reuse opportunities for RAR dependencies. |
1268 | /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. |
1269 | /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the |
1270 | /// iteration indices. |
1271 | /// 5. Iteratively call vectorizeOneOp on the region operations. |
1272 | /// |
1273 | /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is |
1274 | /// performed to the maximal common vector size implied by the `linalgOp` |
1275 | /// iteration space. This eager broadcasting is introduced in the |
1276 | /// permutation_map of the vector.transfer_read operations. The eager |
1277 | /// broadcasting makes it trivial to detrmine where broadcast, transposes and |
1278 | /// reductions should occur, without any bookkeeping. The tradeoff is that, in |
1279 | /// the absence of good canonicalizations, the amount of work increases. |
1280 | /// This is not deemed a problem as we expect canonicalizations and foldings to |
1281 | /// aggressively clean up the useless work. |
1282 | static LogicalResult |
1283 | vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, |
1284 | LinalgOp linalgOp, |
1285 | SmallVectorImpl<Value> &newResults) { |
1286 | LDBG("Vectorizing operation as linalg generic\n" ); |
1287 | Block *block = linalgOp.getBlock(); |
1288 | |
1289 | // 2. Values defined above the region can only be broadcast for now. Make them |
1290 | // map to themselves. |
1291 | IRMapping bvm; |
1292 | SetVector<Value> valuesSet; |
1293 | mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); |
1294 | bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef()); |
1295 | |
1296 | if (linalgOp.getNumDpsInits() == 0) |
1297 | return failure(); |
1298 | |
1299 | // 3. Turn all BBArgs into vector.transfer_read / load. |
1300 | Location loc = linalgOp.getLoc(); |
1301 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1302 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
1303 | BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); |
1304 | if (linalgOp.isScalar(opOperand)) { |
1305 | bvm.map(bbarg, opOperand->get()); |
1306 | continue; |
1307 | } |
1308 | |
1309 | // 3.a. Convert the indexing map for this input/output to a transfer read |
1310 | // permutation map and masking map. |
1311 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); |
1312 | |
1313 | // Remove zeros from indexing map to use it as masking map. |
1314 | SmallVector<int64_t> zeroPos; |
1315 | auto results = indexingMap.getResults(); |
1316 | for (const auto &result : llvm::enumerate(results)) { |
1317 | if (isa<AffineConstantExpr>(result.value())) { |
1318 | zeroPos.push_back(result.index()); |
1319 | } |
1320 | } |
1321 | AffineMap maskingMap = indexingMap.dropResults(zeroPos); |
1322 | |
1323 | AffineMap readMap; |
1324 | VectorType readType; |
1325 | Type elemType = getElementTypeOrSelf(opOperand->get()); |
1326 | if (linalgOp.isDpsInput(opOperand)) { |
1327 | // 3.a.i. For input reads we use the canonical vector shape. |
1328 | readMap = inverseAndBroadcastProjectedPermutation(indexingMap); |
1329 | readType = state.getCanonicalVecType(elemType); |
1330 | } else { |
1331 | // 3.a.ii. For output reads (iteration-carried dependence, e.g., |
1332 | // reductions), the vector shape is computed by mapping the canonical |
1333 | // vector shape to the output domain and back to the canonical domain. |
1334 | readMap = inversePermutation(reindexIndexingMap(indexingMap)); |
1335 | readType = |
1336 | state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); |
1337 | } |
1338 | |
1339 | SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero); |
1340 | |
1341 | Operation *read = rewriter.create<vector::TransferReadOp>( |
1342 | loc, readType, opOperand->get(), indices, readMap); |
1343 | read = state.maskOperation(rewriter, read, linalgOp, maskingMap); |
1344 | Value readValue = read->getResult(0); |
1345 | |
1346 | // 3.b. If masked, set in-bounds to true. Masking guarantees that the access |
1347 | // will be in-bounds. |
1348 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) { |
1349 | SmallVector<bool> inBounds(readType.getRank(), true); |
1350 | cast<vector::TransferReadOp>(maskOp.getMaskableOp()) |
1351 | .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); |
1352 | } |
1353 | |
1354 | // 3.c. Not all ops support 0-d vectors, extract the scalar for now. |
1355 | // TODO: remove this. |
1356 | if (readType.getRank() == 0) |
1357 | readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue); |
1358 | |
1359 | LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue |
1360 | << "\n" ); |
1361 | bvm.map(bbarg, readValue); |
1362 | bvm.map(opOperand->get(), readValue); |
1363 | } |
1364 | |
1365 | SmallVector<CustomVectorizationHook> hooks; |
1366 | // 4a. Register CustomVectorizationHook for yieldOp. |
1367 | CustomVectorizationHook vectorizeYield = |
1368 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
1369 | return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults); |
1370 | }; |
1371 | hooks.push_back(Elt: vectorizeYield); |
1372 | |
1373 | // 4b. Register CustomVectorizationHook for indexOp. |
1374 | CustomVectorizationHook vectorizeIndex = |
1375 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
1376 | return vectorizeLinalgIndex(rewriter, state, op, linalgOp); |
1377 | }; |
1378 | hooks.push_back(Elt: vectorizeIndex); |
1379 | |
1380 | // 4c. Register CustomVectorizationHook for extractOp. |
1381 | CustomVectorizationHook = |
1382 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
1383 | return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); |
1384 | }; |
1385 | hooks.push_back(Elt: vectorizeExtract); |
1386 | |
1387 | // 5. Iteratively call `vectorizeOneOp` to each op in the slice. |
1388 | for (Operation &op : block->getOperations()) { |
1389 | VectorizationResult result = |
1390 | vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); |
1391 | if (result.status == VectorizationStatus::Failure) { |
1392 | LDBG("failed to vectorize: " << op << "\n" ); |
1393 | return failure(); |
1394 | } |
1395 | if (result.status == VectorizationStatus::NewOp) { |
1396 | Operation *maybeMaskedOp = |
1397 | state.maskOperation(rewriter, result.newOp, linalgOp); |
1398 | LDBG("New vector op: " << *maybeMaskedOp << "\n" ); |
1399 | bvm.map(op.getResults(), maybeMaskedOp->getResults()); |
1400 | } |
1401 | } |
1402 | |
1403 | return success(); |
1404 | } |
1405 | |
1406 | /// Given a tensor::PackOp, return the `dest` shape before any packing |
1407 | /// permutations. |
1408 | static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp, |
1409 | ArrayRef<int64_t> destShape) { |
1410 | return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp)); |
1411 | } |
1412 | |
1413 | /// Given an input, the mixed destSizes, and the vector sizes for vectorization, |
1414 | /// create an empty destination tensor and create a TransferWriteOp from the |
1415 | /// input to the empty tensor. If the destination shape is not the same as the |
1416 | /// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a |
1417 | /// mask for the write. |
1418 | static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, |
1419 | Value input, |
1420 | SmallVector<OpFoldResult> destSizes, |
1421 | ArrayRef<int64_t> inputVectorSizes) { |
1422 | auto inputType = cast<VectorType>(input.getType()); |
1423 | Value dest = builder.create<tensor::EmptyOp>(loc, destSizes, |
1424 | inputType.getElementType()); |
1425 | int64_t rank = cast<ShapedType>(dest.getType()).getRank(); |
1426 | auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1427 | Operation *write = builder.create<vector::TransferWriteOp>( |
1428 | loc, |
1429 | /*vector=*/input, |
1430 | /*source=*/dest, |
1431 | /*indices=*/SmallVector<Value>(rank, zero), |
1432 | /*inBounds=*/SmallVector<bool>(rank, true)); |
1433 | auto destShape = cast<ShapedType>(dest.getType()).getShape(); |
1434 | assert(llvm::none_of( |
1435 | destShape.drop_front(inputVectorSizes.size()), |
1436 | [](int64_t size) { return size == ShapedType::kDynamic; }) && |
1437 | "Only dims aligned with inputVectorSizes may be dynamic" ); |
1438 | bool needMaskForWrite = !llvm::equal( |
1439 | inputVectorSizes, destShape.take_front(inputVectorSizes.size())); |
1440 | if (needMaskForWrite) { |
1441 | SmallVector<int64_t> writeMaskShape; |
1442 | writeMaskShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end()); |
1443 | writeMaskShape.append(destShape.begin() + inputVectorSizes.size(), |
1444 | destShape.end()); |
1445 | auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type()); |
1446 | Value maskForWrite = |
1447 | builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes); |
1448 | write = mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite); |
1449 | } |
1450 | return write; |
1451 | } |
1452 | |
1453 | /// Vectorize tensor::PackOp with (1) static innerTiles (2) constant |
1454 | /// padding value and (3) input vector sizes into: |
1455 | /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds |
1456 | /// As in the following example: |
1457 | /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] |
1458 | /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> |
1459 | /// |
1460 | /// This pack would be vectorized to: |
1461 | /// |
1462 | /// %load = vector.mask %mask { |
1463 | /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst |
1464 | /// {in_bounds = [true, true, true]} : |
1465 | /// tensor<32x7x16xf32>, vector<32x8x16xf32> |
1466 | /// } : vector<32x8x16xi1> -> vector<32x8x16xf32> |
1467 | /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> |
1468 | /// to vector<32x4x2x1x16xf32> |
1469 | /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] |
1470 | /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> |
1471 | /// %write = vector.transfer_write %transpose, |
1472 | /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] |
1473 | /// {in_bounds = [true, true, true, true, true]} |
1474 | /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> |
1475 | /// |
1476 | /// If the (3) input vector sizes are not provided, the vector sizes are |
1477 | /// determined by the result tensor shape. Also, we update the inBounds |
1478 | /// attribute instead of masking. |
1479 | static LogicalResult |
1480 | vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, |
1481 | ArrayRef<int64_t> inputVectorSizes, |
1482 | SmallVectorImpl<Value> &newResults) { |
1483 | OpBuilder::InsertionGuard g(rewriter); |
1484 | rewriter.setInsertionPoint(packOp); |
1485 | |
1486 | Location loc = packOp.getLoc(); |
1487 | auto padValue = packOp.getPaddingValue(); |
1488 | if (!padValue) { |
1489 | padValue = rewriter.create<arith::ConstantOp>( |
1490 | loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); |
1491 | } |
1492 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
1493 | LogicalResult status = |
1494 | cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation()) |
1495 | .reifyResultShapes(rewriter, reifiedReturnShapes); |
1496 | (void)status; // prevent unused variable warning on non-assert builds. |
1497 | assert(succeeded(status) && "failed to reify result shapes" ); |
1498 | |
1499 | // If the input vector sizes are not provided, then the vector sizes are |
1500 | // determined by the result tensor shape. In case the vector sizes aren't |
1501 | // provided, we update the inBounds attribute instead of masking. |
1502 | bool useInBoundsInsteadOfMasking = true; |
1503 | if (inputVectorSizes.empty()) { |
1504 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
1505 | inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank()); |
1506 | useInBoundsInsteadOfMasking = false; |
1507 | } |
1508 | |
1509 | // Create masked TransferReadOp. |
1510 | SmallVector<int64_t> inputShape(inputVectorSizes); |
1511 | auto innerTiles = packOp.getStaticInnerTiles(); |
1512 | auto innerDimsPos = packOp.getInnerDimsPos(); |
1513 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
1514 | if (!outerDimsPerm.empty()) |
1515 | applyPermutationToVector(inputShape, |
1516 | invertPermutationVector(outerDimsPerm)); |
1517 | for (auto [idx, size] : enumerate(innerTiles)) |
1518 | inputShape[innerDimsPos[idx]] *= size; |
1519 | auto maskedRead = vector::createReadOrMaskedRead( |
1520 | builder&: rewriter, loc, source: packOp.getSource(), readShape: inputShape, padValue: padValue, |
1521 | useInBoundsInsteadOfMasking); |
1522 | |
1523 | // Create ShapeCastOp. |
1524 | SmallVector<int64_t> destShape(inputVectorSizes); |
1525 | destShape.append(innerTiles.begin(), innerTiles.end()); |
1526 | auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), |
1527 | packOp.getDestType().getElementType()); |
1528 | auto shapeCastOp = |
1529 | rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead); |
1530 | |
1531 | // Create TransposeOp. |
1532 | auto destPermutation = |
1533 | invertPermutationVector(tensor::getPackInverseDestPerm(packOp)); |
1534 | auto transposeOp = rewriter.create<vector::TransposeOp>( |
1535 | loc, shapeCastOp.getResult(), destPermutation); |
1536 | |
1537 | // Create TransferWriteOp. |
1538 | Operation *write = |
1539 | createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), |
1540 | reifiedReturnShapes[0], inputVectorSizes); |
1541 | newResults.push_back(Elt: write->getResult(idx: 0)); |
1542 | return success(); |
1543 | } |
1544 | |
1545 | /// Vectorize a `tensor::UnPackOp` to these 4 Ops: |
1546 | /// Vector::TransferReadOp - Reads a vector from the source tensor |
1547 | /// vector::TransposeOp - Transpose the Source tensor |
1548 | /// ShapeCastOp - Reshape the data based on the target. |
1549 | /// vector::TransferWriteOp. - Write the result vector back to the destination |
1550 | /// tensor |
1551 | static LogicalResult |
1552 | vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, |
1553 | ArrayRef<int64_t> inputVectorSizes, |
1554 | SmallVectorImpl<Value> &newResults) { |
1555 | |
1556 | OpBuilder::InsertionGuard g(rewriter); |
1557 | rewriter.setInsertionPoint(unpackOp); |
1558 | |
1559 | RankedTensorType unpackTensorType = unpackOp.getSourceType(); |
1560 | |
1561 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
1562 | ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); |
1563 | |
1564 | SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(), |
1565 | inputVectorSizes.end()); |
1566 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
1567 | ArrayRef<int64_t> sourceShape = unpackTensorType.getShape(); |
1568 | |
1569 | // ReadMask is the size of tensor used to read and apply mask. It is |
1570 | // set like this: Let's say the vectorSize (VS) array is size 'N' and |
1571 | // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of |
1572 | // size M-N |
1573 | // Thus: |
1574 | // - initially: ReadMaskShape = vectorInputSizes |
1575 | // - Divide all the readMaskShape locations pointed by innerDimPos |
1576 | // by the innerTileSize attribute value. |
1577 | // - if outer_dims_perms is present: do that permutation on readMaskShape. |
1578 | // - Append the remaining shape from SS |
1579 | // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> |
1580 | // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, |
1581 | // 128] and outer_dims_perm is [1, 0] then read shape is: |
1582 | // ReadMaskShape(initial): [512, 128] |
1583 | // Final Value(after innerDim Adjustment): [512/32, 128/16] |
1584 | // = [16, 8] |
1585 | // After applying outer_dims_perm: [8, 16] |
1586 | // After appending the rest of the sourceShape: [8, 16, 32, 16] |
1587 | |
1588 | for (auto [index, size] : enumerate(innerTiles)) { |
1589 | readMaskShape[innerDimPos[index]] = |
1590 | llvm::divideCeil(readMaskShape[innerDimPos[index]], size); |
1591 | } |
1592 | if (!outerDimsPerm.empty()) { |
1593 | applyPermutationToVector(inVec&: readMaskShape, permutation: outerDimsPerm); |
1594 | } |
1595 | readMaskShape.append(in_start: sourceShape.begin() + inputVectorSizes.size(), |
1596 | in_end: sourceShape.end()); |
1597 | |
1598 | ReifiedRankedShapedTypeDims reifiedRetShapes; |
1599 | LogicalResult status = |
1600 | cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) |
1601 | .reifyResultShapes(rewriter, reifiedRetShapes); |
1602 | if (status.failed()) { |
1603 | LDBG("Unable to reify result shapes of " << unpackOp); |
1604 | return failure(); |
1605 | } |
1606 | Location loc = unpackOp->getLoc(); |
1607 | |
1608 | auto padValue = rewriter.create<arith::ConstantOp>( |
1609 | loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); |
1610 | |
1611 | // Read result, mask if necessary. If transferReadOp shape is not equal |
1612 | // to shape of source, then a mask is necessary. |
1613 | Value readResult = vector::createReadOrMaskedRead( |
1614 | builder&: rewriter, loc, source: unpackOp.getSource(), |
1615 | readShape: ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue: padValue); |
1616 | |
1617 | PackingMetadata packMetadata; |
1618 | SmallVector<int64_t> lastDimToInsertPosPerm = |
1619 | tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata); |
1620 | ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType()); |
1621 | SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape()); |
1622 | mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); |
1623 | applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm); |
1624 | RankedTensorType stripMineTensorType = |
1625 | RankedTensorType::get(stripMineShape, stripMineElemType); |
1626 | // Transpose the appropriate rows to match output. |
1627 | vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>( |
1628 | loc, readResult, lastDimToInsertPosPerm); |
1629 | |
1630 | // Collapse the vector to the size required by result. |
1631 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
1632 | stripMineTensorType, packMetadata.reassociations); |
1633 | mlir::VectorType vecCollapsedType = |
1634 | VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); |
1635 | vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( |
1636 | loc, vecCollapsedType, transposeOp->getResult(0)); |
1637 | |
1638 | // WriteMaskShape had to match the shapecast shape for dynamic sizes, |
1639 | // otherwise the validator complains that the mask size is invalid. |
1640 | SmallVector<int64_t> writeMaskShape( |
1641 | unpackOp.getDestType().hasStaticShape() |
1642 | ? inputVectorSizes |
1643 | : shapeCastOp.getResultVectorType().getShape()); |
1644 | Operation *write = |
1645 | createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), |
1646 | reifiedRetShapes[0], writeMaskShape); |
1647 | newResults.push_back(Elt: write->getResult(idx: 0)); |
1648 | return success(); |
1649 | } |
1650 | |
1651 | /// Vectorize a `padOp` with (1) static result type, (2) constant padding value |
1652 | /// and (3) all-zero lowPad to |
1653 | /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. |
1654 | static LogicalResult |
1655 | vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, |
1656 | ArrayRef<int64_t> inputVectorSizes, |
1657 | SmallVectorImpl<Value> &newResults) { |
1658 | auto padValue = padOp.getConstantPaddingValue(); |
1659 | Location loc = padOp.getLoc(); |
1660 | |
1661 | // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value)) |
1662 | OpBuilder::InsertionGuard g(rewriter); |
1663 | rewriter.setInsertionPoint(padOp); |
1664 | |
1665 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
1666 | LogicalResult status = |
1667 | cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation()) |
1668 | .reifyResultShapes(rewriter, reifiedReturnShapes); |
1669 | (void)status; // prevent unused variable warning on non-assert builds |
1670 | assert(succeeded(status) && "failed to reify result shapes" ); |
1671 | auto maskedRead = vector::createReadOrMaskedRead( |
1672 | builder&: rewriter, loc, source: padOp.getSource(), readShape: inputVectorSizes, padValue: padValue); |
1673 | Operation *write = createWriteOrMaskedWrite( |
1674 | rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes); |
1675 | newResults.push_back(Elt: write->getResult(idx: 0)); |
1676 | return success(); |
1677 | } |
1678 | |
1679 | // TODO: probably need some extra checks for reduction followed by consumer |
1680 | // ops that may not commute (e.g. linear reduction + non-linear instructions). |
1681 | static LogicalResult reductionPreconditions(LinalgOp op) { |
1682 | if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { |
1683 | LDBG("reduction precondition failed: no reduction iterator\n" ); |
1684 | return failure(); |
1685 | } |
1686 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
1687 | AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand); |
1688 | if (indexingMap.isPermutation()) |
1689 | continue; |
1690 | |
1691 | Operation *reduceOp = matchLinalgReduction(&opOperand); |
1692 | if (!reduceOp || !getCombinerOpKind(reduceOp)) { |
1693 | LDBG("reduction precondition failed: reduction detection failed\n" ); |
1694 | return failure(); |
1695 | } |
1696 | } |
1697 | return success(); |
1698 | } |
1699 | |
1700 | static LogicalResult |
1701 | vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, |
1702 | bool flatten1DDepthwiseConv) { |
1703 | if (flatten1DDepthwiseConv) { |
1704 | LDBG("Vectorization of flattened convs with dynamic shapes is not " |
1705 | "supported\n" ); |
1706 | return failure(); |
1707 | } |
1708 | |
1709 | if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) { |
1710 | LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n" ); |
1711 | return failure(); |
1712 | } |
1713 | |
1714 | // Support dynamic shapes in 1D depthwise convolution, but only in the |
1715 | // _channel_ dimension. |
1716 | Value lhs = conv.getDpsInputOperand(0)->get(); |
1717 | ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape(); |
1718 | auto shapeWithoutCh = lhsShape.drop_back(N: 1); |
1719 | if (ShapedType::isDynamicShape(shapeWithoutCh)) { |
1720 | LDBG("Dynamically-shaped op vectorization precondition failed: only " |
1721 | "channel dim can be dynamic\n" ); |
1722 | return failure(); |
1723 | } |
1724 | |
1725 | return success(); |
1726 | } |
1727 | |
1728 | static LogicalResult |
1729 | vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, |
1730 | bool flatten1DDepthwiseConv) { |
1731 | if (isa<ConvolutionOpInterface>(op.getOperation())) |
1732 | return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv); |
1733 | |
1734 | // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops, |
1735 | // linalg.copy ops and ops that implement ContractionOpInterface for now. |
1736 | if (!isElementwise(op) && |
1737 | !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>( |
1738 | op.getOperation())) |
1739 | return failure(); |
1740 | |
1741 | LDBG("Dynamically-shaped op meets vectorization pre-conditions\n" ); |
1742 | return success(); |
1743 | } |
1744 | |
1745 | /// Need to check if the inner-tiles are static/constant. |
1746 | static LogicalResult |
1747 | vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, |
1748 | ArrayRef<int64_t> inputVectorSizes) { |
1749 | |
1750 | if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { |
1751 | return !getConstantIntValue(ofr: res).has_value(); |
1752 | })) { |
1753 | LDBG("Inner-tiles must be constant: " << unpackOp << "\n" ); |
1754 | return failure(); |
1755 | } |
1756 | llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); |
1757 | if (!inputVectorSizes.empty() && |
1758 | failed(result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes))) |
1759 | return failure(); |
1760 | |
1761 | return success(); |
1762 | } |
1763 | |
1764 | static LogicalResult vectorizeLinalgOpPrecondition( |
1765 | LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes, |
1766 | bool , bool flatten1DDepthwiseConv) { |
1767 | // tensor with dimension of 0 cannot be vectorized. |
1768 | if (llvm::is_contained(linalgOp.getStaticShape(), 0)) |
1769 | return failure(); |
1770 | // Check API contract for input vector sizes. |
1771 | if (!inputVectorSizes.empty() && |
1772 | failed(vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(), |
1773 | inputVectorSizes))) |
1774 | return failure(); |
1775 | |
1776 | if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition( |
1777 | linalgOp, flatten1DDepthwiseConv))) { |
1778 | LDBG("Dynamically-shaped op failed vectorization pre-conditions\n" ); |
1779 | return failure(); |
1780 | } |
1781 | |
1782 | SmallVector<CustomVectorizationPrecondition> customPreconditions; |
1783 | |
1784 | // Register CustomVectorizationPrecondition for extractOp. |
1785 | customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition); |
1786 | |
1787 | // All types in the body should be a supported element type for VectorType. |
1788 | for (Operation &innerOp : linalgOp->getRegion(0).front()) { |
1789 | // Check if any custom hook can vectorize the inner op. |
1790 | if (llvm::any_of( |
1791 | customPreconditions, |
1792 | [&](const CustomVectorizationPrecondition &customPrecondition) { |
1793 | return succeeded( |
1794 | customPrecondition(&innerOp, vectorizeNDExtract)); |
1795 | })) { |
1796 | continue; |
1797 | } |
1798 | if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) { |
1799 | return !VectorType::isValidElementType(type); |
1800 | })) { |
1801 | return failure(); |
1802 | } |
1803 | if (llvm::any_of(innerOp.getResultTypes(), [](Type type) { |
1804 | return !VectorType::isValidElementType(type); |
1805 | })) { |
1806 | return failure(); |
1807 | } |
1808 | } |
1809 | if (isElementwise(linalgOp)) |
1810 | return success(); |
1811 | |
1812 | // TODO: isaConvolutionOpInterface that can also infer from generic |
1813 | // features. But we will still need stride/dilation attributes that will be |
1814 | // annoying to reverse-engineer... |
1815 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) |
1816 | return success(); |
1817 | // TODO: the common vector shape is equal to the static loop sizes only when |
1818 | // all indexing maps are projected permutations. For convs and stencils the |
1819 | // logic will need to evolve. |
1820 | if (!allIndexingsAreProjectedPermutation(linalgOp)) { |
1821 | LDBG("precondition failed: not projected permutations\n" ); |
1822 | return failure(); |
1823 | } |
1824 | if (failed(reductionPreconditions(linalgOp))) { |
1825 | LDBG("precondition failed: reduction preconditions\n" ); |
1826 | return failure(); |
1827 | } |
1828 | return success(); |
1829 | } |
1830 | |
1831 | static LogicalResult |
1832 | vectorizePackOpPrecondition(tensor::PackOp packOp, |
1833 | ArrayRef<int64_t> inputVectorSizes) { |
1834 | auto padValue = packOp.getPaddingValue(); |
1835 | Attribute cstAttr; |
1836 | if (padValue && !matchPattern(padValue, m_Constant(bind_value: &cstAttr))) { |
1837 | LDBG("pad value is not constant: " << packOp << "\n" ); |
1838 | return failure(); |
1839 | } |
1840 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
1841 | bool satisfyEmptyCond = true; |
1842 | if (inputVectorSizes.empty()) { |
1843 | if (!packOp.getDestType().hasStaticShape() || |
1844 | !packOp.getSourceType().hasStaticShape()) |
1845 | satisfyEmptyCond = false; |
1846 | } |
1847 | |
1848 | if (!satisfyEmptyCond && |
1849 | failed(vector::isValidMaskedInputVector( |
1850 | shape: resultTensorShape.take_front(N: packOp.getSourceRank()), |
1851 | inputVectorSizes))) |
1852 | return failure(); |
1853 | |
1854 | if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { |
1855 | return !getConstantIntValue(ofr: v).has_value(); |
1856 | })) { |
1857 | LDBG("inner_tiles must be constant: " << packOp << "\n" ); |
1858 | return failure(); |
1859 | } |
1860 | |
1861 | return success(); |
1862 | } |
1863 | |
1864 | static LogicalResult |
1865 | vectorizePadOpPrecondition(tensor::PadOp padOp, |
1866 | ArrayRef<int64_t> inputVectorSizes) { |
1867 | auto padValue = padOp.getConstantPaddingValue(); |
1868 | if (!padValue) { |
1869 | LDBG("pad value is not constant: " << padOp << "\n" ); |
1870 | return failure(); |
1871 | } |
1872 | |
1873 | ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape(); |
1874 | if (failed(result: vector::isValidMaskedInputVector(shape: resultTensorShape, |
1875 | inputVectorSizes))) |
1876 | return failure(); |
1877 | |
1878 | if (llvm::any_of(padOp.getLow(), [](Value v) { |
1879 | std::optional<int64_t> res = getConstantIntValue(ofr: v); |
1880 | return !res.has_value() || res.value() != 0; |
1881 | })) { |
1882 | LDBG("low pad must all be zero: " << padOp << "\n" ); |
1883 | return failure(); |
1884 | } |
1885 | |
1886 | return success(); |
1887 | } |
1888 | |
1889 | /// Preconditions for scalable vectors. |
1890 | static LogicalResult |
1891 | vectorizeScalableVectorPrecondition(Operation *op, |
1892 | ArrayRef<int64_t> inputVectorSizes, |
1893 | ArrayRef<bool> inputScalableVecDims) { |
1894 | assert(inputVectorSizes.size() == inputScalableVecDims.size() && |
1895 | "Number of input vector sizes and scalable dims doesn't match" ); |
1896 | |
1897 | if (inputVectorSizes.empty()) |
1898 | return success(); |
1899 | |
1900 | bool isScalable = inputScalableVecDims.back(); |
1901 | if (!isScalable) |
1902 | return success(); |
1903 | |
1904 | // Only element-wise and 1d depthwise conv ops supported in the presence of |
1905 | // scalable dims. |
1906 | auto linalgOp = dyn_cast<LinalgOp>(op); |
1907 | return success(linalgOp && (isElementwise(linalgOp) || |
1908 | isa<linalg::DepthwiseConv1DNwcWcOp>(op))); |
1909 | } |
1910 | |
1911 | LogicalResult mlir::linalg::vectorizeOpPrecondition( |
1912 | Operation *op, ArrayRef<int64_t> inputVectorSizes, |
1913 | ArrayRef<bool> inputScalableVecDims, bool , |
1914 | bool flatten1DDepthwiseConv) { |
1915 | if (failed(result: vectorizeScalableVectorPrecondition(op, inputVectorSizes, |
1916 | inputScalableVecDims))) |
1917 | return failure(); |
1918 | |
1919 | return TypeSwitch<Operation *, LogicalResult>(op) |
1920 | .Case<linalg::LinalgOp>([&](auto linalgOp) { |
1921 | return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, |
1922 | vectorizeNDExtract, |
1923 | flatten1DDepthwiseConv); |
1924 | }) |
1925 | .Case<tensor::PadOp>([&](auto padOp) { |
1926 | return vectorizePadOpPrecondition(padOp, inputVectorSizes); |
1927 | }) |
1928 | .Case<tensor::PackOp>([&](auto packOp) { |
1929 | return vectorizePackOpPrecondition(packOp, inputVectorSizes); |
1930 | }) |
1931 | .Case<tensor::UnPackOp>([&](auto unpackOp) { |
1932 | return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); |
1933 | }) |
1934 | .Default([](auto) { return failure(); }); |
1935 | } |
1936 | |
1937 | /// Converts affine.apply Ops to arithmetic operations. |
1938 | static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { |
1939 | OpBuilder::InsertionGuard g(rewriter); |
1940 | auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>(); |
1941 | |
1942 | for (auto op : make_early_inc_range(toReplace)) { |
1943 | rewriter.setInsertionPoint(op); |
1944 | auto expanded = affine::expandAffineExpr( |
1945 | rewriter, op->getLoc(), op.getAffineMap().getResult(0), |
1946 | op.getOperands().take_front(op.getAffineMap().getNumDims()), |
1947 | op.getOperands().take_back(op.getAffineMap().getNumSymbols())); |
1948 | rewriter.replaceOp(op, expanded); |
1949 | } |
1950 | } |
1951 | |
1952 | /// Emit a suitable vector form for an operation. If provided, |
1953 | /// `inputVectorSizes` are used to vectorize this operation. |
1954 | /// `inputVectorSizes` must match the rank of the iteration space of the |
1955 | /// operation and the input vector sizes must be greater than or equal to |
1956 | /// their counterpart iteration space sizes, if static. `inputVectorShapes` |
1957 | /// also allows the vectorization of operations with dynamic shapes. |
1958 | LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, |
1959 | ArrayRef<int64_t> inputVectorSizes, |
1960 | ArrayRef<bool> inputScalableVecDims, |
1961 | bool , |
1962 | bool flatten1DDepthwiseConv) { |
1963 | LDBG("Attempting to vectorize:\n" << *op << "\n" ); |
1964 | LDBG("Input vector sizes: " ); |
1965 | LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); |
1966 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
1967 | LDBG("Input scalable vector dims: " ); |
1968 | LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); |
1969 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
1970 | |
1971 | if (failed(result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, |
1972 | vectorizeNDExtract, |
1973 | flatten1DDepthwiseConv))) { |
1974 | LDBG("Vectorization pre-conditions failed\n" ); |
1975 | return failure(); |
1976 | } |
1977 | |
1978 | // Initialize vectorization state. |
1979 | VectorizationState state(rewriter); |
1980 | if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { |
1981 | if (failed(state.initState(rewriter, linalgOp: linalgOp, inputVectorSizes, |
1982 | inputScalableVecDims))) { |
1983 | LDBG("Vectorization state couldn't be initialized\n" ); |
1984 | return failure(); |
1985 | } |
1986 | } |
1987 | |
1988 | SmallVector<Value> results; |
1989 | auto vectorizeResult = |
1990 | TypeSwitch<Operation *, LogicalResult>(op) |
1991 | .Case<linalg::LinalgOp>([&](auto linalgOp) { |
1992 | // TODO: isaConvolutionOpInterface that can also infer from |
1993 | // generic features. Will require stride/dilation attributes |
1994 | // inference. |
1995 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { |
1996 | FailureOr<Operation *> convOr = vectorizeConvolution( |
1997 | rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, |
1998 | flatten1DDepthwiseConv); |
1999 | if (succeeded(convOr)) { |
2000 | llvm::append_range(results, (*convOr)->getResults()); |
2001 | return success(); |
2002 | } |
2003 | |
2004 | LDBG("Unsupported convolution can't be vectorized.\n" ); |
2005 | return failure(); |
2006 | } |
2007 | |
2008 | LDBG("Vectorize generic by broadcasting to the canonical vector " |
2009 | "shape\n" ); |
2010 | |
2011 | // Pre-process before proceeding. |
2012 | convertAffineApply(rewriter, linalgOp); |
2013 | |
2014 | // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted |
2015 | // to 'OpBuilder' when it is passed over to some methods like |
2016 | // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we |
2017 | // erase an op within these methods, the actual rewriter won't be |
2018 | // notified and we will end up with read-after-free issues! |
2019 | return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); |
2020 | }) |
2021 | .Case<tensor::PadOp>([&](auto padOp) { |
2022 | return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, |
2023 | results); |
2024 | }) |
2025 | .Case<tensor::PackOp>([&](auto packOp) { |
2026 | return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, |
2027 | results); |
2028 | }) |
2029 | .Case<tensor::UnPackOp>([&](auto unpackOp) { |
2030 | return vectorizeAsTensorUnpackOp(rewriter, unpackOp, |
2031 | inputVectorSizes, results); |
2032 | }) |
2033 | .Default([](auto) { return failure(); }); |
2034 | |
2035 | if (failed(vectorizeResult)) { |
2036 | LDBG("Vectorization failed\n" ); |
2037 | return failure(); |
2038 | } |
2039 | |
2040 | if (!results.empty()) |
2041 | rewriter.replaceOp(op, newValues: results); |
2042 | else |
2043 | rewriter.eraseOp(op); |
2044 | |
2045 | return success(); |
2046 | } |
2047 | |
2048 | LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, |
2049 | memref::CopyOp copyOp) { |
2050 | auto srcType = cast<MemRefType>(copyOp.getSource().getType()); |
2051 | auto dstType = cast<MemRefType>(copyOp.getTarget().getType()); |
2052 | if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) |
2053 | return failure(); |
2054 | |
2055 | auto srcElementType = getElementTypeOrSelf(srcType); |
2056 | auto dstElementType = getElementTypeOrSelf(dstType); |
2057 | if (!VectorType::isValidElementType(srcElementType) || |
2058 | !VectorType::isValidElementType(dstElementType)) |
2059 | return failure(); |
2060 | |
2061 | auto readType = VectorType::get(srcType.getShape(), srcElementType); |
2062 | auto writeType = VectorType::get(dstType.getShape(), dstElementType); |
2063 | |
2064 | Location loc = copyOp->getLoc(); |
2065 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
2066 | SmallVector<Value> indices(srcType.getRank(), zero); |
2067 | |
2068 | Value readValue = rewriter.create<vector::TransferReadOp>( |
2069 | loc, readType, copyOp.getSource(), indices, |
2070 | rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
2071 | if (cast<VectorType>(readValue.getType()).getRank() == 0) { |
2072 | readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue); |
2073 | readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue); |
2074 | } |
2075 | Operation *writeValue = rewriter.create<vector::TransferWriteOp>( |
2076 | loc, readValue, copyOp.getTarget(), indices, |
2077 | rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
2078 | rewriter.replaceOp(copyOp, writeValue->getResults()); |
2079 | return success(); |
2080 | } |
2081 | |
2082 | //----------------------------------------------------------------------------// |
2083 | // Misc. vectorization patterns. |
2084 | //----------------------------------------------------------------------------// |
2085 | |
2086 | /// Helper function that retrieves the value of an IntegerAttr. |
2087 | static int64_t getIntFromAttr(Attribute attr) { |
2088 | return cast<IntegerAttr>(attr).getInt(); |
2089 | } |
2090 | |
2091 | /// Given an ArrayRef of OpFoldResults, return a vector of Values. |
2092 | /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are |
2093 | /// not supported. |
2094 | static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc, |
2095 | ArrayRef<OpFoldResult> ofrs) { |
2096 | SmallVector<Value> result; |
2097 | for (auto o : ofrs) { |
2098 | if (auto val = llvm::dyn_cast_if_present<Value>(Val&: o)) { |
2099 | result.push_back(Elt: val); |
2100 | } else { |
2101 | result.push_back(rewriter.create<arith::ConstantIndexOp>( |
2102 | location: loc, args: getIntFromAttr(attr: o.template get<Attribute>()))); |
2103 | } |
2104 | } |
2105 | return result; |
2106 | } |
2107 | |
2108 | /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and |
2109 | /// InsertSliceOp. For now, only constant padding values are supported. |
2110 | /// If there is enough static type information, TransferReadOps and |
2111 | /// TransferWriteOps may be generated instead of InsertSliceOps. |
2112 | struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern { |
2113 | GenericPadOpVectorizationPattern(MLIRContext *context, |
2114 | PatternBenefit benefit = 1) |
2115 | : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {} |
2116 | /// Vectorize the copying of a tensor::PadOp's source. This is possible if |
2117 | /// each dimension size is statically know in the source type or the result |
2118 | /// type (or both). |
2119 | static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, |
2120 | tensor::PadOp padOp, Value dest) { |
2121 | auto sourceType = padOp.getSourceType(); |
2122 | auto resultType = padOp.getResultType(); |
2123 | if (!VectorType::isValidElementType(sourceType.getElementType())) |
2124 | return failure(); |
2125 | |
2126 | // Copy cannot be vectorized if pad value is non-constant and source shape |
2127 | // is dynamic. In case of a dynamic source shape, padding must be appended |
2128 | // by TransferReadOp, but TransferReadOp supports only constant padding. |
2129 | auto padValue = padOp.getConstantPaddingValue(); |
2130 | if (!padValue) { |
2131 | if (!sourceType.hasStaticShape()) |
2132 | return failure(); |
2133 | // Create dummy padding value. |
2134 | auto elemType = sourceType.getElementType(); |
2135 | padValue = rewriter.create<arith::ConstantOp>( |
2136 | padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); |
2137 | } |
2138 | |
2139 | SmallVector<int64_t> vecShape; |
2140 | SmallVector<bool> readInBounds; |
2141 | SmallVector<bool> writeInBounds; |
2142 | for (unsigned i = 0; i < sourceType.getRank(); ++i) { |
2143 | if (!sourceType.isDynamicDim(i)) { |
2144 | vecShape.push_back(Elt: sourceType.getDimSize(i)); |
2145 | // Source shape is statically known: Neither read nor write are |
2146 | // out-of- bounds. |
2147 | readInBounds.push_back(Elt: true); |
2148 | writeInBounds.push_back(Elt: true); |
2149 | } else if (!resultType.isDynamicDim(i)) { |
2150 | // Source shape is not statically known, but result shape is. |
2151 | // Vectorize with size of result shape. This may be larger than the |
2152 | // source size. |
2153 | vecShape.push_back(Elt: resultType.getDimSize(i)); |
2154 | // Read may be out-of-bounds because the result size could be larger |
2155 | // than the source size. |
2156 | readInBounds.push_back(Elt: false); |
2157 | // Write is out-of-bounds if low padding > 0. |
2158 | writeInBounds.push_back( |
2159 | Elt: getConstantIntValue(padOp.getMixedLowPad()[i]) == |
2160 | static_cast<int64_t>(0)); |
2161 | } else { |
2162 | // Neither source nor result dim of padOp is static. Cannot vectorize |
2163 | // the copy. |
2164 | return failure(); |
2165 | } |
2166 | } |
2167 | auto vecType = VectorType::get(vecShape, sourceType.getElementType()); |
2168 | |
2169 | // Generate TransferReadOp. |
2170 | SmallVector<Value> readIndices( |
2171 | vecType.getRank(), |
2172 | rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0)); |
2173 | auto read = rewriter.create<vector::TransferReadOp>( |
2174 | padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue, |
2175 | ArrayRef<bool>{readInBounds}); |
2176 | |
2177 | // If `dest` is a FillOp and the TransferWriteOp would overwrite the |
2178 | // entire tensor, write directly to the FillOp's operand. |
2179 | if (llvm::equal(vecShape, resultType.getShape()) && |
2180 | llvm::all_of(Range&: writeInBounds, P: [](bool b) { return b; })) |
2181 | if (auto fill = dest.getDefiningOp<FillOp>()) |
2182 | dest = fill.output(); |
2183 | |
2184 | // Generate TransferWriteOp. |
2185 | auto writeIndices = |
2186 | ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); |
2187 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
2188 | padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds}); |
2189 | |
2190 | return success(); |
2191 | } |
2192 | }; |
2193 | |
2194 | /// Base pattern for rewriting tensor::PadOps whose result is consumed by a |
2195 | /// given operation type OpTy. |
2196 | template <typename OpTy> |
2197 | struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> { |
2198 | using OpRewritePattern<tensor::PadOp>::OpRewritePattern; |
2199 | |
2200 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
2201 | PatternRewriter &rewriter) const final { |
2202 | bool changed = false; |
2203 | // Insert users in vector, because some users may be replaced/removed. |
2204 | for (auto *user : llvm::to_vector<4>(padOp->getUsers())) |
2205 | if (auto op = dyn_cast<OpTy>(user)) |
2206 | changed |= rewriteUser(rewriter, padOp, op).succeeded(); |
2207 | return success(isSuccess: changed); |
2208 | } |
2209 | |
2210 | protected: |
2211 | virtual LogicalResult rewriteUser(PatternRewriter &rewriter, |
2212 | tensor::PadOp padOp, OpTy op) const = 0; |
2213 | }; |
2214 | |
2215 | /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.: |
2216 | /// ``` |
2217 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
2218 | /// %r = vector.transfer_read %0[%c0, %c0], %cst |
2219 | /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> |
2220 | /// ``` |
2221 | /// is rewritten to: |
2222 | /// ``` |
2223 | /// %r = vector.transfer_read %src[%c0, %c0], %padding |
2224 | /// {in_bounds = [true, true]} |
2225 | /// : tensor<?x?xf32>, vector<17x5xf32> |
2226 | /// ``` |
2227 | /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be |
2228 | /// sure that the original padding value %cst was never used. |
2229 | /// |
2230 | /// This rewrite is possible if: |
2231 | /// - `xferOp` has no out-of-bounds dims or mask. |
2232 | /// - Low padding is static 0. |
2233 | /// - Single, scalar padding value. |
2234 | struct PadOpVectorizationWithTransferReadPattern |
2235 | : public VectorizePadOpUserPattern<vector::TransferReadOp> { |
2236 | using VectorizePadOpUserPattern< |
2237 | vector::TransferReadOp>::VectorizePadOpUserPattern; |
2238 | |
2239 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
2240 | vector::TransferReadOp xferOp) const override { |
2241 | // Low padding must be static 0. |
2242 | if (!padOp.hasZeroLowPad()) |
2243 | return failure(); |
2244 | // Pad value must be a constant. |
2245 | auto padValue = padOp.getConstantPaddingValue(); |
2246 | if (!padValue) |
2247 | return failure(); |
2248 | // Padding value of existing `xferOp` is unused. |
2249 | if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) |
2250 | return failure(); |
2251 | |
2252 | rewriter.modifyOpInPlace(xferOp, [&]() { |
2253 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
2254 | xferOp->setAttr(xferOp.getInBoundsAttrName(), |
2255 | rewriter.getBoolArrayAttr(inBounds)); |
2256 | xferOp.getSourceMutable().assign(padOp.getSource()); |
2257 | xferOp.getPaddingMutable().assign(padValue); |
2258 | }); |
2259 | |
2260 | return success(); |
2261 | } |
2262 | }; |
2263 | |
2264 | /// Rewrite use of tensor::PadOp result in TransferWriteOp. |
2265 | /// This pattern rewrites TransferWriteOps that write to a padded tensor |
2266 | /// value, where the same amount of padding is immediately removed again after |
2267 | /// the write. In such cases, the TransferWriteOp can write to the non-padded |
2268 | /// tensor value and apply out-of-bounds masking. E.g.: |
2269 | /// ``` |
2270 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
2271 | /// : tensor<...> to tensor<?x?xf32> |
2272 | /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32> |
2273 | /// %2 = vector.transfer_write %vec, %1[...] |
2274 | /// : vector<17x5xf32>, tensor<17x5xf32> |
2275 | /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1] |
2276 | /// : tensor<17x5xf32> to tensor<?x?xf32> |
2277 | /// ``` |
2278 | /// is rewritten to: |
2279 | /// ``` |
2280 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
2281 | /// : tensor<...> to tensor<?x?xf32> |
2282 | /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, |
2283 | /// tensor<?x?xf32> |
2284 | /// ``` |
2285 | /// Note: It is important that the ExtractSliceOp %r resizes the result of the |
2286 | /// TransferWriteOp to the same size as the input of the TensorPadOp (or an |
2287 | /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ |
2288 | /// from %r's old dimensions. |
2289 | /// |
2290 | /// This rewrite is possible if: |
2291 | /// - Low padding is static 0. |
2292 | /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This |
2293 | /// ExtractSliceOp trims the same amount of padding that was added |
2294 | /// beforehand. |
2295 | /// - Single, scalar padding value. |
2296 | struct PadOpVectorizationWithTransferWritePattern |
2297 | : public VectorizePadOpUserPattern<vector::TransferWriteOp> { |
2298 | using VectorizePadOpUserPattern< |
2299 | vector::TransferWriteOp>::VectorizePadOpUserPattern; |
2300 | |
2301 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
2302 | vector::TransferWriteOp xferOp) const override { |
2303 | // TODO: support 0-d corner case. |
2304 | if (xferOp.getTransferRank() == 0) |
2305 | return failure(); |
2306 | |
2307 | // Low padding must be static 0. |
2308 | if (!padOp.hasZeroLowPad()) |
2309 | return failure(); |
2310 | // Pad value must be a constant. |
2311 | auto padValue = padOp.getConstantPaddingValue(); |
2312 | if (!padValue) |
2313 | return failure(); |
2314 | // TransferWriteOp result must be directly consumed by an ExtractSliceOp. |
2315 | if (!xferOp->hasOneUse()) |
2316 | return failure(); |
2317 | auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin()); |
2318 | if (!trimPadding) |
2319 | return failure(); |
2320 | // Only static zero offsets supported when trimming padding. |
2321 | if (!trimPadding.hasZeroOffset()) |
2322 | return failure(); |
2323 | // trimPadding must remove the amount of padding that was added earlier. |
2324 | if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding)) |
2325 | return failure(); |
2326 | |
2327 | // Insert the new TransferWriteOp at position of the old TransferWriteOp. |
2328 | rewriter.setInsertionPoint(xferOp); |
2329 | |
2330 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
2331 | auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
2332 | xferOp, padOp.getSource().getType(), xferOp.getVector(), |
2333 | padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(), |
2334 | xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds)); |
2335 | rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); |
2336 | |
2337 | return success(); |
2338 | } |
2339 | |
2340 | /// Check if `beforePadding` and `afterTrimming` have the same tensor size, |
2341 | /// i.e., same dimensions. |
2342 | /// |
2343 | /// Dimensions may be static, dynamic or mix of both. In case of dynamic |
2344 | /// dimensions, this function tries to infer the (static) tensor size by |
2345 | /// looking at the defining op and utilizing op-specific knowledge. |
2346 | /// |
2347 | /// This is a conservative analysis. In case equal tensor sizes cannot be |
2348 | /// proven statically, this analysis returns `false` even though the tensor |
2349 | /// sizes may turn out to be equal at runtime. |
2350 | bool (Value beforePadding, |
2351 | tensor::ExtractSliceOp afterTrimming) const { |
2352 | // If the input to tensor::PadOp is a CastOp, try with both CastOp |
2353 | // result and CastOp operand. |
2354 | if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>()) |
2355 | if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming: afterTrimming)) |
2356 | return true; |
2357 | |
2358 | auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType()); |
2359 | auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType()); |
2360 | // Only RankedTensorType supported. |
2361 | if (!t1 || !t2) |
2362 | return false; |
2363 | // Rank of both values must be the same. |
2364 | if (t1.getRank() != t2.getRank()) |
2365 | return false; |
2366 | |
2367 | // All static dimensions must be the same. Mixed cases (e.g., dimension |
2368 | // static in `t1` but dynamic in `t2`) are not supported. |
2369 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
2370 | if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) |
2371 | return false; |
2372 | if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) |
2373 | return false; |
2374 | } |
2375 | |
2376 | // Nothing more to check if all dimensions are static. |
2377 | if (t1.getNumDynamicDims() == 0) |
2378 | return true; |
2379 | |
2380 | // All dynamic sizes must be the same. The only supported case at the |
2381 | // moment is when `beforePadding` is an ExtractSliceOp (or a cast |
2382 | // thereof). |
2383 | |
2384 | // Apart from CastOp, only ExtractSliceOp is supported. |
2385 | auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>(); |
2386 | if (!beforeSlice) |
2387 | return false; |
2388 | |
2389 | assert(static_cast<size_t>(t1.getRank()) == |
2390 | beforeSlice.getMixedSizes().size()); |
2391 | assert(static_cast<size_t>(t2.getRank()) == |
2392 | afterTrimming.getMixedSizes().size()); |
2393 | |
2394 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
2395 | // Skip static dimensions. |
2396 | if (!t1.isDynamicDim(i)) |
2397 | continue; |
2398 | auto size1 = beforeSlice.getMixedSizes()[i]; |
2399 | auto size2 = afterTrimming.getMixedSizes()[i]; |
2400 | |
2401 | // Case 1: Same value or same constant int. |
2402 | if (isEqualConstantIntOrValue(size1, size2)) |
2403 | continue; |
2404 | |
2405 | // Other cases: Take a deeper look at defining ops of values. |
2406 | auto v1 = llvm::dyn_cast_if_present<Value>(size1); |
2407 | auto v2 = llvm::dyn_cast_if_present<Value>(size2); |
2408 | if (!v1 || !v2) |
2409 | return false; |
2410 | |
2411 | // Case 2: Both values are identical AffineMinOps. (Should not happen if |
2412 | // CSE is run.) |
2413 | auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>(); |
2414 | auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>(); |
2415 | if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && |
2416 | minOp1.getOperands() == minOp2.getOperands()) |
2417 | continue; |
2418 | |
2419 | // Add additional cases as needed. |
2420 | } |
2421 | |
2422 | // All tests passed. |
2423 | return true; |
2424 | } |
2425 | }; |
2426 | |
2427 | /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.: |
2428 | /// ``` |
2429 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
2430 | /// %r = tensor.insert_slice %0 |
2431 | /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] |
2432 | /// : tensor<17x5xf32> into tensor<?x?x17x5xf32> |
2433 | /// ``` |
2434 | /// is rewritten to: |
2435 | /// ``` |
2436 | /// %0 = vector.transfer_read %src[%c0, %c0], %padding |
2437 | /// : tensor<?x?xf32>, vector<17x5xf32> |
2438 | /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] |
2439 | /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32> |
2440 | /// ``` |
2441 | /// |
2442 | /// This rewrite is possible if: |
2443 | /// - Low padding is static 0. |
2444 | /// - `padOp` result shape is static. |
2445 | /// - The entire padded tensor is inserted. |
2446 | /// (Implies that sizes of `insertOp` are all static.) |
2447 | /// - Only unit strides in `insertOp`. |
2448 | /// - Single, scalar padding value. |
2449 | /// - `padOp` result not used as destination. |
2450 | struct PadOpVectorizationWithInsertSlicePattern |
2451 | : public VectorizePadOpUserPattern<tensor::InsertSliceOp> { |
2452 | using VectorizePadOpUserPattern< |
2453 | tensor::InsertSliceOp>::VectorizePadOpUserPattern; |
2454 | |
2455 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
2456 | tensor::InsertSliceOp insertOp) const override { |
2457 | // Low padding must be static 0. |
2458 | if (!padOp.hasZeroLowPad()) |
2459 | return failure(); |
2460 | // Only unit stride supported. |
2461 | if (!insertOp.hasUnitStride()) |
2462 | return failure(); |
2463 | // Pad value must be a constant. |
2464 | auto padValue = padOp.getConstantPaddingValue(); |
2465 | if (!padValue) |
2466 | return failure(); |
2467 | // Dynamic shapes not supported. |
2468 | if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape()) |
2469 | return failure(); |
2470 | // Pad result not used as destination. |
2471 | if (insertOp.getDest() == padOp.getResult()) |
2472 | return failure(); |
2473 | |
2474 | auto vecType = VectorType::get(padOp.getType().getShape(), |
2475 | padOp.getType().getElementType()); |
2476 | unsigned vecRank = vecType.getRank(); |
2477 | unsigned tensorRank = insertOp.getType().getRank(); |
2478 | |
2479 | // Check if sizes match: Insert the entire tensor into most minor dims. |
2480 | // (No permutations allowed.) |
2481 | SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1); |
2482 | expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); |
2483 | if (!llvm::all_of( |
2484 | llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { |
2485 | return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); |
2486 | })) |
2487 | return failure(); |
2488 | |
2489 | // Insert the TransferReadOp and TransferWriteOp at the position of the |
2490 | // InsertSliceOp. |
2491 | rewriter.setInsertionPoint(insertOp); |
2492 | |
2493 | // Generate TransferReadOp: Read entire source tensor and add high |
2494 | // padding. |
2495 | SmallVector<Value> readIndices( |
2496 | vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0)); |
2497 | auto read = rewriter.create<vector::TransferReadOp>( |
2498 | padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); |
2499 | |
2500 | // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at |
2501 | // specified offsets. Write is fully in-bounds because a InsertSliceOp's |
2502 | // source must fit into the destination at the specified offsets. |
2503 | auto writeIndices = |
2504 | ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); |
2505 | SmallVector<bool> inBounds(vecRank, true); |
2506 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
2507 | insertOp, read, insertOp.getDest(), writeIndices, |
2508 | ArrayRef<bool>{inBounds}); |
2509 | |
2510 | return success(); |
2511 | } |
2512 | }; |
2513 | |
2514 | void mlir::linalg::populatePadOpVectorizationPatterns( |
2515 | RewritePatternSet &patterns, PatternBenefit baseBenefit) { |
2516 | patterns.add<GenericPadOpVectorizationPattern>(arg: patterns.getContext(), |
2517 | args&: baseBenefit); |
2518 | // Try these specialized patterns first before resorting to the generic one. |
2519 | patterns.add<PadOpVectorizationWithTransferReadPattern, |
2520 | PadOpVectorizationWithTransferWritePattern, |
2521 | PadOpVectorizationWithInsertSlicePattern>( |
2522 | arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1); |
2523 | } |
2524 | |
2525 | //----------------------------------------------------------------------------// |
2526 | // Forwarding patterns |
2527 | //----------------------------------------------------------------------------// |
2528 | |
2529 | /// Check whether there is any interleaved use of any `values` between |
2530 | /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value |
2531 | /// is in a different block. |
2532 | static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, |
2533 | ValueRange values) { |
2534 | if (firstOp->getBlock() != secondOp->getBlock() || |
2535 | !firstOp->isBeforeInBlock(other: secondOp)) { |
2536 | LDBG("interleavedUses precondition failed, firstOp: " |
2537 | << *firstOp << ", second op: " << *secondOp << "\n" ); |
2538 | return true; |
2539 | } |
2540 | for (auto v : values) { |
2541 | for (auto &u : v.getUses()) { |
2542 | Operation *owner = u.getOwner(); |
2543 | if (owner == firstOp || owner == secondOp) |
2544 | continue; |
2545 | // TODO: this is too conservative, use dominance info in the future. |
2546 | if (owner->getBlock() == firstOp->getBlock() && |
2547 | (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner))) |
2548 | continue; |
2549 | LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp |
2550 | << ", second op: " << *secondOp << "\n" ); |
2551 | return true; |
2552 | } |
2553 | } |
2554 | return false; |
2555 | } |
2556 | |
2557 | /// Return the unique subview use of `v` if it is indeed unique, null |
2558 | /// otherwise. |
2559 | static memref::SubViewOp getSubViewUseIfUnique(Value v) { |
2560 | memref::SubViewOp subViewOp; |
2561 | for (auto &u : v.getUses()) { |
2562 | if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) { |
2563 | if (subViewOp) |
2564 | return memref::SubViewOp(); |
2565 | subViewOp = newSubViewOp; |
2566 | } |
2567 | } |
2568 | return subViewOp; |
2569 | } |
2570 | |
2571 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
2572 | /// when available. |
2573 | LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( |
2574 | vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { |
2575 | |
2576 | // TODO: support mask. |
2577 | if (xferOp.getMask()) |
2578 | return rewriter.notifyMatchFailure(xferOp, "unsupported mask" ); |
2579 | |
2580 | // Transfer into `view`. |
2581 | Value viewOrAlloc = xferOp.getSource(); |
2582 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
2583 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
2584 | return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc" ); |
2585 | |
2586 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
2587 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); |
2588 | if (!subViewOp) |
2589 | return rewriter.notifyMatchFailure(xferOp, "no subview found" ); |
2590 | Value subView = subViewOp.getResult(); |
2591 | |
2592 | // Find the copy into `subView` without interleaved uses. |
2593 | memref::CopyOp copyOp; |
2594 | for (auto &u : subView.getUses()) { |
2595 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) { |
2596 | assert(isa<MemRefType>(newCopyOp.getTarget().getType())); |
2597 | if (newCopyOp.getTarget() != subView) |
2598 | continue; |
2599 | if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) |
2600 | continue; |
2601 | copyOp = newCopyOp; |
2602 | break; |
2603 | } |
2604 | } |
2605 | if (!copyOp) |
2606 | return rewriter.notifyMatchFailure(xferOp, "no copy found" ); |
2607 | |
2608 | // Find the fill into `viewOrAlloc` without interleaved uses before the |
2609 | // copy. |
2610 | FillOp maybeFillOp; |
2611 | for (auto &u : viewOrAlloc.getUses()) { |
2612 | if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { |
2613 | assert(isa<MemRefType>(newFillOp.output().getType())); |
2614 | if (newFillOp.output() != viewOrAlloc) |
2615 | continue; |
2616 | if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) |
2617 | continue; |
2618 | maybeFillOp = newFillOp; |
2619 | break; |
2620 | } |
2621 | } |
2622 | // Ensure padding matches. |
2623 | if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) |
2624 | return rewriter.notifyMatchFailure(xferOp, |
2625 | "padding value does not match fill" ); |
2626 | |
2627 | // `in` is the subview that memref.copy reads. Replace it. |
2628 | Value in = copyOp.getSource(); |
2629 | |
2630 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
2631 | // The `masked` attribute is only valid on this padded buffer. |
2632 | // When forwarding to vector.transfer_read, the attribute must be reset |
2633 | // conservatively. |
2634 | Value res = rewriter.create<vector::TransferReadOp>( |
2635 | xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(), |
2636 | xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), |
2637 | // in_bounds is explicitly reset |
2638 | /*inBoundsAttr=*/ArrayAttr()); |
2639 | |
2640 | if (maybeFillOp) |
2641 | rewriter.eraseOp(op: maybeFillOp); |
2642 | rewriter.eraseOp(op: copyOp); |
2643 | rewriter.replaceOp(xferOp, res); |
2644 | |
2645 | return success(); |
2646 | } |
2647 | |
2648 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
2649 | /// when available. |
2650 | LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( |
2651 | vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { |
2652 | // TODO: support mask. |
2653 | if (xferOp.getMask()) |
2654 | return rewriter.notifyMatchFailure(xferOp, "unsupported mask" ); |
2655 | |
2656 | // Transfer into `viewOrAlloc`. |
2657 | Value viewOrAlloc = xferOp.getSource(); |
2658 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
2659 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
2660 | return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc" ); |
2661 | |
2662 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
2663 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); |
2664 | if (!subViewOp) |
2665 | return rewriter.notifyMatchFailure(xferOp, "no subview found" ); |
2666 | Value subView = subViewOp.getResult(); |
2667 | |
2668 | // Find the copy from `subView` without interleaved uses. |
2669 | memref::CopyOp copyOp; |
2670 | for (auto &u : subViewOp.getResult().getUses()) { |
2671 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) { |
2672 | if (newCopyOp.getSource() != subView) |
2673 | continue; |
2674 | if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) |
2675 | continue; |
2676 | copyOp = newCopyOp; |
2677 | break; |
2678 | } |
2679 | } |
2680 | if (!copyOp) |
2681 | return rewriter.notifyMatchFailure(xferOp, "no copy found" ); |
2682 | |
2683 | // `out` is the subview copied into that we replace. |
2684 | assert(isa<MemRefType>(copyOp.getTarget().getType())); |
2685 | Value out = copyOp.getTarget(); |
2686 | |
2687 | // Forward vector.transfer into copy. |
2688 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
2689 | // The `masked` attribute is only valid on this padded buffer. |
2690 | // When forwarding to vector.transfer_write, the attribute must be reset |
2691 | // conservatively. |
2692 | rewriter.create<vector::TransferWriteOp>( |
2693 | xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(), |
2694 | xferOp.getPermutationMapAttr(), xferOp.getMask(), |
2695 | // in_bounds is explicitly reset |
2696 | /*inBoundsAttr=*/ArrayAttr()); |
2697 | |
2698 | rewriter.eraseOp(op: copyOp); |
2699 | rewriter.eraseOp(op: xferOp); |
2700 | |
2701 | return success(); |
2702 | } |
2703 | |
2704 | //===----------------------------------------------------------------------===// |
2705 | // Convolution vectorization patterns |
2706 | //===----------------------------------------------------------------------===// |
2707 | |
2708 | template <int N> |
2709 | static void bindShapeDims(ShapedType shapedType) {} |
2710 | |
2711 | template <int N, typename IntTy, typename... IntTy2> |
2712 | static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { |
2713 | val = shapedType.getShape()[N]; |
2714 | bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...); |
2715 | } |
2716 | |
2717 | /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). |
2718 | template <typename... IntTy> |
2719 | static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { |
2720 | bindShapeDims<0>(shapedType, vals...); |
2721 | } |
2722 | |
2723 | namespace { |
2724 | bool isCastOfBlockArgument(Operation *op) { |
2725 | return isa<CastOpInterface>(op) && op->getNumOperands() == 1 && |
2726 | isa<BlockArgument>(op->getOperand(0)); |
2727 | } |
2728 | |
2729 | bool isSupportedPoolKind(vector::CombiningKind kind) { |
2730 | switch (kind) { |
2731 | case vector::CombiningKind::ADD: |
2732 | case vector::CombiningKind::MAXNUMF: |
2733 | case vector::CombiningKind::MAXIMUMF: |
2734 | case vector::CombiningKind::MAXSI: |
2735 | case vector::CombiningKind::MAXUI: |
2736 | case vector::CombiningKind::MINNUMF: |
2737 | case vector::CombiningKind::MINIMUMF: |
2738 | case vector::CombiningKind::MINSI: |
2739 | case vector::CombiningKind::MINUI: |
2740 | return true; |
2741 | default: |
2742 | return false; |
2743 | } |
2744 | } |
2745 | |
2746 | /// Generate a vector implementation for either: |
2747 | /// ``` |
2748 | /// Op def: ( w, kw ) |
2749 | /// Iters: ({Par(), Red()}) |
2750 | /// Layout: {{w + kw}, {kw}, {w}} |
2751 | /// ``` |
2752 | /// kw is unrolled. |
2753 | /// |
2754 | /// or |
2755 | /// |
2756 | /// ``` |
2757 | /// Op def: ( n, w, c, kw, f ) |
2758 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
2759 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
2760 | /// ``` |
2761 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
2762 | /// |
2763 | /// or |
2764 | /// |
2765 | /// ``` |
2766 | /// Op def: ( n, c, w, f, kw ) |
2767 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
2768 | /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
2769 | /// ``` |
2770 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
2771 | /// |
2772 | /// or |
2773 | /// |
2774 | /// ``` |
2775 | /// Op def: ( n, w, c, kw ) |
2776 | /// Iters: ({Par(), Par(), Par(), Red()}) |
2777 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
2778 | /// ``` |
2779 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
2780 | struct Conv1DGenerator |
2781 | : public StructuredGenerator<LinalgOp, utils::IteratorType> { |
2782 | Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW, |
2783 | int dilationW) |
2784 | : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp), |
2785 | strideW(strideW), dilationW(dilationW) { |
2786 | // Determine whether `linalgOp` can be generated with this generator |
2787 | if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) |
2788 | return; |
2789 | lhsShaped = linalgOp.getDpsInputOperand(0)->get(); |
2790 | rhsShaped = linalgOp.getDpsInputOperand(1)->get(); |
2791 | resShaped = linalgOp.getDpsInitOperand(0)->get(); |
2792 | lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType()); |
2793 | rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType()); |
2794 | resShapedType = dyn_cast<ShapedType>(resShaped.getType()); |
2795 | if (!lhsShapedType || !rhsShapedType || !resShapedType) |
2796 | return; |
2797 | // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR |
2798 | // (non-channeled convolution -> LHS and RHS both have single dimensions). |
2799 | if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && |
2800 | (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) |
2801 | return; |
2802 | |
2803 | Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); |
2804 | if (!reduceOp) |
2805 | return; |
2806 | redOp = reduceOp->getName().getIdentifier(); |
2807 | |
2808 | if (!setOperKind(reduceOp)) |
2809 | return; |
2810 | auto maybeKind = getCombinerOpKind(reduceOp); |
2811 | if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD && |
2812 | (oper != Pool || !isSupportedPoolKind(*maybeKind)))) { |
2813 | return; |
2814 | } |
2815 | |
2816 | auto rhsRank = rhsShapedType.getRank(); |
2817 | switch (oper) { |
2818 | case Conv: |
2819 | if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) |
2820 | return; |
2821 | break; |
2822 | case Pool: |
2823 | if (rhsRank != 1) |
2824 | return; |
2825 | break; |
2826 | } |
2827 | // The op is now known to be valid. |
2828 | valid = true; |
2829 | } |
2830 | |
2831 | /// Generate a vector implementation for: |
2832 | /// ``` |
2833 | /// Op def: ( w, kw ) |
2834 | /// Iters: ({Par(), Red()}) |
2835 | /// Layout: {{w + kw}, {kw}, {w}} |
2836 | /// ``` |
2837 | /// kw is always unrolled. |
2838 | /// |
2839 | /// or |
2840 | /// |
2841 | /// ``` |
2842 | /// Op def: ( n, w, c, kw, f ) |
2843 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
2844 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
2845 | /// ``` |
2846 | /// kw is always unrolled. |
2847 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
2848 | /// > 1. |
2849 | FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) { |
2850 | if (!valid) |
2851 | return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool" ); |
2852 | |
2853 | int64_t nSize, wSize, cSize, kwSize, fSize; |
2854 | SmallVector<int64_t, 3> lhsShape, rhsShape, resShape; |
2855 | bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); |
2856 | switch (conv1DOpOrder) { |
2857 | case Conv1DOpOrder::W: |
2858 | // Initialize unused dimensions |
2859 | nSize = fSize = cSize = 0; |
2860 | // out{W} |
2861 | bindShapeDims(resShapedType, wSize); |
2862 | // kernel{kw} |
2863 | bindShapeDims(rhsShapedType, kwSize); |
2864 | lhsShape = {// iw = ow + kw - 1 |
2865 | // (i.e. 16 convolved with 3 -> 14) |
2866 | (wSize + kwSize - 1)}; |
2867 | rhsShape = {kwSize}; |
2868 | resShape = {wSize}; |
2869 | break; |
2870 | case Conv1DOpOrder::Nwc: |
2871 | // out{n, w, f} |
2872 | bindShapeDims(resShapedType, nSize, wSize, fSize); |
2873 | switch (oper) { |
2874 | case Conv: |
2875 | // kernel{kw, c, f} |
2876 | bindShapeDims(rhsShapedType, kwSize, cSize); |
2877 | break; |
2878 | case Pool: |
2879 | // kernel{kw} |
2880 | bindShapeDims(rhsShapedType, kwSize); |
2881 | cSize = fSize; |
2882 | break; |
2883 | } |
2884 | lhsShape = {nSize, |
2885 | // iw = ow * sw + kw * dw - 1 |
2886 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
2887 | // Perform the proper inclusive -> exclusive -> inclusive. |
2888 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
2889 | 1, |
2890 | cSize}; |
2891 | switch (oper) { |
2892 | case Conv: |
2893 | rhsShape = {kwSize, cSize, fSize}; |
2894 | break; |
2895 | case Pool: |
2896 | rhsShape = {kwSize}; |
2897 | break; |
2898 | } |
2899 | resShape = {nSize, wSize, fSize}; |
2900 | break; |
2901 | case Conv1DOpOrder::Ncw: |
2902 | // out{n, f, w} |
2903 | bindShapeDims(resShapedType, nSize, fSize, wSize); |
2904 | switch (oper) { |
2905 | case Conv: |
2906 | // kernel{f, c, kw} |
2907 | bindShapeDims(rhsShapedType, fSize, cSize, kwSize); |
2908 | break; |
2909 | case Pool: |
2910 | // kernel{kw} |
2911 | bindShapeDims(rhsShapedType, kwSize); |
2912 | cSize = fSize; |
2913 | break; |
2914 | } |
2915 | lhsShape = {nSize, cSize, |
2916 | // iw = ow * sw + kw * dw - 1 |
2917 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
2918 | // Perform the proper inclusive -> exclusive -> inclusive. |
2919 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
2920 | 1}; |
2921 | switch (oper) { |
2922 | case Conv: |
2923 | rhsShape = {fSize, cSize, kwSize}; |
2924 | break; |
2925 | case Pool: |
2926 | rhsShape = {kwSize}; |
2927 | break; |
2928 | } |
2929 | resShape = {nSize, fSize, wSize}; |
2930 | break; |
2931 | } |
2932 | |
2933 | vector::TransferWriteOp write; |
2934 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
2935 | |
2936 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
2937 | // When strideW == 1, we can batch the contiguous loads and avoid |
2938 | // unrolling |
2939 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
2940 | |
2941 | Type lhsEltType = lhsShapedType.getElementType(); |
2942 | Type rhsEltType = rhsShapedType.getElementType(); |
2943 | Type resEltType = resShapedType.getElementType(); |
2944 | auto lhsType = VectorType::get(lhsShape, lhsEltType); |
2945 | auto rhsType = VectorType::get(rhsShape, rhsEltType); |
2946 | auto resType = VectorType::get(resShape, resEltType); |
2947 | // Zero padding with the corresponding dimensions for lhs, rhs and res. |
2948 | SmallVector<Value> lhsPadding(lhsShape.size(), zero); |
2949 | SmallVector<Value> rhsPadding(rhsShape.size(), zero); |
2950 | SmallVector<Value> resPadding(resShape.size(), zero); |
2951 | |
2952 | // Read the whole lhs, rhs and res in one shot (with zero padding). |
2953 | Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped, |
2954 | lhsPadding); |
2955 | // This is needed only for Conv. |
2956 | Value rhs = nullptr; |
2957 | if (oper == Conv) |
2958 | rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, |
2959 | rhsPadding); |
2960 | Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped, |
2961 | resPadding); |
2962 | |
2963 | // The base vectorization case for channeled convolution is input: |
2964 | // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern |
2965 | // vectorization case, we do pre transpose on input, weight, and output. |
2966 | switch (conv1DOpOrder) { |
2967 | case Conv1DOpOrder::W: |
2968 | case Conv1DOpOrder::Nwc: |
2969 | // Base case, so no transposes necessary. |
2970 | break; |
2971 | case Conv1DOpOrder::Ncw: { |
2972 | // To match base vectorization case, we pre-transpose current case. |
2973 | // ncw -> nwc |
2974 | static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1}; |
2975 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs); |
2976 | // fcw -> wcf |
2977 | static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0}; |
2978 | |
2979 | // This is needed only for Conv. |
2980 | if (oper == Conv) |
2981 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs); |
2982 | // nfw -> nwf |
2983 | static constexpr std::array<int64_t, 3> permRes = {0, 2, 1}; |
2984 | res = rewriter.create<vector::TransposeOp>(loc, res, permRes); |
2985 | break; |
2986 | } |
2987 | } |
2988 | |
2989 | //===------------------------------------------------------------------===// |
2990 | // Begin vector-only rewrite part |
2991 | //===------------------------------------------------------------------===// |
2992 | // Unroll along kw and read slices of lhs and rhs. |
2993 | SmallVector<Value> lhsVals, rhsVals, resVals; |
2994 | lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize, |
2995 | kwSize, strideW, dilationW, wSizeStep, |
2996 | isSingleChanneled); |
2997 | // Do not do for pooling. |
2998 | if (oper == Conv) |
2999 | rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); |
3000 | resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, |
3001 | wSizeStep, isSingleChanneled); |
3002 | |
3003 | auto linearIndex = [&](int64_t kw, int64_t w) { |
3004 | return kw * (wSize / wSizeStep) + w; |
3005 | }; |
3006 | |
3007 | // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} |
3008 | // or perform outerproduct for non-channeled convolution or perform simple |
3009 | // arith operation for pooling |
3010 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3011 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3012 | switch (oper) { |
3013 | case Conv: |
3014 | if (isSingleChanneled) { |
3015 | resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, |
3016 | lhsVals[linearIndex(kw, w)], |
3017 | rhsVals[kw], resVals[w]); |
3018 | } else { |
3019 | resVals[w] = conv1dSliceAsContraction(rewriter, loc, |
3020 | lhsVals[linearIndex(kw, w)], |
3021 | rhsVals[kw], resVals[w]); |
3022 | } |
3023 | break; |
3024 | case Pool: |
3025 | resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], |
3026 | resVals[w]); |
3027 | break; |
3028 | } |
3029 | } |
3030 | } |
3031 | |
3032 | res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, |
3033 | isSingleChanneled); |
3034 | //===------------------------------------------------------------------===// |
3035 | // End vector-only rewrite part |
3036 | //===------------------------------------------------------------------===// |
3037 | |
3038 | // The base vectorization case for channeled convolution is output: |
3039 | // {n,w,f} To reuse the result from base pattern vectorization case, we |
3040 | // post transpose the base case result. |
3041 | switch (conv1DOpOrder) { |
3042 | case Conv1DOpOrder::W: |
3043 | case Conv1DOpOrder::Nwc: |
3044 | // Base case, so no transposes necessary. |
3045 | break; |
3046 | case Conv1DOpOrder::Ncw: { |
3047 | // nwf -> nfw |
3048 | static constexpr std::array<int64_t, 3> perm = {0, 2, 1}; |
3049 | res = rewriter.create<vector::TransposeOp>(loc, res, perm); |
3050 | break; |
3051 | } |
3052 | } |
3053 | |
3054 | return rewriter |
3055 | .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding) |
3056 | .getOperation(); |
3057 | } |
3058 | |
3059 | // Take a value and widen to have the same element type as `ty`. |
3060 | Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { |
3061 | const Type srcElementType = getElementTypeOrSelf(type: val.getType()); |
3062 | const Type dstElementType = getElementTypeOrSelf(type: ty); |
3063 | assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType)); |
3064 | if (srcElementType == dstElementType) |
3065 | return val; |
3066 | |
3067 | const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth(); |
3068 | const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth(); |
3069 | const Type dstType = |
3070 | cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType); |
3071 | |
3072 | if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) { |
3073 | return rewriter.create<arith::SIToFPOp>(loc, dstType, val); |
3074 | } |
3075 | |
3076 | if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) && |
3077 | srcWidth < dstWidth) |
3078 | return rewriter.create<arith::ExtFOp>(loc, dstType, val); |
3079 | |
3080 | if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) && |
3081 | srcWidth < dstWidth) |
3082 | return rewriter.create<arith::ExtSIOp>(loc, dstType, val); |
3083 | |
3084 | assert(false && "unhandled promotion case" ); |
3085 | return nullptr; |
3086 | } |
3087 | |
3088 | // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} |
3089 | Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, |
3090 | Value lhs, Value rhs, Value res) { |
3091 | vector::IteratorType par = vector::IteratorType::parallel; |
3092 | vector::IteratorType red = vector::IteratorType::reduction; |
3093 | AffineExpr n, w, f, c; |
3094 | bindDims(ctx, n, w, f, c); |
3095 | lhs = promote(rewriter, loc, val: lhs, ty: res.getType()); |
3096 | rhs = promote(rewriter, loc, val: rhs, ty: res.getType()); |
3097 | return rewriter.create<vector::ContractionOp>( |
3098 | loc, lhs, rhs, res, |
3099 | /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, |
3100 | /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red}); |
3101 | } |
3102 | |
3103 | // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel |
3104 | // convolution. |
3105 | Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, |
3106 | Value lhs, Value rhs, Value res) { |
3107 | return rewriter.create<vector::OuterProductOp>( |
3108 | loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); |
3109 | } |
3110 | |
3111 | // Create a reduction: lhs{n, w, c} -> res{n, w, c} |
3112 | Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, |
3113 | Value res) { |
3114 | if (isPoolExt) |
3115 | lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0); |
3116 | return rewriter |
3117 | .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType()) |
3118 | ->getResult(0); |
3119 | } |
3120 | |
3121 | /// Generate a vector implementation for: |
3122 | /// ``` |
3123 | /// Op def: ( n, w, c, kw) |
3124 | /// Iters: ({Par(), Par(), Par(), Red()}) |
3125 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
3126 | /// ``` |
3127 | /// kw is always unrolled. |
3128 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
3129 | /// > 1. |
3130 | FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize, |
3131 | bool channelDimScalableFlag, |
3132 | bool flatten) { |
3133 | if (!valid) |
3134 | return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv" ); |
3135 | |
3136 | bool scalableChDim = false; |
3137 | bool useMasking = false; |
3138 | int64_t nSize, wSize, cSize, kwSize; |
3139 | // kernel{kw, c} |
3140 | bindShapeDims(rhsShapedType, kwSize, cSize); |
3141 | if (ShapedType::isDynamic(cSize)) { |
3142 | assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0" ); |
3143 | cSize = channelDimVecSize; |
3144 | // Scalable vectors are only used when both conditions are met: |
3145 | // 1. channel dim is dynamic |
3146 | // 2. channelDimScalableFlag is set |
3147 | scalableChDim = channelDimScalableFlag; |
3148 | useMasking = true; |
3149 | } |
3150 | |
3151 | assert(!(useMasking && flatten) && |
3152 | "Unsupported flattened conv with dynamic shapes" ); |
3153 | |
3154 | // out{n, w, c} |
3155 | bindShapeDims(resShapedType, nSize, wSize); |
3156 | |
3157 | vector::TransferWriteOp write; |
3158 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
3159 | |
3160 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
3161 | // When strideW == 1, we can batch the contiguous loads and avoid |
3162 | // unrolling |
3163 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
3164 | |
3165 | Type lhsEltType = lhsShapedType.getElementType(); |
3166 | Type rhsEltType = rhsShapedType.getElementType(); |
3167 | Type resEltType = resShapedType.getElementType(); |
3168 | VectorType lhsType = VectorType::get( |
3169 | {nSize, |
3170 | // iw = ow * sw + kw * dw - 1 |
3171 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
3172 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, |
3173 | cSize}, |
3174 | lhsEltType, /*scalableDims=*/{false, false, scalableChDim}); |
3175 | VectorType rhsType = |
3176 | VectorType::get({kwSize, cSize}, rhsEltType, |
3177 | /*scalableDims=*/{false, scalableChDim}); |
3178 | VectorType resType = |
3179 | VectorType::get({nSize, wSize, cSize}, resEltType, |
3180 | /*scalableDims=*/{false, false, scalableChDim}); |
3181 | |
3182 | // Masks the input xfer Op along the channel dim, iff the corresponding |
3183 | // scalable flag is set. |
3184 | auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape, |
3185 | ArrayRef<bool> scalableDims, |
3186 | Operation *opToMask) { |
3187 | if (!useMasking) |
3188 | return opToMask; |
3189 | auto maskType = |
3190 | VectorType::get(maskShape, rewriter.getI1Type(), scalableDims); |
3191 | |
3192 | SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer( |
3193 | cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter); |
3194 | |
3195 | Value maskOp = |
3196 | rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims); |
3197 | |
3198 | return mlir::vector::maskOperation(rewriter, opToMask, maskOp); |
3199 | }; |
3200 | |
3201 | // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, |
3202 | // 0]. |
3203 | Value lhs = rewriter.create<vector::TransferReadOp>( |
3204 | loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); |
3205 | auto maybeMaskedLhs = maybeMaskXferOp( |
3206 | lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); |
3207 | |
3208 | // Read rhs slice of size {kw, c} @ [0, 0]. |
3209 | Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, |
3210 | ValueRange{zero, zero}); |
3211 | auto maybeMaskedRhs = maybeMaskXferOp( |
3212 | rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); |
3213 | |
3214 | // Read res slice of size {n, w, c} @ [0, 0, 0]. |
3215 | Value res = rewriter.create<vector::TransferReadOp>( |
3216 | loc, resType, resShaped, ValueRange{zero, zero, zero}); |
3217 | auto maybeMaskedRes = maybeMaskXferOp( |
3218 | resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); |
3219 | |
3220 | //===------------------------------------------------------------------===// |
3221 | // Begin vector-only rewrite part |
3222 | //===------------------------------------------------------------------===// |
3223 | // Unroll along kw and read slices of lhs and rhs. |
3224 | SmallVector<Value> lhsVals, rhsVals, resVals; |
3225 | auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize}; |
3226 | auto inOutStrides = SmallVector<int64_t>{1, 1, 1}; |
3227 | |
3228 | // Extract lhs slice of size {n, wSizeStep, c} |
3229 | // @ [0, sw * w + dw * kw, 0]. |
3230 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3231 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3232 | lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
3233 | loc, maybeMaskedLhs->getResult(0), |
3234 | /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
3235 | inOutSliceSizes, inOutStrides)); |
3236 | } |
3237 | } |
3238 | // Extract rhs slice of size {c} @ [kw]. |
3239 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3240 | rhsVals.push_back(rewriter.create<vector::ExtractOp>( |
3241 | loc, maybeMaskedRhs->getResult(0), |
3242 | /*offsets=*/ArrayRef<int64_t>{kw})); |
3243 | } |
3244 | // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. |
3245 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3246 | resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
3247 | loc, maybeMaskedRes->getResult(0), |
3248 | /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes, |
3249 | inOutStrides)); |
3250 | } |
3251 | |
3252 | auto linearIndex = [&](int64_t kw, int64_t w) { |
3253 | return kw * (wSize / wSizeStep) + w; |
3254 | }; |
3255 | |
3256 | // Note - the scalable flags are ignored as flattening combined with |
3257 | // scalable vectorization is not supported. |
3258 | auto inOutFlattenSliceSizes = |
3259 | SmallVector<int64_t>{nSize, wSizeStep * cSize}; |
3260 | auto lhsTypeAfterFlattening = |
3261 | VectorType::get(inOutFlattenSliceSizes, lhsEltType); |
3262 | auto resTypeAfterFlattening = |
3263 | VectorType::get(inOutFlattenSliceSizes, resEltType); |
3264 | |
3265 | // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} |
3266 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3267 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3268 | Value lhsVal = lhsVals[linearIndex(kw, w)]; |
3269 | Value resVal = resVals[w]; |
3270 | if (flatten) { |
3271 | // Flatten the input and output vectors (collapse the channel |
3272 | // dimension) |
3273 | lhsVal = rewriter.create<vector::ShapeCastOp>( |
3274 | loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]); |
3275 | resVal = rewriter.create<vector::ShapeCastOp>( |
3276 | loc, resTypeAfterFlattening, resVals[w]); |
3277 | } |
3278 | resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal, |
3279 | rhsVals[kw], resVal, flatten); |
3280 | if (flatten) { |
3281 | // Un-flatten the output vector (restore the channel dimension) |
3282 | resVals[w] = rewriter.create<vector::ShapeCastOp>( |
3283 | loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); |
3284 | } |
3285 | } |
3286 | } |
3287 | |
3288 | // Its possible we failed to create the Fma. |
3289 | if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) { |
3290 | // Manually revert (in reverse order) to avoid leaving a bad IR state. |
3291 | for (auto &collection : |
3292 | {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) |
3293 | for (Value v : collection) |
3294 | rewriter.eraseOp(v.getDefiningOp()); |
3295 | return rewriter.notifyMatchFailure(op, "failed to create FMA" ); |
3296 | } |
3297 | |
3298 | // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. |
3299 | // This does not depend on kw. |
3300 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3301 | maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>( |
3302 | loc, resVals[w], maybeMaskedRes->getResult(0), |
3303 | /*offsets=*/ArrayRef<int64_t>{0, w, 0}, |
3304 | /*strides=*/ArrayRef<int64_t>{1, 1, 1}); |
3305 | } |
3306 | //===------------------------------------------------------------------===// |
3307 | // End vector-only rewrite part |
3308 | //===------------------------------------------------------------------===// |
3309 | |
3310 | // Write back res slice of size {n, w, c} @ [0, 0, 0]. |
3311 | Operation *resOut = rewriter.create<vector::TransferWriteOp>( |
3312 | loc, maybeMaskedRes->getResult(0), resShaped, |
3313 | ValueRange{zero, zero, zero}); |
3314 | return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(), |
3315 | resOut); |
3316 | } |
3317 | |
3318 | /// Lower: |
3319 | /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false) |
3320 | /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true) |
3321 | /// to MulAcc. |
3322 | Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, |
3323 | Value lhs, Value rhs, Value res, |
3324 | bool flatten) { |
3325 | auto rhsTy = cast<ShapedType>(rhs.getType()); |
3326 | auto resTy = cast<ShapedType>(res.getType()); |
3327 | |
3328 | // TODO(suderman): Change this to use a vector.ima intrinsic. |
3329 | lhs = promote(rewriter, loc, val: lhs, ty: resTy); |
3330 | |
3331 | if (flatten) { |
3332 | // NOTE: This following logic won't work for scalable vectors. For this |
3333 | // reason, "flattening" is not supported when shapes are dynamic (this |
3334 | // should be captured by one of the pre-conditions). |
3335 | |
3336 | // There are two options for handling the filter: |
3337 | // * shape_cast(broadcast(filter)) |
3338 | // * broadcast(shuffle(filter)) |
3339 | // Opt for the option without shape_cast to simplify the codegen. |
3340 | auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0]; |
3341 | auto resSize = cast<VectorType>(res.getType()).getShape()[1]; |
3342 | |
3343 | SmallVector<int64_t, 16> indicies; |
3344 | for (int i = 0; i < resSize / rhsSize; ++i) { |
3345 | for (int j = 0; j < rhsSize; ++j) |
3346 | indicies.push_back(Elt: j); |
3347 | } |
3348 | |
3349 | rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies); |
3350 | } |
3351 | // Broadcast the filter to match the output vector |
3352 | rhs = rewriter.create<vector::BroadcastOp>( |
3353 | loc, resTy.clone(rhsTy.getElementType()), rhs); |
3354 | |
3355 | rhs = promote(rewriter, loc, val: rhs, ty: resTy); |
3356 | |
3357 | if (!lhs || !rhs) |
3358 | return nullptr; |
3359 | |
3360 | if (isa<FloatType>(resTy.getElementType())) |
3361 | return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res); |
3362 | |
3363 | auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs); |
3364 | return rewriter.create<arith::AddIOp>(loc, mul, res); |
3365 | } |
3366 | |
3367 | /// Entry point for non-channeled convolution: |
3368 | /// {{w + kw}, {kw}, {w}} |
3369 | FailureOr<Operation *> generateNonChanneledConv() { |
3370 | AffineExpr w, kw; |
3371 | bindDims(ctx, w, kw); |
3372 | if (!iters({Par(), Red()})) |
3373 | return rewriter.notifyMatchFailure(op, |
3374 | "failed to match conv::W 1-par 1-red" ); |
3375 | |
3376 | // No transposition needed. |
3377 | if (layout({/*lhsIndex*/ {w + kw}, |
3378 | /*rhsIndex*/ {kw}, |
3379 | /*resIndex*/ {w}})) |
3380 | return conv(conv1DOpOrder: Conv1DOpOrder::W); |
3381 | |
3382 | return rewriter.notifyMatchFailure(op, "not a conv::W layout" ); |
3383 | } |
3384 | |
3385 | /// Entry point that transposes into the common form: |
3386 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
3387 | FailureOr<Operation *> generateNwcConv() { |
3388 | AffineExpr n, w, f, kw, c; |
3389 | bindDims(ctx, n, w, f, kw, c); |
3390 | if (!iters({Par(), Par(), Par(), Red(), Red()})) |
3391 | return rewriter.notifyMatchFailure( |
3392 | op, "failed to match conv::Nwc 3-par 2-red" ); |
3393 | |
3394 | // No transposition needed. |
3395 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
3396 | /*rhsIndex*/ {kw, c, f}, |
3397 | /*resIndex*/ {n, w, f}})) |
3398 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
3399 | |
3400 | return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout" ); |
3401 | } |
3402 | |
3403 | /// Entry point that transposes into the common form: |
3404 | /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
3405 | FailureOr<Operation *> generateNcwConv() { |
3406 | AffineExpr n, w, f, kw, c; |
3407 | bindDims(ctx, n, f, w, c, kw); |
3408 | if (!iters({Par(), Par(), Par(), Red(), Red()})) |
3409 | return rewriter.notifyMatchFailure( |
3410 | op, "failed to match conv::Ncw 3-par 2-red" ); |
3411 | |
3412 | if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
3413 | /*rhsIndex*/ {f, c, kw}, |
3414 | /*resIndex*/ {n, f, w}})) |
3415 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
3416 | |
3417 | return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout" ); |
3418 | } |
3419 | |
3420 | /// Entry point that transposes into the common form: |
3421 | /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling |
3422 | FailureOr<Operation *> generateNwcPooling() { |
3423 | AffineExpr n, w, c, kw; |
3424 | bindDims(ctx, n, w, c, kw); |
3425 | if (!iters({Par(), Par(), Par(), Red()})) |
3426 | return rewriter.notifyMatchFailure(op, |
3427 | "failed to match pooling 3-par 1-red" ); |
3428 | |
3429 | // No transposition needed. |
3430 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
3431 | /*rhsIndex*/ {kw}, |
3432 | /*resIndex*/ {n, w, c}})) |
3433 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
3434 | |
3435 | return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout" ); |
3436 | } |
3437 | |
3438 | /// Entry point that transposes into the common form: |
3439 | /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling |
3440 | FailureOr<Operation *> generateNcwPooling() { |
3441 | AffineExpr n, w, c, kw; |
3442 | bindDims(ctx, n, c, w, kw); |
3443 | if (!iters({Par(), Par(), Par(), Red()})) |
3444 | return rewriter.notifyMatchFailure(op, |
3445 | "failed to match pooling 3-par 1-red" ); |
3446 | |
3447 | if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
3448 | /*rhsIndex*/ {kw}, |
3449 | /*resIndex*/ {n, c, w}})) |
3450 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
3451 | |
3452 | return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout" ); |
3453 | } |
3454 | |
3455 | /// Entry point that transposes into the common form: |
3456 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
3457 | FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0, |
3458 | bool vecChDimScalableFlag = false, |
3459 | bool flatten = false) { |
3460 | AffineExpr n, w, c, kw; |
3461 | bindDims(ctx, n, w, c, kw); |
3462 | if (!iters({Par(), Par(), Par(), Red()})) |
3463 | return rewriter.notifyMatchFailure( |
3464 | op, "failed to match depthwise::Nwc conv 3-par 1-red" ); |
3465 | |
3466 | // No transposition needed. |
3467 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
3468 | /*rhsIndex*/ {kw, c}, |
3469 | /*resIndex*/ {n, w, c}})) |
3470 | return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten); |
3471 | |
3472 | return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout" ); |
3473 | } |
3474 | |
3475 | private: |
3476 | enum OperKind { Conv, Pool }; |
3477 | bool valid = false; |
3478 | OperKind oper = Conv; |
3479 | StringAttr redOp; |
3480 | StringAttr poolExtOp; |
3481 | bool isPoolExt = false; |
3482 | int strideW, dilationW; |
3483 | Value lhsShaped, rhsShaped, resShaped; |
3484 | ShapedType lhsShapedType, rhsShapedType, resShapedType; |
3485 | |
3486 | // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. |
3487 | // Returns true iff it is a valid conv/pooling op. |
3488 | // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction |
3489 | // + yield) and rhs is not used) then it is the body of a pooling |
3490 | // If conv, check for single `mul` predecessor. The `mul` operands must be |
3491 | // block arguments or extension of block arguments. |
3492 | // Otherwise, check for one or zero `ext` predecessor. The `ext` operands |
3493 | // must be block arguments or extension of block arguments. |
3494 | bool setOperKind(Operation *reduceOp) { |
3495 | int numBlockArguments = |
3496 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
3497 | switch (numBlockArguments) { |
3498 | case 1: { |
3499 | // Will be convolution if feeder is a MulOp. |
3500 | // Otherwise, if it can be pooling. |
3501 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
3502 | P: llvm::IsaPred<BlockArgument>); |
3503 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
3504 | if (isCastOfBlockArgument(op: feedOp)) { |
3505 | oper = Pool; |
3506 | isPoolExt = true; |
3507 | poolExtOp = feedOp->getName().getIdentifier(); |
3508 | } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) && |
3509 | llvm::all_of(feedOp->getOperands(), [](Value v) { |
3510 | if (isa<BlockArgument>(v)) |
3511 | return true; |
3512 | if (Operation *op = v.getDefiningOp()) |
3513 | return isCastOfBlockArgument(op); |
3514 | return false; |
3515 | }))) { |
3516 | return false; |
3517 | } |
3518 | return true; |
3519 | } |
3520 | case 2: |
3521 | // Must be pooling |
3522 | oper = Pool; |
3523 | isPoolExt = false; |
3524 | return true; |
3525 | default: |
3526 | return false; |
3527 | } |
3528 | } |
3529 | }; |
3530 | } // namespace |
3531 | |
3532 | /// Helper function to vectorize a LinalgOp with convolution semantics. |
3533 | // TODO: extend the generic vectorization to support windows and drop this. |
3534 | static FailureOr<Operation *> vectorizeConvolution( |
3535 | RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes, |
3536 | ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) { |
3537 | // The ConvolutionOpInterface gives us guarantees of existence for |
3538 | // strides/dilations. However, we do not need to rely on those, we can |
3539 | // simply use them if present, otherwise use the default and let the generic |
3540 | // conv. matcher in the ConvGenerator succeed or fail. |
3541 | auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides" ); |
3542 | auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations" ); |
3543 | auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1; |
3544 | auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1; |
3545 | Conv1DGenerator e(rewriter, op, stride, dilation); |
3546 | auto res = e.generateNonChanneledConv(); |
3547 | if (succeeded(res)) |
3548 | return res; |
3549 | res = e.generateNwcConv(); |
3550 | if (succeeded(res)) |
3551 | return res; |
3552 | res = e.generateNcwConv(); |
3553 | if (succeeded(res)) |
3554 | return res; |
3555 | res = e.generateNwcPooling(); |
3556 | if (succeeded(res)) |
3557 | return res; |
3558 | res = e.generateNcwPooling(); |
3559 | if (succeeded(res)) |
3560 | return res; |
3561 | |
3562 | // Only depthwise 1D NWC convs are left - these can be vectorized using masks |
3563 | // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e. |
3564 | // masked/scalable) is the channel dim (i.e. the trailing dim). |
3565 | uint64_t vecChDimSize = ShapedType::kDynamic; |
3566 | bool vecChDimScalableFlag = false; |
3567 | if (!inputVecSizes.empty()) { |
3568 | // Only use the input vector size corresponding to the channel dim. Other |
3569 | // vector dims will be inferred from the Ops. |
3570 | assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) || |
3571 | isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) && |
3572 | "Not a 1D depthwise conv!" ); |
3573 | size_t chDimIdx = |
3574 | TypeSwitch<Operation *, size_t>(op) |
3575 | .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; }) |
3576 | .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; }); |
3577 | |
3578 | vecChDimSize = inputVecSizes[chDimIdx]; |
3579 | vecChDimScalableFlag = inputScalableVecDims[chDimIdx]; |
3580 | } |
3581 | return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag, |
3582 | flatten: flatten1DDepthwiseConv); |
3583 | } |
3584 | |
3585 | struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> { |
3586 | using OpInterfaceRewritePattern::OpInterfaceRewritePattern; |
3587 | |
3588 | LogicalResult matchAndRewrite(LinalgOp op, |
3589 | PatternRewriter &rewriter) const override { |
3590 | FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op); |
3591 | if (failed(result: resultOrFail)) |
3592 | return failure(); |
3593 | Operation *newOp = *resultOrFail; |
3594 | if (newOp->getNumResults() == 0) { |
3595 | rewriter.eraseOp(op: op.getOperation()); |
3596 | return success(); |
3597 | } |
3598 | assert(newOp->getNumResults() == 1 && "expected single result" ); |
3599 | rewriter.replaceOp(op.getOperation(), newOp->getResult(idx: 0)); |
3600 | return success(); |
3601 | } |
3602 | }; |
3603 | |
3604 | void mlir::linalg::populateConvolutionVectorizationPatterns( |
3605 | RewritePatternSet &patterns, PatternBenefit benefit) { |
3606 | patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit); |
3607 | } |
3608 | |