1 | //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the linalg dialect Vectorization transformations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | #include "mlir/Dialect/Affine/Utils.h" |
13 | |
14 | #include "mlir/Analysis/SliceAnalysis.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
19 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
22 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
26 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
28 | #include "mlir/IR/AffineExpr.h" |
29 | #include "mlir/IR/Builders.h" |
30 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
31 | #include "mlir/IR/BuiltinTypes.h" |
32 | #include "mlir/IR/OpDefinition.h" |
33 | #include "mlir/IR/PatternMatch.h" |
34 | #include "mlir/Support/LLVM.h" |
35 | #include "mlir/Transforms/RegionUtils.h" |
36 | #include "llvm/ADT/STLExtras.h" |
37 | #include "llvm/ADT/Sequence.h" |
38 | #include "llvm/ADT/SmallVector.h" |
39 | #include "llvm/ADT/TypeSwitch.h" |
40 | #include "llvm/ADT/iterator_range.h" |
41 | #include "llvm/Support/Debug.h" |
42 | #include "llvm/Support/MathExtras.h" |
43 | #include "llvm/Support/raw_ostream.h" |
44 | #include <optional> |
45 | #include <type_traits> |
46 | |
47 | using namespace mlir; |
48 | using namespace mlir::linalg; |
49 | |
50 | #define DEBUG_TYPE "linalg-vectorization" |
51 | |
52 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
53 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
54 | |
55 | /// Try to vectorize `convOp` as a convolution. |
56 | static FailureOr<Operation *> |
57 | vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, |
58 | ArrayRef<int64_t> inputVecSizes = {}, |
59 | ArrayRef<bool> inputVecScalableFlags = {}, |
60 | bool flatten1DDepthwiseConv = false); |
61 | |
62 | /// Vectorize tensor::InsertSliceOp with: |
63 | /// * vector::TransferReadOp + vector::TransferWriteOp |
64 | /// The vector sizes are either: |
65 | /// * user-provided in `inputVectorSizes`, or |
66 | /// * inferred from the static dims in the input and output tensors. |
67 | /// Bails out if: |
68 | /// * vector sizes are not user-provided, and |
69 | /// * at least one dim is dynamic (in both the input and output tensors). |
70 | /// |
71 | /// Before: |
72 | /// !t_in_type = tensor<1x2x3xf32> |
73 | /// !t_out_type = tensor<9x8x7x1x2x3xf32> |
74 | /// !v_type = vector<1x2x3xf32> |
75 | /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type |
76 | /// into !t_out_type |
77 | /// After: |
78 | /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type |
79 | /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type |
80 | static LogicalResult |
81 | vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, |
82 | ArrayRef<int64_t> inputVectorSizes, |
83 | SmallVectorImpl<Value> &newResults); |
84 | |
85 | /// Returns the effective Pad value for the input op, provided it's a scalar. |
86 | /// |
87 | /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If |
88 | /// this Op performs padding, retrieve the padding value provided that it's |
89 | /// a scalar and static/fixed for all the padded values. Returns an empty value |
90 | /// otherwise. |
91 | static Value getStaticPadVal(Operation *op); |
92 | |
93 | /// Return the unique instance of OpType in `block` if it is indeed unique. |
94 | /// Return null if none or more than 1 instances exist. |
95 | template <typename OpType> |
96 | static OpType getSingleOpOfType(Block &block) { |
97 | OpType res; |
98 | block.walk([&](OpType op) { |
99 | if (res) { |
100 | res = nullptr; |
101 | return WalkResult::interrupt(); |
102 | } |
103 | res = op; |
104 | return WalkResult::advance(); |
105 | }); |
106 | return res; |
107 | } |
108 | |
109 | /// Helper function to extract the input slices after filter is unrolled along |
110 | /// kw. |
111 | static SmallVector<Value> |
112 | extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, |
113 | int64_t nSize, int64_t wSize, int64_t cSize, |
114 | int64_t kwSize, int strideW, int dilationW, |
115 | int64_t wSizeStep, bool isSingleChanneled) { |
116 | SmallVector<Value> result; |
117 | if (isSingleChanneled) { |
118 | // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled |
119 | // convolution. |
120 | SmallVector<int64_t> sizes = {wSizeStep}; |
121 | SmallVector<int64_t> strides = {1}; |
122 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
123 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
124 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
125 | loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides)); |
126 | } |
127 | } |
128 | } else { |
129 | // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] |
130 | // for channeled convolution. |
131 | SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize}; |
132 | SmallVector<int64_t> strides = {1, 1, 1}; |
133 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
134 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
135 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
136 | loc, input, |
137 | /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
138 | sizes, strides)); |
139 | } |
140 | } |
141 | } |
142 | return result; |
143 | } |
144 | |
145 | /// Helper function to extract the filter slices after filter is unrolled along |
146 | /// kw. |
147 | static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter, |
148 | Location loc, Value filter, |
149 | int64_t kwSize) { |
150 | SmallVector<Value> result; |
151 | // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for |
152 | // non-chanelled convolution] @ [kw]. |
153 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
154 | result.push_back(rewriter.create<vector::ExtractOp>( |
155 | loc, filter, /*offsets=*/ArrayRef<int64_t>{kw})); |
156 | } |
157 | return result; |
158 | } |
159 | |
160 | /// Helper function to extract the result slices after filter is unrolled along |
161 | /// kw. |
162 | static SmallVector<Value> |
163 | extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, |
164 | int64_t nSize, int64_t wSize, int64_t fSize, |
165 | int64_t wSizeStep, bool isSingleChanneled) { |
166 | SmallVector<Value> result; |
167 | if (isSingleChanneled) { |
168 | // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. |
169 | SmallVector<int64_t> sizes = {wSizeStep}; |
170 | SmallVector<int64_t> strides = {1}; |
171 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
172 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
173 | loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides)); |
174 | } |
175 | } else { |
176 | // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
177 | // convolution. |
178 | SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize}; |
179 | SmallVector<int64_t> strides = {1, 1, 1}; |
180 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
181 | result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
182 | loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides)); |
183 | } |
184 | } |
185 | return result; |
186 | } |
187 | |
188 | /// Helper function to insert the computed result slices. |
189 | static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, |
190 | Value res, int64_t wSize, int64_t wSizeStep, |
191 | SmallVectorImpl<Value> &resVals, |
192 | bool isSingleChanneled) { |
193 | |
194 | if (isSingleChanneled) { |
195 | // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. |
196 | // This does not depend on kw. |
197 | SmallVector<int64_t> strides = {1}; |
198 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
199 | res = rewriter.create<vector::InsertStridedSliceOp>( |
200 | loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides); |
201 | } |
202 | } else { |
203 | // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
204 | // convolution. This does not depend on kw. |
205 | SmallVector<int64_t> strides = {1, 1, 1}; |
206 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
207 | res = rewriter.create<vector::InsertStridedSliceOp>( |
208 | loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, |
209 | strides); |
210 | } |
211 | } |
212 | return res; |
213 | } |
214 | |
215 | /// Contains the vectorization state and related methods used across the |
216 | /// vectorization process of a given operation. |
217 | struct VectorizationState { |
218 | VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {} |
219 | |
220 | /// Initializes the vectorization state, including the computation of the |
221 | /// canonical vector shape for vectorization. |
222 | LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, |
223 | ArrayRef<int64_t> inputVectorSizes, |
224 | ArrayRef<bool> inputScalableVecDims); |
225 | |
226 | /// Returns the canonical vector shape used to vectorize the iteration space. |
227 | ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } |
228 | |
229 | /// Returns the vector dimensions that are scalable in the canonical vector |
230 | /// shape. |
231 | ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; } |
232 | |
233 | /// Returns a vector type of the provided `elementType` with the canonical |
234 | /// vector shape and the corresponding fixed/scalable dimensions bit. If |
235 | /// `dimPermutation` is provided, the canonical vector dimensions are permuted |
236 | /// accordingly. |
237 | VectorType getCanonicalVecType( |
238 | Type elementType, |
239 | std::optional<AffineMap> dimPermutation = std::nullopt) const { |
240 | SmallVector<int64_t> vectorShape; |
241 | SmallVector<bool> scalableDims; |
242 | if (dimPermutation.has_value()) { |
243 | vectorShape = |
244 | applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape); |
245 | scalableDims = |
246 | applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims); |
247 | } else { |
248 | vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end()); |
249 | scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end()); |
250 | } |
251 | |
252 | return VectorType::get(vectorShape, elementType, scalableDims); |
253 | } |
254 | |
255 | /// Masks an operation with the canonical vector mask if the operation needs |
256 | /// masking. Returns the masked operation or the original operation if masking |
257 | /// is not needed. If provided, the canonical mask for this operation is |
258 | /// permuted using `maybeIndexingMap`. |
259 | Operation * |
260 | maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
261 | std::optional<AffineMap> maybeIndexingMap = std::nullopt); |
262 | |
263 | private: |
264 | /// Initializes the iteration space static sizes using the Linalg op |
265 | /// information. This may become more complicated in the future. |
266 | void initIterSpaceStaticSizes(LinalgOp linalgOp) { |
267 | iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges()); |
268 | } |
269 | |
270 | /// Generates 'arith.constant' and 'tensor/memref.dim' operations for |
271 | /// all the static and dynamic dimensions of the iteration space to be |
272 | /// vectorized and store them in `iterSpaceValueSizes`. |
273 | LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
274 | LinalgOp linalgOp); |
275 | |
276 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
277 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is |
278 | /// permuted using that permutation map. If a new mask is created, it will be |
279 | /// cached for future users. |
280 | Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, |
281 | LinalgOp linalgOp, |
282 | std::optional<AffineMap> maybeMaskingMap); |
283 | |
284 | /// Check whether this permutation map can be used for masking. At the |
285 | /// moment we only make sure that there are no broadcast dimensions, but this |
286 | /// might change if indexing maps evolve. |
287 | bool isValidMaskingMap(AffineMap maskingMap) { |
288 | return maskingMap.getBroadcastDims().size() == 0; |
289 | } |
290 | |
291 | /// Turn the input indexing map into a valid masking map. |
292 | /// |
293 | /// The input indexing map may contain "zero" results, e.g.: |
294 | /// (d0, d1, d2, d3) -> (d2, d1, d0, 0) |
295 | /// Applying such maps to canonical vector shapes like this one: |
296 | /// (1, 16, 16, 4) |
297 | /// would yield an invalid vector shape like this: |
298 | /// (16, 16, 1, 0) |
299 | /// Instead, drop the broadcasting dims that make no sense for masking perm. |
300 | /// maps: |
301 | /// (d0, d1, d2, d3) -> (d2, d1, d0) |
302 | /// This way, the corresponding vector/mask type will be: |
303 | /// vector<16x16x1xty> |
304 | /// rather than this invalid Vector type: |
305 | /// vector<16x16x1x0xty> |
306 | AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) { |
307 | return indexingMap.dropZeroResults(); |
308 | } |
309 | |
310 | // Holds the compile-time static sizes of the iteration space to vectorize. |
311 | // Dynamic dimensions are represented using ShapedType::kDynamic. |
312 | SmallVector<int64_t> iterSpaceStaticSizes; |
313 | |
314 | /// Holds the value sizes of the iteration space to vectorize. Static |
315 | /// dimensions are represented by 'arith.constant' and dynamic |
316 | /// dimensions by 'tensor/memref.dim'. |
317 | SmallVector<Value> iterSpaceValueSizes; |
318 | |
319 | /// Holds the canonical vector shape used to vectorize the iteration space. |
320 | SmallVector<int64_t> canonicalVecShape; |
321 | |
322 | /// Holds the vector dimensions that are scalable in the canonical vector |
323 | /// shape. |
324 | SmallVector<bool> scalableVecDims; |
325 | |
326 | /// Holds the active masks for permutations of the canonical vector iteration |
327 | /// space. |
328 | DenseMap<AffineMap, Value> activeMaskCache; |
329 | |
330 | /// Global vectorization guard for the incoming rewriter. It's initialized |
331 | /// when the vectorization state is initialized. |
332 | OpBuilder::InsertionGuard rewriterGuard; |
333 | }; |
334 | |
335 | LogicalResult |
336 | VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
337 | LinalgOp linalgOp) { |
338 | // TODO: Support 0-d vectors. |
339 | for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { |
340 | if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { |
341 | // Create constant index op for static dimensions. |
342 | iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>( |
343 | linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); |
344 | continue; |
345 | } |
346 | |
347 | // Find an operand defined on this dimension of the iteration space to |
348 | // extract the runtime dimension size. |
349 | Value operand; |
350 | unsigned operandDimPos; |
351 | if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand, |
352 | operandDimPos))) |
353 | return failure(); |
354 | |
355 | Value dynamicDim = linalgOp.hasPureTensorSemantics() |
356 | ? (Value)rewriter.create<tensor::DimOp>( |
357 | linalgOp.getLoc(), operand, operandDimPos) |
358 | : (Value)rewriter.create<memref::DimOp>( |
359 | linalgOp.getLoc(), operand, operandDimPos); |
360 | iterSpaceValueSizes.push_back(Elt: dynamicDim); |
361 | } |
362 | |
363 | return success(); |
364 | } |
365 | |
366 | /// Initializes the vectorization state, including the computation of the |
367 | /// canonical vector shape for vectorization. |
368 | // TODO: Move this to the constructor when we can remove the failure cases. |
369 | LogicalResult |
370 | VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, |
371 | ArrayRef<int64_t> inputVectorSizes, |
372 | ArrayRef<bool> inputScalableVecDims) { |
373 | // Initialize the insertion point. |
374 | rewriter.setInsertionPoint(linalgOp); |
375 | |
376 | if (!inputVectorSizes.empty()) { |
377 | // Get the canonical vector shape from the input vector sizes provided. This |
378 | // path should be taken to vectorize code with dynamic shapes and when using |
379 | // vector sizes greater than the iteration space sizes. |
380 | canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end()); |
381 | scalableVecDims.append(in_start: inputScalableVecDims.begin(), |
382 | in_end: inputScalableVecDims.end()); |
383 | } else { |
384 | // Compute the canonical vector shape from the operation shape. If there are |
385 | // dynamic shapes, the operation won't be vectorized. We assume all the |
386 | // vector dimensions are fixed. |
387 | canonicalVecShape = linalgOp.getStaticLoopRanges(); |
388 | scalableVecDims.append(linalgOp.getNumLoops(), false); |
389 | } |
390 | |
391 | LDBG("Canonical vector shape: "); |
392 | LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); |
393 | LLVM_DEBUG(llvm::dbgs() << "\n"); |
394 | LDBG("Scalable vector dims: "); |
395 | LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); |
396 | LLVM_DEBUG(llvm::dbgs() << "\n"); |
397 | |
398 | if (ShapedType::isDynamicShape(canonicalVecShape)) |
399 | return failure(); |
400 | |
401 | // Initialize iteration space static sizes. |
402 | initIterSpaceStaticSizes(linalgOp: linalgOp); |
403 | |
404 | // Generate 'arith.constant' and 'tensor/memref.dim' operations for |
405 | // all the static and dynamic dimensions of the iteration space, needed to |
406 | // compute a mask during vectorization. |
407 | if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp: linalgOp))) |
408 | return failure(); |
409 | |
410 | return success(); |
411 | } |
412 | |
413 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
414 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted |
415 | /// using that permutation map. If a new mask is created, it will be cached for |
416 | /// future users. |
417 | Value VectorizationState::getOrCreateMaskFor( |
418 | RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
419 | std::optional<AffineMap> maybeMaskingMap) { |
420 | |
421 | assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) && |
422 | "Ill-formed masking map."); |
423 | |
424 | // No mask is needed if the operation is not maskable. |
425 | auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask); |
426 | if (!maskableOp) |
427 | return Value(); |
428 | |
429 | assert(!maskableOp.isMasked() && |
430 | "Masking an operation that is already masked"); |
431 | |
432 | // If no masking map was provided, use an identity map with the loop dims. |
433 | assert((!maybeMaskingMap || *maybeMaskingMap) && |
434 | "Unexpected null mask permutation map"); |
435 | AffineMap maskingMap = |
436 | maybeMaskingMap ? *maybeMaskingMap |
437 | : AffineMap::getMultiDimIdentityMap( |
438 | numDims: linalgOp.getNumLoops(), context: rewriter.getContext()); |
439 | |
440 | LDBG("Masking map: "<< maskingMap << "\n"); |
441 | |
442 | // Return the active mask for the masking map of this operation if it was |
443 | // already created. |
444 | auto activeMaskIt = activeMaskCache.find(Val: maskingMap); |
445 | if (activeMaskIt != activeMaskCache.end()) { |
446 | Value mask = activeMaskIt->second; |
447 | LDBG("Reusing mask: "<< mask << "\n"); |
448 | return mask; |
449 | } |
450 | |
451 | // Compute permuted projection of the iteration space to be masked and the |
452 | // corresponding mask shape. If the resulting iteration space dimensions are |
453 | // static and identical to the mask shape, masking is not needed for this |
454 | // operation. |
455 | // TODO: Improve this check. Only projected permutation indexing maps are |
456 | // supported. |
457 | SmallVector<int64_t> permutedStaticSizes = |
458 | applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes); |
459 | auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap); |
460 | auto maskShape = maskType.getShape(); |
461 | |
462 | LDBG("Mask shape: "); |
463 | LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); |
464 | LLVM_DEBUG(llvm::dbgs() << "\n"); |
465 | |
466 | if (permutedStaticSizes == maskShape) { |
467 | LDBG("Masking is not needed for masking map: "<< maskingMap << "\n"); |
468 | activeMaskCache[maskingMap] = Value(); |
469 | return Value(); |
470 | } |
471 | |
472 | // Permute the iteration space value sizes to compute the mask upper bounds. |
473 | SmallVector<Value> upperBounds = |
474 | applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes)); |
475 | assert(!maskShape.empty() && !upperBounds.empty() && |
476 | "Masked 0-d vectors are not supported yet"); |
477 | |
478 | // Create the mask based on the dimension values. |
479 | Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(), |
480 | maskType, upperBounds); |
481 | LDBG("Creating new mask: "<< mask << "\n"); |
482 | activeMaskCache[maskingMap] = mask; |
483 | return mask; |
484 | } |
485 | |
486 | Operation * |
487 | VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, |
488 | LinalgOp linalgOp, |
489 | std::optional<AffineMap> maybeIndexingMap) { |
490 | LDBG("Trying to mask: "<< *opToMask << "\n"); |
491 | |
492 | std::optional<AffineMap> maybeMaskingMap = std::nullopt; |
493 | if (maybeIndexingMap) |
494 | maybeMaskingMap = getMaskingMapFromIndexingMap(indexingMap&: *maybeIndexingMap); |
495 | |
496 | // Create or retrieve mask for this operation. |
497 | Value mask = |
498 | getOrCreateMaskFor(rewriter, opToMask, linalgOp: linalgOp, maybeMaskingMap); |
499 | |
500 | if (!mask) { |
501 | LDBG("No mask required\n"); |
502 | return opToMask; |
503 | } |
504 | |
505 | // Wrap the operation with a new `vector.mask` and update D-U chain. |
506 | assert(opToMask && "Expected a valid operation to mask"); |
507 | auto maskOp = cast<vector::MaskOp>( |
508 | mlir::vector::maskOperation(rewriter, opToMask, mask)); |
509 | Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); |
510 | |
511 | for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults())) |
512 | rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx), |
513 | maskOpTerminator); |
514 | |
515 | LDBG("Masked operation: "<< *maskOp << "\n"); |
516 | return maskOp; |
517 | } |
518 | |
519 | /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a |
520 | /// projectedPermutation, compress the unused dimensions to serve as a |
521 | /// permutation_map for a vector transfer operation. |
522 | /// For example, given a linalg op such as: |
523 | /// |
524 | /// ``` |
525 | /// %0 = linalg.generic { |
526 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>, |
527 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)> |
528 | /// } |
529 | /// ins(%0 : tensor<2x3x4xf32>) |
530 | /// outs(%1 : tensor<5x6xf32>) |
531 | /// ``` |
532 | /// |
533 | /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine |
534 | /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second |
535 | /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. |
536 | static AffineMap reindexIndexingMap(AffineMap map) { |
537 | assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) && |
538 | "expected projected permutation"); |
539 | auto res = compressUnusedDims(map); |
540 | assert(res.getNumDims() == |
541 | (res.getNumResults() - res.getNumOfZeroResults()) && |
542 | "expected reindexed map with same number of dims and results"); |
543 | return res; |
544 | } |
545 | |
546 | /// Helper enum to represent conv1d input traversal order. |
547 | enum class Conv1DOpOrder { |
548 | W, // Corresponds to non-channeled 1D convolution operation. |
549 | Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. |
550 | Nwc // Corresponds to operation that traverses the input in (n, w, c) order. |
551 | }; |
552 | |
553 | /// Helper data structure to represent the result of vectorization. |
554 | /// In certain specific cases, like terminators, we do not want to propagate/ |
555 | enum VectorizationStatus { |
556 | /// Op failed to vectorize. |
557 | Failure = 0, |
558 | /// Op vectorized and custom function took care of replacement logic |
559 | NoReplace, |
560 | /// Op vectorized into a new Op whose results will replace original Op's |
561 | /// results. |
562 | NewOp |
563 | // TODO: support values if Op vectorized to Many-Ops whose results we need to |
564 | // aggregate for replacement. |
565 | }; |
566 | struct VectorizationResult { |
567 | /// Return status from vectorizing the current op. |
568 | enum VectorizationStatus status = VectorizationStatus::Failure; |
569 | /// New vectorized operation to replace the current op. |
570 | /// Replacement behavior is specified by `status`. |
571 | Operation *newOp; |
572 | }; |
573 | |
574 | std::optional<vector::CombiningKind> |
575 | mlir::linalg::getCombinerOpKind(Operation *combinerOp) { |
576 | using ::mlir::vector::CombiningKind; |
577 | |
578 | if (!combinerOp) |
579 | return std::nullopt; |
580 | return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp) |
581 | .Case<arith::AddIOp, arith::AddFOp>( |
582 | [&](auto op) { return CombiningKind::ADD; }) |
583 | .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; }) |
584 | .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; }) |
585 | .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; }) |
586 | .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; }) |
587 | .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; }) |
588 | .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; }) |
589 | .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; }) |
590 | .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; }) |
591 | .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; }) |
592 | .Case<arith::MulIOp, arith::MulFOp>( |
593 | [&](auto op) { return CombiningKind::MUL; }) |
594 | .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; }) |
595 | .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) |
596 | .Default([&](auto op) { return std::nullopt; }); |
597 | } |
598 | |
599 | /// Check whether `outputOperand` is a reduction with a single combiner |
600 | /// operation. Return the combiner operation of the reduction. Return |
601 | /// nullptr otherwise. Multiple reduction operations would impose an |
602 | /// ordering between reduction dimensions and is currently unsupported in |
603 | /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != |
604 | /// max(min(X)) |
605 | // TODO: use in LinalgOp verification, there is a circular dependency atm. |
606 | static Operation *matchLinalgReduction(OpOperand *outputOperand) { |
607 | auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); |
608 | unsigned outputPos = |
609 | outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs(); |
610 | // Only single combiner operations are supported for now. |
611 | SmallVector<Operation *, 4> combinerOps; |
612 | if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || |
613 | combinerOps.size() != 1) |
614 | return nullptr; |
615 | |
616 | // Return the combiner operation. |
617 | return combinerOps[0]; |
618 | } |
619 | |
620 | /// Broadcast `value` to a vector of `shape` if possible. Return value |
621 | /// otherwise. |
622 | static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) { |
623 | auto dstVecType = dyn_cast<VectorType>(dstType); |
624 | // If no shape to broadcast to, just return `value`. |
625 | if (dstVecType.getRank() == 0) |
626 | return value; |
627 | if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) != |
628 | vector::BroadcastableToResult::Success) |
629 | return value; |
630 | Location loc = b.getInsertionPoint()->getLoc(); |
631 | return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value); |
632 | } |
633 | |
634 | /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This |
635 | /// assumes that `reductionOp` has two operands and one of them is the reduction |
636 | /// initial value.buildMultiDimReduce |
637 | // Note: this is a true builder that notifies the OpBuilder listener. |
638 | // TODO: Consider moving as a static helper on the ReduceOp. |
639 | static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, |
640 | Value valueToReduce, Value acc, |
641 | ArrayRef<bool> dimsToMask) { |
642 | auto maybeKind = getCombinerOpKind(reduceOp); |
643 | assert(maybeKind && "Failed precondition: could not get reduction kind"); |
644 | return b.create<vector::MultiDimReductionOp>( |
645 | reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); |
646 | } |
647 | |
648 | static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) { |
649 | return llvm::to_vector( |
650 | llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); |
651 | } |
652 | |
653 | /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one |
654 | /// reduction iterator. |
655 | static bool hasReductionIterator(LinalgOp &op) { |
656 | return isa<linalg::ReduceOp>(op) || |
657 | (isa<linalg::GenericOp>(op) && |
658 | llvm::any_of(op.getIteratorTypesArray(), isReductionIterator)); |
659 | } |
660 | |
661 | /// Build a vector.transfer_write of `value` into `outputOperand` at indices set |
662 | /// to all `0`; where `outputOperand` is an output operand of the LinalgOp |
663 | /// currently being vectorized. If `dest` has null rank, build an memref.store. |
664 | /// Return the produced value or null if no value is produced. |
665 | // Note: this is a true builder that notifies the OpBuilder listener. |
666 | // TODO: Consider moving as a static helper on the ReduceOp. |
667 | static Value buildVectorWrite(RewriterBase &rewriter, Value value, |
668 | OpOperand *outputOperand, |
669 | VectorizationState &state) { |
670 | Location loc = value.getLoc(); |
671 | auto linalgOp = cast<LinalgOp>(outputOperand->getOwner()); |
672 | AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand); |
673 | |
674 | // Compute the vector type of the value to store. This type should be an |
675 | // identity or projection of the canonical vector type without any permutation |
676 | // applied, given that any permutation in a transfer write happens as part of |
677 | // the write itself. |
678 | AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap( |
679 | ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(), |
680 | keepDimFilter: [&](AffineDimExpr dimExpr) -> bool { |
681 | return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr); |
682 | }); |
683 | auto vectorType = state.getCanonicalVecType( |
684 | getElementTypeOrSelf(type: outputOperand->get().getType()), vectorTypeMap); |
685 | |
686 | Operation *write; |
687 | if (vectorType.getRank() > 0) { |
688 | AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap)); |
689 | SmallVector<Value> indices(linalgOp.getRank(outputOperand), |
690 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
691 | value = broadcastIfNeeded(rewriter, value, vectorType); |
692 | assert(value.getType() == vectorType && "Incorrect type"); |
693 | write = rewriter.create<vector::TransferWriteOp>( |
694 | loc, value, outputOperand->get(), indices, writeMap); |
695 | } else { |
696 | // 0-d case is still special: do not invert the reindexing writeMap. |
697 | if (!isa<VectorType>(value.getType())) |
698 | value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value); |
699 | assert(value.getType() == vectorType && "Incorrect type"); |
700 | write = rewriter.create<vector::TransferWriteOp>( |
701 | loc, value, outputOperand->get(), ValueRange{}); |
702 | } |
703 | |
704 | write = state.maskOperation(rewriter, opToMask: write, linalgOp: linalgOp, maybeIndexingMap: opOperandMap); |
705 | |
706 | // If masked, set in-bounds to true. Masking guarantees that the access will |
707 | // be in-bounds. |
708 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) { |
709 | auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp()); |
710 | SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true); |
711 | maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); |
712 | } |
713 | |
714 | LDBG("vectorized op: "<< *write << "\n"); |
715 | if (!write->getResults().empty()) |
716 | return write->getResult(idx: 0); |
717 | return Value(); |
718 | } |
719 | |
720 | // Custom vectorization precondition function type. This is intented to be used |
721 | // with CustomVectorizationHook. Returns success if the corresponding custom |
722 | // hook can vectorize the op. |
723 | using CustomVectorizationPrecondition = |
724 | std::function<LogicalResult(Operation *, bool)>; |
725 | |
726 | // Custom vectorization function type. Produce a vector form of Operation* |
727 | // assuming all its vectorized operands are already in the IRMapping. |
728 | // Return nullptr if the Operation cannot be vectorized. |
729 | using CustomVectorizationHook = |
730 | std::function<VectorizationResult(Operation *, const IRMapping &)>; |
731 | |
732 | /// Helper function to vectorize the terminator of a `linalgOp`. New result |
733 | /// vector values are appended to `newResults`. Return |
734 | /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it |
735 | /// should not try to map produced operations and instead return the results |
736 | /// using the `newResults` vector making them available to the vectorization |
737 | /// algorithm for RAUW. This function is meant to be used as a |
738 | /// CustomVectorizationHook. |
739 | static VectorizationResult |
740 | vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, |
741 | const IRMapping &bvm, VectorizationState &state, |
742 | LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) { |
743 | auto yieldOp = dyn_cast<linalg::YieldOp>(op); |
744 | if (!yieldOp) |
745 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
746 | for (const auto &output : llvm::enumerate(yieldOp.getValues())) { |
747 | // TODO: Scan for an opportunity for reuse. |
748 | // TODO: use a map. |
749 | Value vectorValue = bvm.lookup(output.value()); |
750 | Value newResult = |
751 | buildVectorWrite(rewriter, vectorValue, |
752 | linalgOp.getDpsInitOperand(output.index()), state); |
753 | if (newResult) |
754 | newResults.push_back(newResult); |
755 | } |
756 | |
757 | return VectorizationResult{.status: VectorizationStatus::NoReplace, .newOp: nullptr}; |
758 | } |
759 | |
760 | /// Helper function to vectorize the index operations of a `linalgOp`. Return |
761 | /// VectorizationStatus::NewOp to signal the vectorization algorithm that it |
762 | /// should map the produced operations. This function is meant to be used as a |
763 | /// CustomVectorizationHook. |
764 | static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, |
765 | VectorizationState &state, |
766 | Operation *op, |
767 | LinalgOp linalgOp) { |
768 | IndexOp indexOp = dyn_cast<linalg::IndexOp>(op); |
769 | if (!indexOp) |
770 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
771 | auto loc = indexOp.getLoc(); |
772 | // Compute the static loop sizes of the index op. |
773 | ArrayRef<int64_t> targetShape = state.getCanonicalVecShape(); |
774 | auto dim = indexOp.getDim(); |
775 | // Compute a one-dimensional index vector for the index op dimension. |
776 | auto indexVectorType = |
777 | VectorType::get({targetShape[dim]}, rewriter.getIndexType(), |
778 | state.getScalableVecDims()[dim]); |
779 | auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType); |
780 | // Return the one-dimensional index vector if it lives in the trailing |
781 | // dimension of the iteration space since the vectorization algorithm in this |
782 | // case can handle the broadcast. |
783 | if (dim == targetShape.size() - 1) |
784 | return VectorizationResult{VectorizationStatus::NewOp, indexSteps}; |
785 | // Otherwise permute the targetShape to move the index dimension last, |
786 | // broadcast the one-dimensional index vector to the permuted shape, and |
787 | // finally transpose the broadcasted index vector to undo the permutation. |
788 | auto permPattern = |
789 | llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size())); |
790 | std::swap(permPattern[dim], permPattern.back()); |
791 | auto permMap = |
792 | AffineMap::getPermutationMap(permPattern, linalgOp.getContext()); |
793 | |
794 | auto broadCastOp = rewriter.create<vector::BroadcastOp>( |
795 | loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap), |
796 | indexSteps); |
797 | SmallVector<int64_t> transposition = |
798 | llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops())); |
799 | std::swap(transposition.back(), transposition[dim]); |
800 | auto transposeOp = |
801 | rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition); |
802 | return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; |
803 | } |
804 | |
805 | /// Helper function to check if the tensor.extract can be vectorized by the |
806 | /// custom hook vectorizeTensorExtract. |
807 | static LogicalResult |
808 | tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) { |
809 | tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op); |
810 | if (!extractOp) |
811 | return failure(); |
812 | |
813 | if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) |
814 | return failure(); |
815 | |
816 | // Check the index type, but only for non 0-d tensors (for which we do need |
817 | // access indices). |
818 | if (not extractOp.getIndices().empty()) { |
819 | if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) |
820 | return failure(); |
821 | } |
822 | |
823 | if (!llvm::all_of(extractOp->getResultTypes(), |
824 | VectorType::isValidElementType)) { |
825 | return failure(); |
826 | } |
827 | |
828 | return success(); |
829 | } |
830 | |
831 | /// Calculates the offsets (`$index_vec`) for `vector.gather` operations |
832 | /// generated from `tensor.extract`. The offset is calculated as follows |
833 | /// (example using scalar values): |
834 | /// |
835 | /// offset = extractOp.indices[0] |
836 | /// for (i = 1; i < numIndices; i++) |
837 | /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; |
838 | /// |
839 | /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: |
840 | /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 |
841 | static Value calculateGatherOffset(RewriterBase &rewriter, |
842 | VectorizationState &state, |
843 | tensor::ExtractOp extractOp, |
844 | const IRMapping &bvm) { |
845 | // The vector of indices for GatherOp should be shaped as the output vector. |
846 | auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType()); |
847 | auto loc = extractOp.getLoc(); |
848 | |
849 | Value offset = broadcastIfNeeded( |
850 | rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType); |
851 | |
852 | const size_t numIndices = extractOp.getIndices().size(); |
853 | for (size_t i = 1; i < numIndices; i++) { |
854 | Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i); |
855 | |
856 | auto dimSize = broadcastIfNeeded( |
857 | rewriter, |
858 | rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx), |
859 | indexVecType); |
860 | |
861 | offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize); |
862 | |
863 | auto extractOpIndex = broadcastIfNeeded( |
864 | rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType); |
865 | |
866 | offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset); |
867 | } |
868 | |
869 | return offset; |
870 | } |
871 | |
872 | enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; |
873 | |
874 | /// Find the index of the trailing non-unit dim in linalgOp. This hook is used |
875 | /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op) |
876 | /// represents a contiguous load operation. |
877 | /// |
878 | /// Note that when calling this hook, it is assumed that the output vector is |
879 | /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been |
880 | /// labelled as a gather load before entering this method. |
881 | /// |
882 | /// Following on from the above, it is assumed that: |
883 | /// * for statically shaped loops, when no masks are used, only one dim is != |
884 | /// 1 (that's what the shape of the output vector is based on). |
885 | /// * for dynamically shaped loops, there might be more non-unit dims |
886 | /// as the output vector type is user-specified. |
887 | /// |
888 | /// TODO: Statically shaped loops + vector masking |
889 | static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) { |
890 | SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges(); |
891 | assert( |
892 | (linalgOp.hasDynamicShape() || |
893 | llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) && |
894 | "For statically shaped Linalg Ops, only one " |
895 | "non-unit loop dim is expected"); |
896 | assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse."); |
897 | |
898 | size_t idx = loopRanges.size() - 1; |
899 | for (; idx != 0; idx--) |
900 | if (loopRanges[idx] != 1) |
901 | break; |
902 | |
903 | return idx; |
904 | } |
905 | |
906 | /// Checks whether `val` can be used for calculating a loop invariant index. |
907 | static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, |
908 | VectorType resType) { |
909 | |
910 | assert(((llvm::count_if(resType.getShape(), |
911 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
912 | "n-D vectors are not yet supported"); |
913 | |
914 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
915 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
916 | // tricky. Just bail out in the latter case. |
917 | // TODO: We could try analysing the corresponding affine map here. |
918 | auto *block = linalgOp.getBlock(); |
919 | if (isa<BlockArgument>(Val: val)) |
920 | return llvm::all_of(block->getArguments(), |
921 | [&val](Value v) { return (v != val); }); |
922 | |
923 | Operation *defOp = val.getDefiningOp(); |
924 | assert(defOp && "This is neither a block argument nor an operation result"); |
925 | |
926 | // IndexOp is loop invariant as long as its result remains constant across |
927 | // iterations. Note that for dynamic shapes, the corresponding dim will also |
928 | // be conservatively treated as != 1. |
929 | if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { |
930 | return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1; |
931 | } |
932 | |
933 | auto *ancestor = block->findAncestorOpInBlock(*defOp); |
934 | |
935 | // Values define outside `linalgOp` are loop invariant. |
936 | if (!ancestor) |
937 | return true; |
938 | |
939 | // Values defined inside `linalgOp`, which are constant, are loop invariant. |
940 | if (isa<arith::ConstantOp>(ancestor)) |
941 | return true; |
942 | |
943 | bool result = true; |
944 | for (auto op : ancestor->getOperands()) |
945 | result &= isLoopInvariantIdx(linalgOp, op, resType); |
946 | |
947 | return result; |
948 | } |
949 | |
950 | /// Check whether `val` could be used for calculating the trailing index for a |
951 | /// contiguous load operation. |
952 | /// |
953 | /// There are currently 3 types of values that are allowed here: |
954 | /// 1. loop-invariant values, |
955 | /// 2. values that increment by 1 with every loop iteration, |
956 | /// 3. results of basic arithmetic operations (linear and continuous) |
957 | /// involving 1., 2. and 3. |
958 | /// This method returns True if indeed only such values are used in calculating |
959 | /// `val.` |
960 | /// |
961 | /// Additionally, the trailing index for a contiguous load operation should |
962 | /// increment by 1 with every loop iteration, i.e. be based on: |
963 | /// * `linalg.index <dim>` , |
964 | /// where <dim> is the trailing non-unit dim of the iteration space (this way, |
965 | /// `linalg.index <dim>` increments by 1 with every loop iteration). |
966 | /// `foundIndexOp` is updated to `true` when such Op is found. |
967 | static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, |
968 | bool &foundIndexOp, VectorType resType) { |
969 | |
970 | assert(((llvm::count_if(resType.getShape(), |
971 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
972 | "n-D vectors are not yet supported"); |
973 | |
974 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
975 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
976 | // tricky. Just bail out in the latter case. |
977 | // TODO: We could try analysing the corresponding affine map here. |
978 | auto *block = linalgOp.getBlock(); |
979 | if (isa<BlockArgument>(Val: val)) |
980 | return llvm::all_of(block->getArguments(), |
981 | [&val](Value v) { return (v != val); }); |
982 | |
983 | Operation *defOp = val.getDefiningOp(); |
984 | assert(defOp && "This is neither a block argument nor an operation result"); |
985 | |
986 | if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { |
987 | auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp); |
988 | |
989 | foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne); |
990 | return true; |
991 | } |
992 | |
993 | auto *ancestor = block->findAncestorOpInBlock(*defOp); |
994 | |
995 | if (!ancestor) |
996 | return false; |
997 | |
998 | // Conservatively reject Ops that could lead to indices with stride other |
999 | // than 1. |
1000 | if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor)) |
1001 | return false; |
1002 | |
1003 | bool result = false; |
1004 | for (auto op : ancestor->getOperands()) |
1005 | result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType); |
1006 | |
1007 | return result; |
1008 | } |
1009 | |
1010 | /// Infer the memory access pattern for the input ExtractOp |
1011 | /// |
1012 | /// Based on the ExtratOp result shape and the access indices, decides whether |
1013 | /// this Op corresponds to a contiguous load (including a broadcast of a scalar) |
1014 | /// or a gather load. When analysing the ExtractOp indices (to identify |
1015 | /// contiguous laods), this method looks for "loop" invariant indices (e.g. |
1016 | /// block arguments) and indices that change linearly (e.g. via `linalg.index` |
1017 | /// Op). |
1018 | /// |
1019 | /// Note that it is always safe to use gather load operations for contiguous |
1020 | /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume |
1021 | /// that `extractOp` is a gather load. |
1022 | static VectorMemoryAccessKind |
1023 | getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, |
1024 | LinalgOp &linalgOp, VectorType resType) { |
1025 | |
1026 | auto inputShape = cast<ShapedType>(extractOp.getTensor().getType()); |
1027 | |
1028 | // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. |
1029 | if (inputShape.getShape().empty()) |
1030 | return VectorMemoryAccessKind::ScalarBroadcast; |
1031 | |
1032 | // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false |
1033 | // otherwise. |
1034 | bool isOutput1DVector = |
1035 | (llvm::count_if(resType.getShape(), |
1036 | [](int64_t dimSize) { return dimSize > 1; }) == 1); |
1037 | // 1. Assume that it's a gather load when reading non-1D vector. |
1038 | if (!isOutput1DVector) |
1039 | return VectorMemoryAccessKind::Gather; |
1040 | |
1041 | bool leadingIdxsLoopInvariant = true; |
1042 | |
1043 | // 2. Analyze the leading indices of `extractOp`. |
1044 | // Look at the way each index is calculated and decide whether it is suitable |
1045 | // for a contiguous load, i.e. whether it's loop invariant. If not, it's a |
1046 | // gather load. |
1047 | auto indices = extractOp.getIndices(); |
1048 | auto leadIndices = indices.drop_back(1); |
1049 | |
1050 | for (auto [i, indexVal] : llvm::enumerate(leadIndices)) { |
1051 | if (inputShape.getShape()[i] == 1) |
1052 | continue; |
1053 | |
1054 | leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType); |
1055 | } |
1056 | |
1057 | if (!leadingIdxsLoopInvariant) { |
1058 | LDBG("Found gather load: "<< extractOp); |
1059 | return VectorMemoryAccessKind::Gather; |
1060 | } |
1061 | |
1062 | // 3. Analyze the trailing index for `extractOp`. |
1063 | // At this point we know that the leading indices are loop invariant. This |
1064 | // means that is potentially a scalar or a contiguous load. We can decide |
1065 | // based on the trailing idx. |
1066 | auto extractOpTrailingIdx = indices.back(); |
1067 | |
1068 | // 3a. Scalar broadcast load |
1069 | // If the trailing index is loop invariant then this is a scalar load. |
1070 | if (leadingIdxsLoopInvariant && |
1071 | isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) { |
1072 | LDBG("Found scalar broadcast load: "<< extractOp); |
1073 | |
1074 | return VectorMemoryAccessKind::ScalarBroadcast; |
1075 | } |
1076 | |
1077 | // 3b. Contiguous loads |
1078 | // The trailing `extractOp` index should increment with every loop iteration. |
1079 | // This effectively means that it must be based on the trailing loop index. |
1080 | // This is what the following bool captures. |
1081 | bool foundIndexOp = false; |
1082 | bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, |
1083 | foundIndexOp, resType); |
1084 | // TODO: Support generating contiguous loads for column vectors - that will |
1085 | // require adding a permutation map to tranfer_read Ops. |
1086 | bool isRowVector = resType.getShape().back() != 1; |
1087 | isContiguousLoad &= (foundIndexOp && isRowVector); |
1088 | |
1089 | if (isContiguousLoad) { |
1090 | LDBG("Found contigous load: "<< extractOp); |
1091 | return VectorMemoryAccessKind::Contiguous; |
1092 | } |
1093 | |
1094 | // 4. Fallback case - gather load. |
1095 | LDBG("Found gather load: "<< extractOp); |
1096 | return VectorMemoryAccessKind::Gather; |
1097 | } |
1098 | |
1099 | /// Helper function to vectorize the tensor.extract operations. Returns |
1100 | /// VectorizationStatus::NewOp to signal the vectorization algorithm that it |
1101 | /// should map the produced operations. This function is meant to be used as a |
1102 | /// CustomVectorizationHook. |
1103 | static VectorizationResult |
1104 | vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, |
1105 | Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { |
1106 | tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op); |
1107 | if (!extractOp) |
1108 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
1109 | auto loc = extractOp.getLoc(); |
1110 | |
1111 | // Compute the static loop sizes of the extract op. |
1112 | auto resultType = state.getCanonicalVecType(extractOp.getResult().getType()); |
1113 | auto maskConstantOp = rewriter.create<arith::ConstantOp>( |
1114 | loc, |
1115 | DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()), |
1116 | /*value=*/true)); |
1117 | auto passThruConstantOp = |
1118 | rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType)); |
1119 | |
1120 | // Base indices are currently set to 0. We will need to re-visit if more |
1121 | // generic scenarios are to be supported. |
1122 | SmallVector<Value> baseIndices( |
1123 | extractOp.getIndices().size(), |
1124 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
1125 | |
1126 | VectorMemoryAccessKind memAccessKind = |
1127 | getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType); |
1128 | |
1129 | // 1. Handle gather access |
1130 | if (memAccessKind == VectorMemoryAccessKind::Gather) { |
1131 | Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); |
1132 | |
1133 | // Generate the gather load |
1134 | Operation *gatherOp = rewriter.create<vector::GatherOp>( |
1135 | loc, resultType, extractOp.getTensor(), baseIndices, offset, |
1136 | maskConstantOp, passThruConstantOp); |
1137 | gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp: linalgOp); |
1138 | |
1139 | LDBG("Vectorised as gather load: "<< extractOp << "\n"); |
1140 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: gatherOp}; |
1141 | } |
1142 | |
1143 | // 2. Handle: |
1144 | // a. scalar loads + broadcast, |
1145 | // b. contiguous loads. |
1146 | // Both cases use vector.transfer_read. |
1147 | |
1148 | // Collect indices for `vector.transfer_read`. At this point, the indices will |
1149 | // either be scalars or would have been broadcast to vectors matching the |
1150 | // result type. For indices that are vectors, there are two options: |
1151 | // * for non-trailing indices, all elements are identical (contiguous |
1152 | // loads are identified by looking for non-trailing indices that are |
1153 | // invariant with respect to the corresponding linalg.generic), or |
1154 | // * for trailing indices, the index vector will contain values with stride |
1155 | // one, but for `vector.transfer_read` only the first (i.e. 0th) index is |
1156 | // needed. |
1157 | // This means that |
1158 | // * for scalar indices - just re-use it, |
1159 | // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom |
1160 | // (0th) element and use that. |
1161 | SmallVector<Value> transferReadIdxs; |
1162 | for (size_t i = 0; i < extractOp.getIndices().size(); i++) { |
1163 | Value idx = bvm.lookup(extractOp.getIndices()[i]); |
1164 | if (idx.getType().isIndex()) { |
1165 | transferReadIdxs.push_back(Elt: idx); |
1166 | continue; |
1167 | } |
1168 | |
1169 | auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>( |
1170 | loc, |
1171 | VectorType::get(resultType.getShape().back(), rewriter.getIndexType(), |
1172 | resultType.getScalableDims().back()), |
1173 | idx); |
1174 | transferReadIdxs.push_back( |
1175 | rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0)); |
1176 | } |
1177 | |
1178 | // `tensor.extract_element` is always in-bounds, hence the following holds. |
1179 | auto dstRank = resultType.getRank(); |
1180 | auto srcRank = extractOp.getTensor().getType().getRank(); |
1181 | SmallVector<bool> inBounds(dstRank, true); |
1182 | |
1183 | // 2a. Handle scalar broadcast access. |
1184 | if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { |
1185 | MLIRContext *ctx = rewriter.getContext(); |
1186 | SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx)); |
1187 | auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); |
1188 | |
1189 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
1190 | loc, resultType, extractOp.getTensor(), transferReadIdxs, |
1191 | permutationMap, inBounds); |
1192 | |
1193 | // Mask this broadcasting xfer_read here rather than relying on the generic |
1194 | // path (the generic path assumes identity masking map, which wouldn't be |
1195 | // valid here). |
1196 | SmallVector<int64_t> readMaskShape = {1}; |
1197 | auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type()); |
1198 | auto allTrue = rewriter.create<vector::ConstantMaskOp>( |
1199 | loc, readMaskType, vector::ConstantMaskKind::AllTrue); |
1200 | auto *maskedReadOp = |
1201 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: transferReadOp, mask: allTrue); |
1202 | |
1203 | LDBG("Vectorised as scalar broadcast load: "<< extractOp << "\n"); |
1204 | return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp}; |
1205 | } |
1206 | |
1207 | // 2b. Handle contiguous access. |
1208 | auto permutationMap = AffineMap::getMinorIdentityMap( |
1209 | dims: srcRank, results: std::min(dstRank, srcRank), context: rewriter.getContext()); |
1210 | |
1211 | int32_t rankDiff = dstRank - srcRank; |
1212 | // When dstRank > srcRank, broadcast the source tensor to the unitary leading |
1213 | // dims so that the ranks match. This is done by extending the map with 0s. |
1214 | // For example, for dstRank = 3, srcRank = 2, the following map created |
1215 | // above: |
1216 | // (d0, d1) --> (d0, d1) |
1217 | // is extended as: |
1218 | // (d0, d1) --> (0, d0, d1) |
1219 | while (rankDiff > 0) { |
1220 | permutationMap = permutationMap.insertResult( |
1221 | mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), 0); |
1222 | rankDiff--; |
1223 | } |
1224 | |
1225 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
1226 | loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, |
1227 | inBounds); |
1228 | |
1229 | LDBG("Vectorised as contiguous load: "<< extractOp); |
1230 | return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; |
1231 | } |
1232 | |
1233 | /// Emit reduction operations if the shapes of the value to reduce is different |
1234 | /// that the result shape. |
1235 | // Note: this is a true builder that notifies the OpBuilder listener. |
1236 | // TODO: Consider moving as a static helper on the ReduceOp. |
1237 | static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, |
1238 | Value reduceValue, Value initialValue, |
1239 | const IRMapping &bvm) { |
1240 | Value reduceVec = bvm.lookup(from: reduceValue); |
1241 | Value outputVec = bvm.lookup(from: initialValue); |
1242 | auto reduceType = dyn_cast<VectorType>(reduceVec.getType()); |
1243 | auto outputType = dyn_cast<VectorType>(outputVec.getType()); |
1244 | // Reduce only if needed as the value may already have been reduce for |
1245 | // contraction vectorization. |
1246 | if (!reduceType || |
1247 | (outputType && reduceType.getShape() == outputType.getShape())) |
1248 | return nullptr; |
1249 | SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp); |
1250 | return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask); |
1251 | } |
1252 | |
1253 | /// Generic vectorization for a single operation `op`, given already vectorized |
1254 | /// operands carried by `bvm`. Vectorization occurs as follows: |
1255 | /// 1. Try to apply any of the `customVectorizationHooks` and return its |
1256 | /// result on success. |
1257 | /// 2. Clone any constant in the current scope without vectorization: each |
1258 | /// consumer of the constant will later determine the shape to which the |
1259 | /// constant needs to be broadcast to. |
1260 | /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose |
1261 | /// of the `customVectorizationHooks` to cover such cases. |
1262 | /// 4. Clone `op` in vector form to a vector of shape prescribed by the first |
1263 | /// operand of maximal rank. Other operands have smaller rank and are |
1264 | /// broadcast accordingly. It is assumed this broadcast is always legal, |
1265 | /// otherwise, it means one of the `customVectorizationHooks` is incorrect. |
1266 | /// |
1267 | /// This function assumes all operands of `op` have been vectorized and are in |
1268 | /// the `bvm` mapping. As a consequence, this function is meant to be called on |
1269 | /// a topologically-sorted list of ops. |
1270 | /// This function does not update `bvm` but returns a VectorizationStatus that |
1271 | /// instructs the caller what `bvm` update needs to occur. |
1272 | static VectorizationResult |
1273 | vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, |
1274 | LinalgOp linalgOp, Operation *op, const IRMapping &bvm, |
1275 | ArrayRef<CustomVectorizationHook> customVectorizationHooks) { |
1276 | LDBG("vectorize op "<< *op << "\n"); |
1277 | |
1278 | // 1. Try to apply any CustomVectorizationHook. |
1279 | if (!customVectorizationHooks.empty()) { |
1280 | for (auto &customFunc : customVectorizationHooks) { |
1281 | VectorizationResult result = customFunc(op, bvm); |
1282 | if (result.status == VectorizationStatus::Failure) |
1283 | continue; |
1284 | return result; |
1285 | } |
1286 | } |
1287 | |
1288 | // 2. Constant ops don't get vectorized but rather broadcasted at their users. |
1289 | // Clone so that the constant is not confined to the linalgOp block . |
1290 | if (isa<arith::ConstantOp, func::ConstantOp>(op)) |
1291 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: rewriter.clone(op&: *op)}; |
1292 | |
1293 | // 3. Only ElementwiseMappable are allowed in the generic vectorization. |
1294 | if (!OpTrait::hasElementwiseMappableTraits(op)) |
1295 | return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr}; |
1296 | |
1297 | // 4 . Check if the operation is a reduction. |
1298 | SmallVector<std::pair<Value, Value>> reductionOperands; |
1299 | for (Value operand : op->getOperands()) { |
1300 | auto blockArg = dyn_cast<BlockArgument>(Val&: operand); |
1301 | if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || |
1302 | blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) |
1303 | continue; |
1304 | SmallVector<Operation *> reductionOps; |
1305 | Value reduceValue = matchReduction( |
1306 | linalgOp.getRegionOutputArgs(), |
1307 | blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); |
1308 | if (!reduceValue) |
1309 | continue; |
1310 | reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand)); |
1311 | } |
1312 | if (!reductionOperands.empty()) { |
1313 | assert(reductionOperands.size() == 1); |
1314 | Operation *reduceOp = |
1315 | reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first, |
1316 | reductionOperands[0].second, bvm); |
1317 | if (reduceOp) |
1318 | return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: reduceOp}; |
1319 | } |
1320 | |
1321 | // 5. Generic vectorization path for ElementwiseMappable ops. |
1322 | // a. Get the first max ranked shape. |
1323 | VectorType firstMaxRankedType; |
1324 | for (Value operand : op->getOperands()) { |
1325 | auto vecOperand = bvm.lookup(from: operand); |
1326 | assert(vecOperand && "Vector operand couldn't be found"); |
1327 | |
1328 | auto vecType = dyn_cast<VectorType>(vecOperand.getType()); |
1329 | if (vecType && (!firstMaxRankedType || |
1330 | firstMaxRankedType.getRank() < vecType.getRank())) |
1331 | firstMaxRankedType = vecType; |
1332 | } |
1333 | // b. Broadcast each op if needed. |
1334 | SmallVector<Value> vecOperands; |
1335 | for (Value scalarOperand : op->getOperands()) { |
1336 | Value vecOperand = bvm.lookup(from: scalarOperand); |
1337 | assert(vecOperand && "Vector operand couldn't be found"); |
1338 | |
1339 | if (firstMaxRankedType) { |
1340 | auto vecType = VectorType::get(firstMaxRankedType.getShape(), |
1341 | getElementTypeOrSelf(vecOperand.getType()), |
1342 | firstMaxRankedType.getScalableDims()); |
1343 | vecOperands.push_back(Elt: broadcastIfNeeded(rewriter, vecOperand, vecType)); |
1344 | } else { |
1345 | vecOperands.push_back(Elt: vecOperand); |
1346 | } |
1347 | } |
1348 | // c. for elementwise, the result is the vector with the firstMaxRankedShape |
1349 | SmallVector<Type> resultTypes; |
1350 | for (Type resultType : op->getResultTypes()) { |
1351 | resultTypes.push_back( |
1352 | firstMaxRankedType |
1353 | ? VectorType::get(firstMaxRankedType.getShape(), resultType, |
1354 | firstMaxRankedType.getScalableDims()) |
1355 | : resultType); |
1356 | } |
1357 | // d. Build and return the new op. |
1358 | return VectorizationResult{ |
1359 | .status: VectorizationStatus::NewOp, |
1360 | .newOp: rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands, |
1361 | resultTypes, op->getAttrs())}; |
1362 | } |
1363 | |
1364 | /// Generic vectorization function that rewrites the body of a `linalgOp` into |
1365 | /// vector form. Generic vectorization proceeds as follows: |
1366 | /// 1. Verify the `linalgOp` has one non-empty region. |
1367 | /// 2. Values defined above the region are mapped to themselves and will be |
1368 | /// broadcasted on a per-need basis by their consumers. |
1369 | /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d |
1370 | /// load). |
1371 | /// TODO: Reuse opportunities for RAR dependencies. |
1372 | /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. |
1373 | /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the |
1374 | /// iteration indices. |
1375 | /// 5. Iteratively call vectorizeOneOp on the region operations. |
1376 | /// |
1377 | /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is |
1378 | /// performed to the maximal common vector size implied by the `linalgOp` |
1379 | /// iteration space. This eager broadcasting is introduced in the |
1380 | /// permutation_map of the vector.transfer_read operations. The eager |
1381 | /// broadcasting makes it trivial to detrmine where broadcast, transposes and |
1382 | /// reductions should occur, without any bookkeeping. The tradeoff is that, in |
1383 | /// the absence of good canonicalizations, the amount of work increases. |
1384 | /// This is not deemed a problem as we expect canonicalizations and foldings to |
1385 | /// aggressively clean up the useless work. |
1386 | static LogicalResult |
1387 | vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, |
1388 | LinalgOp linalgOp, |
1389 | SmallVectorImpl<Value> &newResults) { |
1390 | LDBG("Vectorizing operation as linalg generic\n"); |
1391 | Block *block = linalgOp.getBlock(); |
1392 | |
1393 | // 2. Values defined above the region can only be broadcast for now. Make them |
1394 | // map to themselves. |
1395 | IRMapping bvm; |
1396 | SetVector<Value> valuesSet; |
1397 | mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); |
1398 | bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef()); |
1399 | |
1400 | if (linalgOp.getNumDpsInits() == 0) |
1401 | return failure(); |
1402 | |
1403 | // 3. Turn all BBArgs into vector.transfer_read / load. |
1404 | Location loc = linalgOp.getLoc(); |
1405 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1406 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
1407 | BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); |
1408 | if (linalgOp.isScalar(opOperand)) { |
1409 | bvm.map(bbarg, opOperand->get()); |
1410 | continue; |
1411 | } |
1412 | |
1413 | // 3.a. Convert the indexing map for this input/output to a transfer read |
1414 | // permutation map and masking map. |
1415 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); |
1416 | |
1417 | AffineMap readMap; |
1418 | VectorType readType; |
1419 | Type elemType = getElementTypeOrSelf(opOperand->get()); |
1420 | if (linalgOp.isDpsInput(opOperand)) { |
1421 | // 3.a.i. For input reads we use the canonical vector shape. |
1422 | readMap = inverseAndBroadcastProjectedPermutation(indexingMap); |
1423 | readType = state.getCanonicalVecType(elemType); |
1424 | } else { |
1425 | // 3.a.ii. For output reads (iteration-carried dependence, e.g., |
1426 | // reductions), the vector shape is computed by mapping the canonical |
1427 | // vector shape to the output domain and back to the canonical domain. |
1428 | readMap = inversePermutation(reindexIndexingMap(indexingMap)); |
1429 | readType = |
1430 | state.getCanonicalVecType(elemType, readMap.compose(indexingMap)); |
1431 | } |
1432 | |
1433 | SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero); |
1434 | |
1435 | Operation *read = rewriter.create<vector::TransferReadOp>( |
1436 | loc, readType, opOperand->get(), indices, readMap); |
1437 | read = state.maskOperation(rewriter, read, linalgOp, indexingMap); |
1438 | Value readValue = read->getResult(0); |
1439 | |
1440 | // 3.b. If masked, set in-bounds to true. Masking guarantees that the access |
1441 | // will be in-bounds. |
1442 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) { |
1443 | SmallVector<bool> inBounds(readType.getRank(), true); |
1444 | cast<vector::TransferReadOp>(maskOp.getMaskableOp()) |
1445 | .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); |
1446 | } |
1447 | |
1448 | // 3.c. Not all ops support 0-d vectors, extract the scalar for now. |
1449 | // TODO: remove this. |
1450 | if (readType.getRank() == 0) |
1451 | readValue = rewriter.create<vector::ExtractOp>(loc, readValue, |
1452 | ArrayRef<int64_t>()); |
1453 | |
1454 | LDBG("New vectorized bbarg("<< bbarg.getArgNumber() << "): "<< readValue |
1455 | << "\n"); |
1456 | bvm.map(bbarg, readValue); |
1457 | bvm.map(opOperand->get(), readValue); |
1458 | } |
1459 | |
1460 | SmallVector<CustomVectorizationHook> hooks; |
1461 | // 4a. Register CustomVectorizationHook for yieldOp. |
1462 | CustomVectorizationHook vectorizeYield = |
1463 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
1464 | return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults); |
1465 | }; |
1466 | hooks.push_back(Elt: vectorizeYield); |
1467 | |
1468 | // 4b. Register CustomVectorizationHook for indexOp. |
1469 | CustomVectorizationHook vectorizeIndex = |
1470 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
1471 | return vectorizeLinalgIndex(rewriter, state, op, linalgOp); |
1472 | }; |
1473 | hooks.push_back(Elt: vectorizeIndex); |
1474 | |
1475 | // 4c. Register CustomVectorizationHook for extractOp. |
1476 | CustomVectorizationHook vectorizeExtract = |
1477 | [&](Operation *op, const IRMapping &bvm) -> VectorizationResult { |
1478 | return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); |
1479 | }; |
1480 | hooks.push_back(Elt: vectorizeExtract); |
1481 | |
1482 | // 5. Iteratively call `vectorizeOneOp` to each op in the slice. |
1483 | for (Operation &op : block->getOperations()) { |
1484 | VectorizationResult result = |
1485 | vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); |
1486 | if (result.status == VectorizationStatus::Failure) { |
1487 | LDBG("failed to vectorize: "<< op << "\n"); |
1488 | return failure(); |
1489 | } |
1490 | if (result.status == VectorizationStatus::NewOp) { |
1491 | Operation *maybeMaskedOp = |
1492 | state.maskOperation(rewriter, result.newOp, linalgOp); |
1493 | LDBG("New vector op: "<< *maybeMaskedOp << "\n"); |
1494 | bvm.map(op.getResults(), maybeMaskedOp->getResults()); |
1495 | } |
1496 | } |
1497 | |
1498 | return success(); |
1499 | } |
1500 | |
1501 | /// Given a linalg::PackOp, return the `dest` shape before any packing |
1502 | /// permutations. |
1503 | static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp, |
1504 | ArrayRef<int64_t> destShape) { |
1505 | return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp)); |
1506 | } |
1507 | |
1508 | /// Determines whether a mask for xfer_write is trivially "all true" |
1509 | /// |
1510 | /// Given all the inputs required to generate a mask (mask sizes and shapes), |
1511 | /// and an xfer_write operation (write indices and the destination tensor |
1512 | /// shape), determines whether the corresponding mask would be trivially |
1513 | /// foldable (i.e., trivially "all true"). |
1514 | /// |
1515 | /// Use this method to avoid generating spurious masks and relaying on |
1516 | /// vectorization post-processing to remove them. |
1517 | /// |
1518 | /// Pre-conditions for a mask to be trivially foldable: |
1519 | /// * All involved shapes (mask + destination tensor) are static. |
1520 | /// * All write indices are constant. |
1521 | /// * All mask sizes are constant (including `arith.constant`). |
1522 | /// |
1523 | /// If the pre-conditions are met, the method checks for each destination |
1524 | /// dimension `d`: |
1525 | /// (1) destDimSize[rankDiff + d] <= maskShape[d] |
1526 | /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d] |
1527 | /// |
1528 | /// rankDiff = rank(dest) - rank(mask). |
1529 | /// |
1530 | /// This method takes a conservative view: it may return false even if the mask |
1531 | /// is technically foldable. |
1532 | /// |
1533 | /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape |
1534 | /// of the dest tensor): |
1535 | /// %c0 = arith.constant 0 : index |
1536 | /// %mask = vector.create_mask 5, 1 |
1537 | /// vector.mask %mask { |
1538 | /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0] |
1539 | /// {in_bounds = [true, true]} |
1540 | /// : vector<5x1xi32>, tensor<5x1xi32> |
1541 | /// } |
1542 | /// |
1543 | /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape, |
1544 | /// mask is required to avoid out-of-bounds write): |
1545 | /// %c0 = arith.constant 0 : index |
1546 | /// %mask = vector.create_mask 5, 1 |
1547 | /// vector.mask %mask { |
1548 | /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0] |
1549 | /// {in_bounds = [true, true]} |
1550 | /// : vector<8x1xi32>, tensor<5x1xi32> |
1551 | /// } |
1552 | /// |
1553 | /// TODO: Re-use in createReadOrMaskedRead |
1554 | static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes, |
1555 | SmallVector<Value> &writeIdxs, |
1556 | ArrayRef<int64_t> destShape, |
1557 | ArrayRef<int64_t> maskShape) { |
1558 | // Masking is unavoidable in the case of dynamic tensors. |
1559 | if (ShapedType::isDynamicShape(destShape)) |
1560 | return false; |
1561 | |
1562 | // Collect all constant mask sizes. |
1563 | SmallVector<int64_t, 4> cstMaskSizes; |
1564 | for (auto [i, dimSize] : llvm::enumerate(First&: maskSizes)) { |
1565 | if (auto intSize = getConstantIntValue(ofr: dimSize)) { |
1566 | cstMaskSizes.push_back(Elt: *intSize); |
1567 | } |
1568 | } |
1569 | |
1570 | // If any of the mask sizes is non-constant, bail out. |
1571 | if (cstMaskSizes.size() != maskShape.size()) |
1572 | return false; |
1573 | |
1574 | // Collect all constant write indices. |
1575 | SmallVector<int64_t, 4> cstWriteIdxs; |
1576 | for (auto [i, idx] : llvm::enumerate(First&: writeIdxs)) { |
1577 | APSInt intVal; |
1578 | if (matchPattern(idx, m_ConstantInt(&intVal))) { |
1579 | cstWriteIdxs.push_back(Elt: intVal.getSExtValue()); |
1580 | } |
1581 | } |
1582 | |
1583 | // If any of the write indices is non-constant, bail out. |
1584 | if (cstWriteIdxs.size() != destShape.size()) |
1585 | return false; |
1586 | |
1587 | // Go over all destination dims and check (1) and (2). Take into account that: |
1588 | // * The number of mask sizes will match the rank of the vector to store. |
1589 | // This could be lower than the rank of the destination tensor. |
1590 | // * Mask sizes could be larger than the corresponding mask shape (hence |
1591 | // `clamp`). |
1592 | // TODO: The 2nd item should be rejected by the verifier. |
1593 | int64_t rankDiff = destShape.size() - cstMaskSizes.size(); |
1594 | for (auto [i, idx] : llvm::enumerate(First&: cstMaskSizes)) { |
1595 | if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] || |
1596 | /*(2)*/ destShape[rankDiff + i] < |
1597 | (std::clamp(val: cstMaskSizes[i], lo: int64_t(0), hi: maskShape[i]) + |
1598 | cstWriteIdxs[i])) |
1599 | return false; |
1600 | } |
1601 | |
1602 | return true; |
1603 | } |
1604 | |
1605 | /// Creates an optionally masked TransferWriteOp |
1606 | /// |
1607 | /// Generates the following operation: |
1608 | /// %res = vector.transfer_write %vecToStore into %dest |
1609 | /// |
1610 | /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness: |
1611 | /// |
1612 | /// %mask = vector.create_mask(%destShape) : %vecToStoreShape |
1613 | /// %res = vector.mask %mask { |
1614 | /// vector.transfer_write %vecToStore into %dest |
1615 | /// } |
1616 | /// |
1617 | /// The mask shape is identical to `vecToStore` (with the element type == |
1618 | /// i1), and the mask values are based on the shape of the `dest` tensor. |
1619 | /// |
1620 | /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute |
1621 | /// is used instead of masking: |
1622 | /// |
1623 | /// %write = vector.transfer_write %vecToStore into %dest |
1624 | /// in_bounds_flags = (...) |
1625 | /// %res = vector.transfer_write %input into %dest |
1626 | /// {in_bounds = in_bounds_flags} |
1627 | /// |
1628 | /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices |
1629 | /// are set to 0. |
1630 | static Operation * |
1631 | createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, |
1632 | Value dest, SmallVector<Value> writeIndices = {}, |
1633 | bool useInBoundsInsteadOfMasking = false) { |
1634 | |
1635 | ShapedType destType = cast<ShapedType>(dest.getType()); |
1636 | int64_t destRank = destType.getRank(); |
1637 | auto destShape = destType.getShape(); |
1638 | |
1639 | VectorType vecToStoreType = cast<VectorType>(vecToStore.getType()); |
1640 | int64_t vecToStoreRank = vecToStoreType.getRank(); |
1641 | auto vecToStoreShape = vecToStoreType.getShape(); |
1642 | |
1643 | // Compute the in_bounds attribute |
1644 | SmallVector<bool> inBoundsVal(vecToStoreRank, true); |
1645 | if (useInBoundsInsteadOfMasking) { |
1646 | // Update the inBounds attribute. |
1647 | // FIXME: This computation is too weak - it ignores the write indices. |
1648 | for (unsigned i = 0; i < vecToStoreRank; i++) |
1649 | inBoundsVal[i] = |
1650 | (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) && |
1651 | !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]); |
1652 | } |
1653 | |
1654 | // If missing, initialize the write indices to 0. |
1655 | assert(writeIndices.empty() || |
1656 | writeIndices.size() == static_cast<size_t>(destRank) && |
1657 | "Invalid number of write indices!"); |
1658 | if (writeIndices.empty()) { |
1659 | auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1660 | writeIndices.assign(destRank, zero); |
1661 | } |
1662 | |
1663 | // Generate the xfer_write Op |
1664 | Operation *write = |
1665 | builder.create<vector::TransferWriteOp>(loc, |
1666 | /*vector=*/vecToStore, |
1667 | /*source=*/dest, |
1668 | /*indices=*/writeIndices, |
1669 | /*inBounds=*/inBoundsVal); |
1670 | |
1671 | // If masking is disabled, exit. |
1672 | if (useInBoundsInsteadOfMasking) |
1673 | return write; |
1674 | |
1675 | // Check if masking is needed. If not, exit. |
1676 | if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank))) |
1677 | return write; |
1678 | |
1679 | // Compute the mask and mask the write Op. |
1680 | auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type()); |
1681 | |
1682 | SmallVector<OpFoldResult> destSizes = |
1683 | tensor::getMixedSizes(builder, loc, value: dest); |
1684 | SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank, |
1685 | destSizes.end()); |
1686 | |
1687 | if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape, |
1688 | vecToStoreShape)) |
1689 | return write; |
1690 | |
1691 | Value maskForWrite = |
1692 | builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes); |
1693 | return mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite); |
1694 | } |
1695 | |
1696 | /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant |
1697 | /// padding value and (3) input vector sizes into: |
1698 | /// |
1699 | /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds |
1700 | /// |
1701 | /// As in the following example: |
1702 | /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] |
1703 | /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> |
1704 | /// |
1705 | /// This pack would be vectorized to: |
1706 | /// |
1707 | /// %load = vector.mask %mask { |
1708 | /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst |
1709 | /// {in_bounds = [true, true, true]} : |
1710 | /// tensor<32x7x16xf32>, vector<32x8x16xf32> |
1711 | /// } : vector<32x8x16xi1> -> vector<32x8x16xf32> |
1712 | /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> |
1713 | /// to vector<32x4x2x1x16xf32> |
1714 | /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] |
1715 | /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> |
1716 | /// %write = vector.transfer_write %transpose, |
1717 | /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] |
1718 | /// {in_bounds = [true, true, true, true, true]} |
1719 | /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> |
1720 | /// |
1721 | /// If the (3) input vector sizes are not provided, the vector sizes are |
1722 | /// determined by the result tensor shape and the `in_bounds` |
1723 | /// attribute is used instead of masking to mark out-of-bounds accesses. |
1724 | /// |
1725 | /// NOTE: The input vector sizes specify the dimensions corresponding to the |
1726 | /// outer dimensions of the output tensor. The remaining dimensions are |
1727 | /// computed based on, e.g., the static inner tiles. |
1728 | /// Supporting dynamic inner tiles will require the user to specify the |
1729 | /// missing vector sizes. This is left as a TODO. |
1730 | static LogicalResult |
1731 | vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, |
1732 | ArrayRef<int64_t> inputVectorSizes, |
1733 | SmallVectorImpl<Value> &newResults) { |
1734 | // TODO: Introduce a parent class that will handle the insertion point update. |
1735 | OpBuilder::InsertionGuard g(rewriter); |
1736 | rewriter.setInsertionPoint(packOp); |
1737 | |
1738 | Location loc = packOp.getLoc(); |
1739 | auto padValue = packOp.getPaddingValue(); |
1740 | if (!padValue) { |
1741 | padValue = rewriter.create<arith::ConstantOp>( |
1742 | loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); |
1743 | } |
1744 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
1745 | LogicalResult status = |
1746 | cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation()) |
1747 | .reifyResultShapes(rewriter, reifiedReturnShapes); |
1748 | (void)status; // prevent unused variable warning on non-assert builds. |
1749 | assert(succeeded(status) && "failed to reify result shapes"); |
1750 | |
1751 | // If the input vector sizes are not provided, then the vector sizes are |
1752 | // determined by the result tensor shape. In case the vector sizes aren't |
1753 | // provided, we update the inBounds attribute instead of masking. |
1754 | bool useInBoundsInsteadOfMasking = false; |
1755 | if (inputVectorSizes.empty()) { |
1756 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
1757 | inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank()); |
1758 | useInBoundsInsteadOfMasking = true; |
1759 | } |
1760 | |
1761 | // Create masked TransferReadOp. |
1762 | SmallVector<int64_t> inputShape(inputVectorSizes); |
1763 | auto innerTiles = packOp.getStaticInnerTiles(); |
1764 | auto innerDimsPos = packOp.getInnerDimsPos(); |
1765 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
1766 | if (!outerDimsPerm.empty()) |
1767 | applyPermutationToVector(inputShape, |
1768 | invertPermutationVector(outerDimsPerm)); |
1769 | for (auto [idx, size] : enumerate(innerTiles)) |
1770 | inputShape[innerDimsPos[idx]] *= size; |
1771 | auto maskedRead = vector::createReadOrMaskedRead( |
1772 | builder&: rewriter, loc, source: packOp.getSource(), inputVectorSizes: inputShape, padValue: padValue, |
1773 | useInBoundsInsteadOfMasking); |
1774 | |
1775 | // Create ShapeCastOp. |
1776 | SmallVector<int64_t> destShape(inputVectorSizes); |
1777 | destShape.append(innerTiles.begin(), innerTiles.end()); |
1778 | auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), |
1779 | packOp.getDestType().getElementType()); |
1780 | auto shapeCastOp = |
1781 | rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead); |
1782 | |
1783 | // Create TransposeOp. |
1784 | auto destPermutation = |
1785 | invertPermutationVector(getPackInverseDestPerm(packOp)); |
1786 | auto transposeOp = rewriter.create<vector::TransposeOp>( |
1787 | loc, shapeCastOp.getResult(), destPermutation); |
1788 | |
1789 | // Create TransferWriteOp. |
1790 | Value dest = rewriter.create<tensor::EmptyOp>( |
1791 | loc, reifiedReturnShapes[0], |
1792 | transposeOp.getResult().getType().getElementType()); |
1793 | Operation *write = |
1794 | createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest); |
1795 | newResults.push_back(Elt: write->getResult(idx: 0)); |
1796 | return success(); |
1797 | } |
1798 | |
1799 | /// Vectorize a `linalg::UnPackOp` to these 4 Ops: |
1800 | /// Vector::TransferReadOp - Reads a vector from the source tensor |
1801 | /// vector::TransposeOp - Transpose the Source tensor |
1802 | /// ShapeCastOp - Reshape the data based on the target. |
1803 | /// vector::TransferWriteOp. - Write the result vector back to the destination |
1804 | /// tensor. |
1805 | /// If the vector sizes are not provided: |
1806 | /// * the vector sizes are determined by the input operand and attributes, |
1807 | /// * update the inBounds attribute instead of masking. |
1808 | static LogicalResult |
1809 | vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, |
1810 | ArrayRef<int64_t> inputVectorSizes, |
1811 | SmallVectorImpl<Value> &newResults) { |
1812 | |
1813 | // TODO: Introduce a parent class that will handle the insertion point update. |
1814 | OpBuilder::InsertionGuard g(rewriter); |
1815 | rewriter.setInsertionPoint(unpackOp); |
1816 | |
1817 | RankedTensorType unpackTensorType = unpackOp.getSourceType(); |
1818 | |
1819 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
1820 | ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); |
1821 | ArrayRef<int64_t> sourceShape = unpackTensorType.getShape(); |
1822 | bool useInBoundsInsteadOfMasking = false; |
1823 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
1824 | |
1825 | auto destSize = unpackOp.getDestRank(); |
1826 | |
1827 | if (!inputVectorSizes.empty()) |
1828 | assert(inputVectorSizes.size() == destSize && |
1829 | "Incorrect number of input vector sizes"); |
1830 | |
1831 | // vectorSizes is the shape of the vector that will be used to do final |
1832 | // write on the destination tensor. It is set like this: Let's say the |
1833 | // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. |
1834 | // Thus: |
1835 | // 1. vectorSizes = sourceShape.take_front(N) |
1836 | // 2. if outer_dims_perms is present: do that permutation on vectorSizes. |
1837 | // 3. multiply all the locations in vectorSize pointed by innerDimPos by the |
1838 | // innerTiles attribute value. |
1839 | SmallVector<int64_t> vectorSizes(inputVectorSizes); |
1840 | if (vectorSizes.empty()) { |
1841 | llvm::append_range(vectorSizes, sourceShape.take_front(N: destSize)); |
1842 | if (!outerDimsPerm.empty()) |
1843 | applyPermutationToVector(inVec&: vectorSizes, permutation: outerDimsPerm); |
1844 | for (auto [i, pos] : llvm::enumerate(innerDimPos)) |
1845 | vectorSizes[pos] *= innerTiles[i]; |
1846 | |
1847 | useInBoundsInsteadOfMasking = true; |
1848 | } |
1849 | |
1850 | // readVectorSizes is the size of tensor used to read and apply mask. It is |
1851 | // set like this: Let's say the vectorSize (VS) array is size 'N' and |
1852 | // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of |
1853 | // size M-N |
1854 | // Thus: |
1855 | // - initially: readVectorSizes = vectorInputSizes |
1856 | // - Divide all the readMaskShape locations pointed by innerDimPos |
1857 | // by the innerTileSize attribute value. |
1858 | // - if outer_dims_perms is present: do that permutation on readVectorSizes. |
1859 | // - Append the remaining shape from SS |
1860 | // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> |
1861 | // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, |
1862 | // 128] and outer_dims_perm is [1, 0] then read shape is: |
1863 | // ReadVectorSizes(initial): [512, 128] |
1864 | // Final Value(after innerDim Adjustment): [512/32, 128/16] |
1865 | // = [16, 8] |
1866 | // After applying outer_dims_perm: [8, 16] |
1867 | // After appending the rest of the sourceShape: [8, 16, 32, 16] |
1868 | |
1869 | SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end()); |
1870 | |
1871 | for (auto [index, size] : enumerate(innerTiles)) { |
1872 | readVectorSizes[innerDimPos[index]] = |
1873 | llvm::divideCeil(readVectorSizes[innerDimPos[index]], size); |
1874 | } |
1875 | if (!outerDimsPerm.empty()) { |
1876 | applyPermutationToVector(inVec&: readVectorSizes, permutation: outerDimsPerm); |
1877 | } |
1878 | readVectorSizes.append(in_start: sourceShape.begin() + vectorSizes.size(), |
1879 | in_end: sourceShape.end()); |
1880 | |
1881 | ReifiedRankedShapedTypeDims reifiedRetShapes; |
1882 | LogicalResult status = |
1883 | cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) |
1884 | .reifyResultShapes(rewriter, reifiedRetShapes); |
1885 | if (status.failed()) { |
1886 | LDBG("Unable to reify result shapes of "<< unpackOp); |
1887 | return failure(); |
1888 | } |
1889 | Location loc = unpackOp->getLoc(); |
1890 | |
1891 | auto padValue = rewriter.create<arith::ConstantOp>( |
1892 | loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); |
1893 | |
1894 | // Read result, mask if necessary. If transferReadOp shape is not equal |
1895 | // to shape of source, then a mask is necessary. |
1896 | Value readResult = vector::createReadOrMaskedRead( |
1897 | builder&: rewriter, loc, source: unpackOp.getSource(), inputVectorSizes: readVectorSizes, padValue: padValue, |
1898 | /*useInBoundsInsteadOfMasking=*/false); |
1899 | |
1900 | PackingMetadata packMetadata; |
1901 | SmallVector<int64_t> lastDimToInsertPosPerm = |
1902 | getUnPackInverseSrcPerm(unpackOp, packMetadata); |
1903 | ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType()); |
1904 | SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape()); |
1905 | mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); |
1906 | applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm); |
1907 | RankedTensorType stripMineTensorType = |
1908 | RankedTensorType::get(stripMineShape, stripMineElemType); |
1909 | // Transpose the appropriate rows to match output. |
1910 | vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>( |
1911 | loc, readResult, lastDimToInsertPosPerm); |
1912 | |
1913 | // Collapse the vector to the size required by result. |
1914 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
1915 | stripMineTensorType, packMetadata.reassociations); |
1916 | mlir::VectorType vecCollapsedType = |
1917 | VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); |
1918 | vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( |
1919 | loc, vecCollapsedType, transposeOp->getResult(0)); |
1920 | |
1921 | // writeVectorSizes had to match the shapecast shape for dynamic sizes, |
1922 | // otherwise the validator complains that the mask size is invalid. |
1923 | SmallVector<int64_t> writeVectorSizes( |
1924 | unpackOp.getDestType().hasStaticShape() |
1925 | ? vectorSizes |
1926 | : shapeCastOp.getResultVectorType().getShape()); |
1927 | Value dest = rewriter.create<tensor::EmptyOp>( |
1928 | loc, reifiedRetShapes[0], |
1929 | shapeCastOp.getResult().getType().getElementType()); |
1930 | Operation *write = createWriteOrMaskedWrite( |
1931 | rewriter, loc, shapeCastOp.getResult(), dest, |
1932 | /*writeIndices=*/{}, useInBoundsInsteadOfMasking); |
1933 | newResults.push_back(Elt: write->getResult(idx: 0)); |
1934 | return success(); |
1935 | } |
1936 | |
1937 | /// Vectorize a `padOp` with (1) static result type, (2) constant padding value |
1938 | /// and (3) all-zero lowPad to |
1939 | /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. |
1940 | static LogicalResult |
1941 | vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, |
1942 | ArrayRef<int64_t> inputVectorSizes, |
1943 | SmallVectorImpl<Value> &newResults) { |
1944 | auto padValue = padOp.getConstantPaddingValue(); |
1945 | Location loc = padOp.getLoc(); |
1946 | |
1947 | // TODO: Introduce a parent class that will handle the insertion point update. |
1948 | OpBuilder::InsertionGuard g(rewriter); |
1949 | rewriter.setInsertionPoint(padOp); |
1950 | |
1951 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
1952 | LogicalResult status = |
1953 | cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation()) |
1954 | .reifyResultShapes(rewriter, reifiedReturnShapes); |
1955 | (void)status; // prevent unused variable warning on non-assert builds |
1956 | assert(succeeded(status) && "failed to reify result shapes"); |
1957 | auto maskedRead = vector::createReadOrMaskedRead( |
1958 | builder&: rewriter, loc, source: padOp.getSource(), inputVectorSizes, padValue: padValue, |
1959 | /*useInBoundsInsteadOfMasking=*/false); |
1960 | |
1961 | // Create Xfer write Op |
1962 | Value dest = rewriter.create<tensor::EmptyOp>( |
1963 | loc, reifiedReturnShapes[0], padOp.getResultType().getElementType()); |
1964 | Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest); |
1965 | newResults.push_back(Elt: write->getResult(idx: 0)); |
1966 | return success(); |
1967 | } |
1968 | |
1969 | // TODO: probably need some extra checks for reduction followed by consumer |
1970 | // ops that may not commute (e.g. linear reduction + non-linear instructions). |
1971 | static LogicalResult reductionPreconditions(LinalgOp op) { |
1972 | if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { |
1973 | LDBG("reduction precondition failed: no reduction iterator\n"); |
1974 | return failure(); |
1975 | } |
1976 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
1977 | AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand); |
1978 | if (indexingMap.isPermutation()) |
1979 | continue; |
1980 | |
1981 | Operation *reduceOp = matchLinalgReduction(&opOperand); |
1982 | if (!reduceOp || !getCombinerOpKind(reduceOp)) { |
1983 | LDBG("reduction precondition failed: reduction detection failed\n"); |
1984 | return failure(); |
1985 | } |
1986 | } |
1987 | return success(); |
1988 | } |
1989 | |
1990 | static LogicalResult |
1991 | vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, |
1992 | bool flatten1DDepthwiseConv) { |
1993 | if (flatten1DDepthwiseConv) { |
1994 | LDBG("Vectorization of flattened convs with dynamic shapes is not " |
1995 | "supported\n"); |
1996 | return failure(); |
1997 | } |
1998 | |
1999 | if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) { |
2000 | LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n"); |
2001 | return failure(); |
2002 | } |
2003 | |
2004 | // Support dynamic shapes in 1D depthwise convolution, but only in the |
2005 | // _channel_ dimension. |
2006 | Value lhs = conv.getDpsInputOperand(0)->get(); |
2007 | ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape(); |
2008 | auto shapeWithoutCh = lhsShape.drop_back(N: 1); |
2009 | if (ShapedType::isDynamicShape(shapeWithoutCh)) { |
2010 | LDBG("Dynamically-shaped op vectorization precondition failed: only " |
2011 | "channel dim can be dynamic\n"); |
2012 | return failure(); |
2013 | } |
2014 | |
2015 | return success(); |
2016 | } |
2017 | |
2018 | static LogicalResult |
2019 | vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, |
2020 | bool flatten1DDepthwiseConv) { |
2021 | if (isa<ConvolutionOpInterface>(op.getOperation())) |
2022 | return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv); |
2023 | |
2024 | if (hasReductionIterator(op)) |
2025 | return reductionPreconditions(op); |
2026 | |
2027 | // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops, |
2028 | // linalg.copy ops and ops that implement ContractionOpInterface for now. |
2029 | if (!isElementwise(op) && |
2030 | !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>( |
2031 | op.getOperation())) |
2032 | return failure(); |
2033 | |
2034 | LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); |
2035 | return success(); |
2036 | } |
2037 | |
2038 | /// Need to check if the inner-tiles are static/constant. |
2039 | static LogicalResult |
2040 | vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, |
2041 | ArrayRef<int64_t> inputVectorSizes) { |
2042 | |
2043 | if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { |
2044 | return !getConstantIntValue(ofr: res).has_value(); |
2045 | })) { |
2046 | LDBG("Inner-tiles must be constant: "<< unpackOp << "\n"); |
2047 | return failure(); |
2048 | } |
2049 | ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); |
2050 | bool satisfyEmptyCond = inputVectorSizes.empty() && |
2051 | unpackOp.getDestType().hasStaticShape() && |
2052 | unpackOp.getSourceType().hasStaticShape(); |
2053 | if (!satisfyEmptyCond && |
2054 | failed(Result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes))) |
2055 | return failure(); |
2056 | |
2057 | return success(); |
2058 | } |
2059 | |
2060 | static LogicalResult |
2061 | vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, |
2062 | ArrayRef<int64_t> inputVectorSizes) { |
2063 | |
2064 | TypedValue<RankedTensorType> source = sliceOp.getSource(); |
2065 | auto sourceType = source.getType(); |
2066 | if (!VectorType::isValidElementType(sourceType.getElementType())) |
2067 | return failure(); |
2068 | |
2069 | // Get the pad value. |
2070 | // TransferReadOp (which is used to vectorize InsertSliceOp), requires a |
2071 | // scalar padding value. Note that: |
2072 | // * for in-bounds accesses, |
2073 | // the value is actually irrelevant. There are 2 cases in which xfer.read |
2074 | // accesses are known to be in-bounds: |
2075 | // 1. The source shape is static (output vector sizes would be based on |
2076 | // the source shape and hence all memory accesses would be in-bounds), |
2077 | // 2. Masking is used, i.e. the output vector sizes are user-provided. In |
2078 | // this case it is safe to assume that all memory accesses are in-bounds. |
2079 | // |
2080 | // When the value is not known and not needed, use 0. Otherwise, bail out. |
2081 | Value padValue = getStaticPadVal(sliceOp); |
2082 | bool isOutOfBoundsRead = |
2083 | !sourceType.hasStaticShape() && inputVectorSizes.empty(); |
2084 | |
2085 | if (!padValue && isOutOfBoundsRead) { |
2086 | LDBG("Failed to get a pad value for out-of-bounds read access\n"); |
2087 | return failure(); |
2088 | } |
2089 | return success(); |
2090 | } |
2091 | |
2092 | namespace { |
2093 | enum class ConvOperationKind { Conv, Pool }; |
2094 | } // namespace |
2095 | |
2096 | static bool isCastOfBlockArgument(Operation *op) { |
2097 | return isa<CastOpInterface>(op) && op->getNumOperands() == 1 && |
2098 | isa<BlockArgument>(op->getOperand(0)); |
2099 | } |
2100 | |
2101 | // Returns the ConvOperationKind of the op using reduceOp of the generic |
2102 | // payload. If it is neither a convolution nor a pooling, it returns |
2103 | // std::nullopt. |
2104 | // |
2105 | // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction |
2106 | // + yield) and rhs is not used) then it is the body of a pooling |
2107 | // If conv, check for single `mul` predecessor. The `mul` operands must be |
2108 | // block arguments or extension of block arguments. |
2109 | // Otherwise, check for one or zero `ext` predecessor. The `ext` operands |
2110 | // must be block arguments or extension of block arguments. |
2111 | static std::optional<ConvOperationKind> |
2112 | getConvOperationKind(Operation *reduceOp) { |
2113 | int numBlockArguments = |
2114 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
2115 | |
2116 | switch (numBlockArguments) { |
2117 | case 1: { |
2118 | // Will be convolution if feeder is a MulOp. |
2119 | // A strength reduced version of MulOp for i1 type is AndOp which is also |
2120 | // supported. Otherwise, it can be pooling. This strength reduction logic |
2121 | // is in `buildBinaryFn` helper in the Linalg dialect. |
2122 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
2123 | P: llvm::IsaPred<BlockArgument>); |
2124 | assert(feedValIt != reduceOp->operand_end() && |
2125 | "Expected a non-block argument operand"); |
2126 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
2127 | if (isCastOfBlockArgument(op: feedOp)) { |
2128 | return ConvOperationKind::Pool; |
2129 | } |
2130 | |
2131 | if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) || |
2132 | (isa<arith::AndIOp>(feedOp) && |
2133 | feedOp->getResultTypes()[0].isInteger(1))) && |
2134 | llvm::all_of(feedOp->getOperands(), [](Value v) { |
2135 | if (isa<BlockArgument>(v)) |
2136 | return true; |
2137 | if (Operation *op = v.getDefiningOp()) |
2138 | return isCastOfBlockArgument(op); |
2139 | return false; |
2140 | }))) { |
2141 | return std::nullopt; |
2142 | } |
2143 | |
2144 | return ConvOperationKind::Conv; |
2145 | } |
2146 | case 2: |
2147 | // Must be pooling |
2148 | return ConvOperationKind::Pool; |
2149 | default: |
2150 | return std::nullopt; |
2151 | } |
2152 | } |
2153 | |
2154 | static bool isSupportedPoolKind(vector::CombiningKind kind) { |
2155 | switch (kind) { |
2156 | case vector::CombiningKind::ADD: |
2157 | case vector::CombiningKind::MAXNUMF: |
2158 | case vector::CombiningKind::MAXIMUMF: |
2159 | case vector::CombiningKind::MAXSI: |
2160 | case vector::CombiningKind::MAXUI: |
2161 | case vector::CombiningKind::MINNUMF: |
2162 | case vector::CombiningKind::MINIMUMF: |
2163 | case vector::CombiningKind::MINSI: |
2164 | case vector::CombiningKind::MINUI: |
2165 | return true; |
2166 | default: |
2167 | return false; |
2168 | } |
2169 | } |
2170 | |
2171 | static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) { |
2172 | auto getOperandType = [&](auto operand) { |
2173 | return dyn_cast<ShapedType>((operand->get()).getType()); |
2174 | }; |
2175 | ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0)); |
2176 | ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1)); |
2177 | ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0)); |
2178 | // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR |
2179 | // (non-channeled convolution -> LHS and RHS both have single dimensions). |
2180 | // Note that this also ensures 2D and 3D convolutions are rejected. |
2181 | if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && |
2182 | (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) |
2183 | return failure(); |
2184 | |
2185 | Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0)); |
2186 | if (!reduceOp) |
2187 | return failure(); |
2188 | |
2189 | auto maybeOper = getConvOperationKind(reduceOp); |
2190 | if (!maybeOper.has_value()) |
2191 | return failure(); |
2192 | |
2193 | auto maybeKind = getCombinerOpKind(reduceOp); |
2194 | // Typically convolution will have a `Add` CombiningKind but for i1 type it |
2195 | // can get strength reduced to `OR` which is also supported. This strength |
2196 | // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. |
2197 | if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && |
2198 | *maybeKind != vector::CombiningKind::OR) && |
2199 | (*maybeOper != ConvOperationKind::Pool || |
2200 | !isSupportedPoolKind(*maybeKind)))) { |
2201 | return failure(); |
2202 | } |
2203 | |
2204 | auto rhsRank = rhsShapedType.getRank(); |
2205 | if (*maybeOper == ConvOperationKind::Pool) { |
2206 | if (rhsRank != 1) |
2207 | return failure(); |
2208 | } else { |
2209 | if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) |
2210 | return failure(); |
2211 | } |
2212 | |
2213 | return success(); |
2214 | } |
2215 | |
2216 | static LogicalResult vectorizeLinalgOpPrecondition( |
2217 | LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes, |
2218 | bool vectorizeNDExtract, bool flatten1DDepthwiseConv) { |
2219 | // tensor with dimension of 0 cannot be vectorized. |
2220 | if (llvm::is_contained(linalgOp.getStaticShape(), 0)) |
2221 | return failure(); |
2222 | // Check API contract for input vector sizes. |
2223 | if (!inputVectorSizes.empty() && |
2224 | failed(vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(), |
2225 | inputVectorSizes))) |
2226 | return failure(); |
2227 | |
2228 | if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition( |
2229 | linalgOp, flatten1DDepthwiseConv))) { |
2230 | LDBG("Dynamically-shaped op failed vectorization pre-conditions\n"); |
2231 | return failure(); |
2232 | } |
2233 | |
2234 | SmallVector<CustomVectorizationPrecondition> customPreconditions; |
2235 | |
2236 | // Register CustomVectorizationPrecondition for extractOp. |
2237 | customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition); |
2238 | |
2239 | // All types in the body should be a supported element type for VectorType. |
2240 | for (Operation &innerOp : linalgOp->getRegion(0).front()) { |
2241 | // Check if any custom hook can vectorize the inner op. |
2242 | if (llvm::any_of( |
2243 | customPreconditions, |
2244 | [&](const CustomVectorizationPrecondition &customPrecondition) { |
2245 | return succeeded( |
2246 | customPrecondition(&innerOp, vectorizeNDExtract)); |
2247 | })) { |
2248 | continue; |
2249 | } |
2250 | if (!llvm::all_of(innerOp.getOperandTypes(), |
2251 | VectorType::isValidElementType)) { |
2252 | return failure(); |
2253 | } |
2254 | if (!llvm::all_of(innerOp.getResultTypes(), |
2255 | VectorType::isValidElementType)) { |
2256 | return failure(); |
2257 | } |
2258 | } |
2259 | if (isElementwise(linalgOp)) |
2260 | return success(); |
2261 | |
2262 | // TODO: isaConvolutionOpInterface that can also infer from generic |
2263 | // features. But we will still need stride/dilation attributes that will be |
2264 | // annoying to reverse-engineer... |
2265 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) |
2266 | return vectorizeConvOpPrecondition(linalgOp); |
2267 | |
2268 | // TODO: the common vector shape is equal to the static loop sizes only when |
2269 | // all indexing maps are projected permutations. For convs and stencils the |
2270 | // logic will need to evolve. |
2271 | if (!allIndexingsAreProjectedPermutation(linalgOp)) { |
2272 | LDBG("precondition failed: not projected permutations\n"); |
2273 | return failure(); |
2274 | } |
2275 | if (failed(reductionPreconditions(linalgOp))) { |
2276 | LDBG("precondition failed: reduction preconditions\n"); |
2277 | return failure(); |
2278 | } |
2279 | return success(); |
2280 | } |
2281 | |
2282 | static LogicalResult |
2283 | vectorizePackOpPrecondition(linalg::PackOp packOp, |
2284 | ArrayRef<int64_t> inputVectorSizes) { |
2285 | auto padValue = packOp.getPaddingValue(); |
2286 | Attribute cstAttr; |
2287 | if (padValue && !matchPattern(padValue, m_Constant(bind_value: &cstAttr))) { |
2288 | LDBG("pad value is not constant: "<< packOp << "\n"); |
2289 | return failure(); |
2290 | } |
2291 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
2292 | bool satisfyEmptyCond = true; |
2293 | if (inputVectorSizes.empty()) { |
2294 | if (!packOp.getDestType().hasStaticShape() || |
2295 | !packOp.getSourceType().hasStaticShape()) |
2296 | satisfyEmptyCond = false; |
2297 | } |
2298 | |
2299 | if (!satisfyEmptyCond && |
2300 | failed(vector::isValidMaskedInputVector( |
2301 | shape: resultTensorShape.take_front(N: packOp.getSourceRank()), |
2302 | inputVectorSizes))) |
2303 | return failure(); |
2304 | |
2305 | if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { |
2306 | return !getConstantIntValue(ofr: v).has_value(); |
2307 | })) { |
2308 | LDBG("inner_tiles must be constant: "<< packOp << "\n"); |
2309 | return failure(); |
2310 | } |
2311 | |
2312 | return success(); |
2313 | } |
2314 | |
2315 | static LogicalResult |
2316 | vectorizePadOpPrecondition(tensor::PadOp padOp, |
2317 | ArrayRef<int64_t> inputVectorSizes) { |
2318 | auto padValue = padOp.getConstantPaddingValue(); |
2319 | if (!padValue) { |
2320 | LDBG("pad value is not constant: "<< padOp << "\n"); |
2321 | return failure(); |
2322 | } |
2323 | |
2324 | ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape(); |
2325 | if (failed(Result: vector::isValidMaskedInputVector(shape: resultTensorShape, |
2326 | inputVectorSizes))) |
2327 | return failure(); |
2328 | |
2329 | // Padding with non-zero low pad values is not supported, unless the |
2330 | // corresponding result dim is 1 as this would require shifting the results to |
2331 | // the right for the low padded dims by the required amount of low padding. |
2332 | // However, we do support low padding if the dims being low padded have result |
2333 | // sizes of 1. The reason is when we have a low pad on a unit result dim, the |
2334 | // input size of that dimension will be dynamically zero (as the sum of the |
2335 | // low pad and input dim size has to be one) and hence we will create a zero |
2336 | // mask as the lowering logic just makes the mask one for the input dim size - |
2337 | // which is zero here. Hence we will load the pad value which is what we want |
2338 | // in this case. If the low pad is dynamically zero then the lowering is |
2339 | // correct as well as no shifts are necessary. |
2340 | if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) { |
2341 | Value padValue = en.value(); |
2342 | unsigned pos = en.index(); |
2343 | std::optional<int64_t> pad = getConstantIntValue(ofr: padValue); |
2344 | return (!pad.has_value() || pad.value() != 0) && |
2345 | resultTensorShape[pos] != 1; |
2346 | })) { |
2347 | LDBG("low pad must all be zero for all non unit dims: "<< padOp << "\n"); |
2348 | return failure(); |
2349 | } |
2350 | |
2351 | return success(); |
2352 | } |
2353 | |
2354 | /// Preconditions for scalable vectors. This is quite restrictive - it models |
2355 | /// the fact that in practice we would only make selected dimensions scalable. |
2356 | static LogicalResult |
2357 | vectorizeScalableVectorPrecondition(Operation *op, |
2358 | ArrayRef<int64_t> inputVectorSizes, |
2359 | ArrayRef<bool> inputScalableVecDims) { |
2360 | assert(inputVectorSizes.size() == inputScalableVecDims.size() && |
2361 | "Number of input vector sizes and scalable dims doesn't match"); |
2362 | |
2363 | size_t numOfScalableDims = |
2364 | llvm::count_if(Range&: inputScalableVecDims, P: [](bool flag) { return flag; }); |
2365 | |
2366 | if (numOfScalableDims == 0) |
2367 | return success(); |
2368 | |
2369 | auto linalgOp = dyn_cast<LinalgOp>(op); |
2370 | |
2371 | // Cond 1: There's been no need for scalable vectorisation of |
2372 | // non-linalg Ops so far |
2373 | if (!linalgOp) |
2374 | return failure(); |
2375 | |
2376 | // Cond 2: There's been no need for more than 2 scalable dims so far |
2377 | if (numOfScalableDims > 2) |
2378 | return failure(); |
2379 | |
2380 | // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that |
2381 | // it matches one of the supported cases: |
2382 | // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim |
2383 | // (*). |
2384 | // 2. Exactly 2 dims are scalable and those are the _last two adjacent_ |
2385 | // parallel dims. |
2386 | // 3. Exactly 1 reduction dim is scalable and that's the last (innermost) |
2387 | // dim. |
2388 | // The 2nd restriction above means that only Matmul-like Ops are supported |
2389 | // when 2 dims are scalable, e.g. : |
2390 | // * iterators = [parallel, parallel, reduction] |
2391 | // * scalable flags = [true, true, false] |
2392 | // |
2393 | // (*) Non-unit dims get folded away in practice. |
2394 | // TODO: Relax these conditions as good motivating examples are identified. |
2395 | |
2396 | // Find the first scalable flag. |
2397 | bool seenNonUnitParallel = false; |
2398 | auto iterators = linalgOp.getIteratorTypesArray(); |
2399 | SmallVector<bool> scalableFlags(inputScalableVecDims); |
2400 | int64_t idx = scalableFlags.size() - 1; |
2401 | while (!scalableFlags[idx]) { |
2402 | bool isNonUnitDim = (inputVectorSizes[idx] != 1); |
2403 | seenNonUnitParallel |= |
2404 | (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim); |
2405 | |
2406 | iterators.pop_back(); |
2407 | scalableFlags.pop_back(); |
2408 | --idx; |
2409 | } |
2410 | |
2411 | // Analyze the iterator corresponding to the first scalable dim. |
2412 | switch (iterators.back()) { |
2413 | case utils::IteratorType::reduction: { |
2414 | // Check 3. above is met. |
2415 | if (iterators.size() != inputVectorSizes.size()) { |
2416 | LDBG("Non-trailing reduction dim requested for scalable " |
2417 | "vectorization\n"); |
2418 | return failure(); |
2419 | } |
2420 | if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { |
2421 | LDBG("Scalable vectorization of the reduction dim in Matmul-like ops " |
2422 | "is not supported\n"); |
2423 | return failure(); |
2424 | } |
2425 | break; |
2426 | } |
2427 | case utils::IteratorType::parallel: { |
2428 | // Check 1. and 2. above are met. |
2429 | if (seenNonUnitParallel) { |
2430 | LDBG("Inner parallel dim not requested for scalable " |
2431 | "vectorization\n"); |
2432 | return failure(); |
2433 | } |
2434 | break; |
2435 | } |
2436 | } |
2437 | |
2438 | // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are |
2439 | // supported for which expect the folowing config: |
2440 | // * iterators = [parallel, parallel, reduction] |
2441 | // * scalable flags = [true, true, false] |
2442 | if (numOfScalableDims == 2) { |
2443 | // Disallow below case which breaks 3. above: |
2444 | // * iterators = [..., parallel, reduction] |
2445 | // * scalable flags = [..., true, true] |
2446 | if (iterators.back() == utils::IteratorType::reduction) { |
2447 | LDBG("Higher dim than the trailing reduction dim requested for scalable " |
2448 | "vectorization\n"); |
2449 | return failure(); |
2450 | } |
2451 | scalableFlags.pop_back(); |
2452 | iterators.pop_back(); |
2453 | |
2454 | if (!scalableFlags.back() || |
2455 | (iterators.back() != utils::IteratorType::parallel)) |
2456 | return failure(); |
2457 | } |
2458 | |
2459 | // Check to not let go the matmul with extended semantic, through this |
2460 | // transform. |
2461 | if (linalgOp.hasUserDefinedMaps()) |
2462 | return failure(); |
2463 | |
2464 | // Cond 4: Only the following ops are supported in the |
2465 | // presence of scalable vectors |
2466 | return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || |
2467 | isa<linalg::MatmulTransposeAOp>(op) || |
2468 | isa<linalg::DepthwiseConv1DNwcWcOp>(op) || |
2469 | isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp)); |
2470 | } |
2471 | |
2472 | LogicalResult mlir::linalg::vectorizeOpPrecondition( |
2473 | Operation *op, ArrayRef<int64_t> inputVectorSizes, |
2474 | ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, |
2475 | bool flatten1DDepthwiseConv) { |
2476 | |
2477 | if (!hasVectorizationImpl(op)) |
2478 | return failure(); |
2479 | |
2480 | if (failed(Result: vectorizeScalableVectorPrecondition(op, inputVectorSizes, |
2481 | inputScalableVecDims))) |
2482 | return failure(); |
2483 | |
2484 | return TypeSwitch<Operation *, LogicalResult>(op) |
2485 | .Case<linalg::LinalgOp>([&](auto linalgOp) { |
2486 | return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, |
2487 | vectorizeNDExtract, |
2488 | flatten1DDepthwiseConv); |
2489 | }) |
2490 | .Case<tensor::PadOp>([&](auto padOp) { |
2491 | return vectorizePadOpPrecondition(padOp, inputVectorSizes); |
2492 | }) |
2493 | .Case<linalg::PackOp>([&](auto packOp) { |
2494 | return vectorizePackOpPrecondition(packOp, inputVectorSizes); |
2495 | }) |
2496 | .Case<linalg::UnPackOp>([&](auto unpackOp) { |
2497 | return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); |
2498 | }) |
2499 | .Case<tensor::InsertSliceOp>([&](auto sliceOp) { |
2500 | return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes); |
2501 | }) |
2502 | .Default([](auto) { return failure(); }); |
2503 | } |
2504 | |
2505 | /// Converts affine.apply Ops to arithmetic operations. |
2506 | static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { |
2507 | OpBuilder::InsertionGuard g(rewriter); |
2508 | auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>(); |
2509 | |
2510 | for (auto op : make_early_inc_range(toReplace)) { |
2511 | rewriter.setInsertionPoint(op); |
2512 | auto expanded = affine::expandAffineExpr( |
2513 | rewriter, op->getLoc(), op.getAffineMap().getResult(0), |
2514 | op.getOperands().take_front(op.getAffineMap().getNumDims()), |
2515 | op.getOperands().take_back(op.getAffineMap().getNumSymbols())); |
2516 | rewriter.replaceOp(op, expanded); |
2517 | } |
2518 | } |
2519 | |
2520 | bool mlir::linalg::hasVectorizationImpl(Operation *op) { |
2521 | return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp, |
2522 | tensor::InsertSliceOp>(op); |
2523 | } |
2524 | |
2525 | /// Emit a suitable vector form for an operation. If provided, |
2526 | /// `inputVectorSizes` are used to vectorize this operation. |
2527 | /// `inputVectorSizes` must match the rank of the iteration space of the |
2528 | /// operation and the input vector sizes must be greater than or equal to |
2529 | /// their counterpart iteration space sizes, if static. `inputVectorShapes` |
2530 | /// also allows the vectorization of operations with dynamic shapes. |
2531 | LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, |
2532 | ArrayRef<int64_t> inputVectorSizes, |
2533 | ArrayRef<bool> inputScalableVecDims, |
2534 | bool vectorizeNDExtract, |
2535 | bool flatten1DDepthwiseConv) { |
2536 | LDBG("Attempting to vectorize:\n"<< *op << "\n"); |
2537 | LDBG("Input vector sizes: "); |
2538 | LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); |
2539 | LLVM_DEBUG(llvm::dbgs() << "\n"); |
2540 | LDBG("Input scalable vector dims: "); |
2541 | LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); |
2542 | LLVM_DEBUG(llvm::dbgs() << "\n"); |
2543 | |
2544 | if (failed(Result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, |
2545 | vectorizeNDExtract, |
2546 | flatten1DDepthwiseConv))) { |
2547 | LDBG("Vectorization pre-conditions failed\n"); |
2548 | return failure(); |
2549 | } |
2550 | |
2551 | // Initialize vectorization state. |
2552 | VectorizationState state(rewriter); |
2553 | if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) { |
2554 | if (failed(state.initState(rewriter, linalgOp: linalgOp, inputVectorSizes, |
2555 | inputScalableVecDims))) { |
2556 | LDBG("Vectorization state couldn't be initialized\n"); |
2557 | return failure(); |
2558 | } |
2559 | } |
2560 | |
2561 | SmallVector<Value> results; |
2562 | auto vectorizeResult = |
2563 | TypeSwitch<Operation *, LogicalResult>(op) |
2564 | .Case<linalg::LinalgOp>([&](auto linalgOp) { |
2565 | // TODO: isaConvolutionOpInterface that can also infer from |
2566 | // generic features. Will require stride/dilation attributes |
2567 | // inference. |
2568 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { |
2569 | FailureOr<Operation *> convOr = vectorizeConvolution( |
2570 | rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, |
2571 | flatten1DDepthwiseConv); |
2572 | if (succeeded(convOr)) { |
2573 | llvm::append_range(results, (*convOr)->getResults()); |
2574 | return success(); |
2575 | } |
2576 | |
2577 | LDBG("Unsupported convolution can't be vectorized.\n"); |
2578 | return failure(); |
2579 | } |
2580 | |
2581 | LDBG("Vectorize generic by broadcasting to the canonical vector " |
2582 | "shape\n"); |
2583 | |
2584 | // Pre-process before proceeding. |
2585 | convertAffineApply(rewriter, linalgOp); |
2586 | |
2587 | // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted |
2588 | // to 'OpBuilder' when it is passed over to some methods like |
2589 | // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we |
2590 | // erase an op within these methods, the actual rewriter won't be |
2591 | // notified and we will end up with read-after-free issues! |
2592 | return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); |
2593 | }) |
2594 | .Case<tensor::PadOp>([&](auto padOp) { |
2595 | return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, |
2596 | results); |
2597 | }) |
2598 | .Case<linalg::PackOp>([&](auto packOp) { |
2599 | return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, |
2600 | results); |
2601 | }) |
2602 | .Case<linalg::UnPackOp>([&](auto unpackOp) { |
2603 | return vectorizeAsTensorUnpackOp(rewriter, unpackOp, |
2604 | inputVectorSizes, results); |
2605 | }) |
2606 | .Case<tensor::InsertSliceOp>([&](auto sliceOp) { |
2607 | return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, |
2608 | results); |
2609 | }) |
2610 | .Default([](auto) { return failure(); }); |
2611 | |
2612 | if (failed(vectorizeResult)) { |
2613 | LDBG("Vectorization failed\n"); |
2614 | return failure(); |
2615 | } |
2616 | |
2617 | if (!results.empty()) |
2618 | rewriter.replaceOp(op, newValues: results); |
2619 | else |
2620 | rewriter.eraseOp(op); |
2621 | |
2622 | return success(); |
2623 | } |
2624 | |
2625 | LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, |
2626 | memref::CopyOp copyOp) { |
2627 | auto srcType = cast<MemRefType>(copyOp.getSource().getType()); |
2628 | auto dstType = cast<MemRefType>(copyOp.getTarget().getType()); |
2629 | if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) |
2630 | return failure(); |
2631 | |
2632 | auto srcElementType = getElementTypeOrSelf(srcType); |
2633 | auto dstElementType = getElementTypeOrSelf(dstType); |
2634 | if (!VectorType::isValidElementType(srcElementType) || |
2635 | !VectorType::isValidElementType(dstElementType)) |
2636 | return failure(); |
2637 | |
2638 | auto readType = VectorType::get(srcType.getShape(), srcElementType); |
2639 | auto writeType = VectorType::get(dstType.getShape(), dstElementType); |
2640 | |
2641 | Location loc = copyOp->getLoc(); |
2642 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
2643 | SmallVector<Value> indices(srcType.getRank(), zero); |
2644 | |
2645 | Value readValue = rewriter.create<vector::TransferReadOp>( |
2646 | loc, readType, copyOp.getSource(), indices, |
2647 | rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
2648 | if (cast<VectorType>(readValue.getType()).getRank() == 0) { |
2649 | readValue = |
2650 | rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>()); |
2651 | readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue); |
2652 | } |
2653 | Operation *writeValue = rewriter.create<vector::TransferWriteOp>( |
2654 | loc, readValue, copyOp.getTarget(), indices, |
2655 | rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
2656 | rewriter.replaceOp(copyOp, writeValue->getResults()); |
2657 | return success(); |
2658 | } |
2659 | |
2660 | //----------------------------------------------------------------------------// |
2661 | // Misc. vectorization patterns. |
2662 | //----------------------------------------------------------------------------// |
2663 | /// Base pattern for rewriting tensor::PadOps whose result is consumed by a |
2664 | /// given operation type OpTy. |
2665 | template <typename OpTy> |
2666 | struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> { |
2667 | using OpRewritePattern<tensor::PadOp>::OpRewritePattern; |
2668 | |
2669 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
2670 | PatternRewriter &rewriter) const final { |
2671 | bool changed = false; |
2672 | // Insert users in vector, because some users may be replaced/removed. |
2673 | for (auto *user : llvm::to_vector<4>(padOp->getUsers())) |
2674 | if (auto op = dyn_cast<OpTy>(user)) |
2675 | changed |= rewriteUser(rewriter, padOp, op).succeeded(); |
2676 | return success(IsSuccess: changed); |
2677 | } |
2678 | |
2679 | protected: |
2680 | virtual LogicalResult rewriteUser(PatternRewriter &rewriter, |
2681 | tensor::PadOp padOp, OpTy op) const = 0; |
2682 | }; |
2683 | |
2684 | /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.: |
2685 | /// ``` |
2686 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
2687 | /// %r = vector.transfer_read %0[%c0, %c0], %cst |
2688 | /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> |
2689 | /// ``` |
2690 | /// is rewritten to: |
2691 | /// ``` |
2692 | /// %r = vector.transfer_read %src[%c0, %c0], %padding |
2693 | /// {in_bounds = [true, true]} |
2694 | /// : tensor<?x?xf32>, vector<17x5xf32> |
2695 | /// ``` |
2696 | /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be |
2697 | /// sure that the original padding value %cst was never used. |
2698 | /// |
2699 | /// This rewrite is possible if: |
2700 | /// - `xferOp` has no out-of-bounds dims or mask. |
2701 | /// - Low padding is static 0. |
2702 | /// - Single, scalar padding value. |
2703 | struct PadOpVectorizationWithTransferReadPattern |
2704 | : public VectorizePadOpUserPattern<vector::TransferReadOp> { |
2705 | using VectorizePadOpUserPattern< |
2706 | vector::TransferReadOp>::VectorizePadOpUserPattern; |
2707 | |
2708 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
2709 | vector::TransferReadOp xferOp) const override { |
2710 | // Low padding must be static 0. |
2711 | if (!padOp.hasZeroLowPad()) |
2712 | return failure(); |
2713 | // Pad value must be a constant. |
2714 | auto padValue = padOp.getConstantPaddingValue(); |
2715 | if (!padValue) |
2716 | return failure(); |
2717 | // Padding value of existing `xferOp` is unused. |
2718 | if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) |
2719 | return failure(); |
2720 | |
2721 | rewriter.modifyOpInPlace(xferOp, [&]() { |
2722 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
2723 | xferOp->setAttr(xferOp.getInBoundsAttrName(), |
2724 | rewriter.getBoolArrayAttr(inBounds)); |
2725 | xferOp.getBaseMutable().assign(padOp.getSource()); |
2726 | xferOp.getPaddingMutable().assign(padValue); |
2727 | }); |
2728 | |
2729 | return success(); |
2730 | } |
2731 | }; |
2732 | |
2733 | /// Rewrite use of tensor::PadOp result in TransferWriteOp. |
2734 | /// This pattern rewrites TransferWriteOps that write to a padded tensor |
2735 | /// value, where the same amount of padding is immediately removed again after |
2736 | /// the write. In such cases, the TransferWriteOp can write to the non-padded |
2737 | /// tensor value and apply out-of-bounds masking. E.g.: |
2738 | /// ``` |
2739 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
2740 | /// : tensor<...> to tensor<?x?xf32> |
2741 | /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32> |
2742 | /// %2 = vector.transfer_write %vec, %1[...] |
2743 | /// : vector<17x5xf32>, tensor<17x5xf32> |
2744 | /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1] |
2745 | /// : tensor<17x5xf32> to tensor<?x?xf32> |
2746 | /// ``` |
2747 | /// is rewritten to: |
2748 | /// ``` |
2749 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
2750 | /// : tensor<...> to tensor<?x?xf32> |
2751 | /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, |
2752 | /// tensor<?x?xf32> |
2753 | /// ``` |
2754 | /// Note: It is important that the ExtractSliceOp %r resizes the result of the |
2755 | /// TransferWriteOp to the same size as the input of the TensorPadOp (or an |
2756 | /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ |
2757 | /// from %r's old dimensions. |
2758 | /// |
2759 | /// This rewrite is possible if: |
2760 | /// - Low padding is static 0. |
2761 | /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This |
2762 | /// ExtractSliceOp trims the same amount of padding that was added |
2763 | /// beforehand. |
2764 | /// - Single, scalar padding value. |
2765 | struct PadOpVectorizationWithTransferWritePattern |
2766 | : public VectorizePadOpUserPattern<vector::TransferWriteOp> { |
2767 | using VectorizePadOpUserPattern< |
2768 | vector::TransferWriteOp>::VectorizePadOpUserPattern; |
2769 | |
2770 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
2771 | vector::TransferWriteOp xferOp) const override { |
2772 | // TODO: support 0-d corner case. |
2773 | if (xferOp.getTransferRank() == 0) |
2774 | return failure(); |
2775 | |
2776 | // Low padding must be static 0. |
2777 | if (!padOp.hasZeroLowPad()) |
2778 | return failure(); |
2779 | // Pad value must be a constant. |
2780 | auto padValue = padOp.getConstantPaddingValue(); |
2781 | if (!padValue) |
2782 | return failure(); |
2783 | // TransferWriteOp result must be directly consumed by an ExtractSliceOp. |
2784 | if (!xferOp->hasOneUse()) |
2785 | return failure(); |
2786 | auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin()); |
2787 | if (!trimPadding) |
2788 | return failure(); |
2789 | // Only static zero offsets supported when trimming padding. |
2790 | if (!trimPadding.hasZeroOffset()) |
2791 | return failure(); |
2792 | // trimPadding must remove the amount of padding that was added earlier. |
2793 | if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding)) |
2794 | return failure(); |
2795 | |
2796 | // Insert the new TransferWriteOp at position of the old TransferWriteOp. |
2797 | rewriter.setInsertionPoint(xferOp); |
2798 | |
2799 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
2800 | auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
2801 | xferOp, padOp.getSource().getType(), xferOp.getVector(), |
2802 | padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(), |
2803 | xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds)); |
2804 | rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); |
2805 | |
2806 | return success(); |
2807 | } |
2808 | |
2809 | /// Check if `beforePadding` and `afterTrimming` have the same tensor size, |
2810 | /// i.e., same dimensions. |
2811 | /// |
2812 | /// Dimensions may be static, dynamic or mix of both. In case of dynamic |
2813 | /// dimensions, this function tries to infer the (static) tensor size by |
2814 | /// looking at the defining op and utilizing op-specific knowledge. |
2815 | /// |
2816 | /// This is a conservative analysis. In case equal tensor sizes cannot be |
2817 | /// proven statically, this analysis returns `false` even though the tensor |
2818 | /// sizes may turn out to be equal at runtime. |
2819 | bool hasSameTensorSize(Value beforePadding, |
2820 | tensor::ExtractSliceOp afterTrimming) const { |
2821 | // If the input to tensor::PadOp is a CastOp, try with both CastOp |
2822 | // result and CastOp operand. |
2823 | if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>()) |
2824 | if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming: afterTrimming)) |
2825 | return true; |
2826 | |
2827 | auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType()); |
2828 | auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType()); |
2829 | // Only RankedTensorType supported. |
2830 | if (!t1 || !t2) |
2831 | return false; |
2832 | // Rank of both values must be the same. |
2833 | if (t1.getRank() != t2.getRank()) |
2834 | return false; |
2835 | |
2836 | // All static dimensions must be the same. Mixed cases (e.g., dimension |
2837 | // static in `t1` but dynamic in `t2`) are not supported. |
2838 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
2839 | if (t1.isDynamicDim(i) != t2.isDynamicDim(i)) |
2840 | return false; |
2841 | if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i)) |
2842 | return false; |
2843 | } |
2844 | |
2845 | // Nothing more to check if all dimensions are static. |
2846 | if (t1.getNumDynamicDims() == 0) |
2847 | return true; |
2848 | |
2849 | // All dynamic sizes must be the same. The only supported case at the |
2850 | // moment is when `beforePadding` is an ExtractSliceOp (or a cast |
2851 | // thereof). |
2852 | |
2853 | // Apart from CastOp, only ExtractSliceOp is supported. |
2854 | auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>(); |
2855 | if (!beforeSlice) |
2856 | return false; |
2857 | |
2858 | assert(static_cast<size_t>(t1.getRank()) == |
2859 | beforeSlice.getMixedSizes().size()); |
2860 | assert(static_cast<size_t>(t2.getRank()) == |
2861 | afterTrimming.getMixedSizes().size()); |
2862 | |
2863 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
2864 | // Skip static dimensions. |
2865 | if (!t1.isDynamicDim(i)) |
2866 | continue; |
2867 | auto size1 = beforeSlice.getMixedSizes()[i]; |
2868 | auto size2 = afterTrimming.getMixedSizes()[i]; |
2869 | |
2870 | // Case 1: Same value or same constant int. |
2871 | if (isEqualConstantIntOrValue(size1, size2)) |
2872 | continue; |
2873 | |
2874 | // Other cases: Take a deeper look at defining ops of values. |
2875 | auto v1 = llvm::dyn_cast_if_present<Value>(size1); |
2876 | auto v2 = llvm::dyn_cast_if_present<Value>(size2); |
2877 | if (!v1 || !v2) |
2878 | return false; |
2879 | |
2880 | // Case 2: Both values are identical AffineMinOps. (Should not happen if |
2881 | // CSE is run.) |
2882 | auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>(); |
2883 | auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>(); |
2884 | if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && |
2885 | minOp1.getOperands() == minOp2.getOperands()) |
2886 | continue; |
2887 | |
2888 | // Add additional cases as needed. |
2889 | } |
2890 | |
2891 | // All tests passed. |
2892 | return true; |
2893 | } |
2894 | }; |
2895 | |
2896 | /// Returns the effective Pad value for the input op, provided it's a scalar. |
2897 | /// |
2898 | /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If |
2899 | /// this Op performs padding, retrieve the padding value provided that it's |
2900 | /// a scalar and static/fixed for all the padded values. Returns an empty value |
2901 | /// otherwise. |
2902 | /// |
2903 | /// TODO: This is used twice (when checking vectorization pre-conditions and |
2904 | /// when vectorizing). Cache results instead of re-running. |
2905 | static Value getStaticPadVal(Operation *op) { |
2906 | if (!op) |
2907 | return {}; |
2908 | |
2909 | // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's |
2910 | // being broadcast, provided that it's a scalar. |
2911 | if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) { |
2912 | auto source = bcast.getSource(); |
2913 | if (llvm::dyn_cast<VectorType>(source.getType())) |
2914 | return {}; |
2915 | |
2916 | return source; |
2917 | } |
2918 | |
2919 | // 2. linalg.fill - use the scalar input value that used to fill the output |
2920 | // tensor. |
2921 | if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) { |
2922 | return fill.getInputs()[0]; |
2923 | } |
2924 | |
2925 | // 3. tensor.generateOp - can't guarantee the value is fixed without |
2926 | // analysing, bail out. |
2927 | if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) { |
2928 | return {}; |
2929 | } |
2930 | |
2931 | // 4. vector.transfer_write - inspect the input vector that's written from. If |
2932 | // if contains a single value that has been broadcast (e.g. via |
2933 | // vector.broadcast), extract it, fail otherwise. |
2934 | if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op)) |
2935 | return getStaticPadVal(xferWrite.getVector().getDefiningOp()); |
2936 | |
2937 | // 5. tensor.insert_slice - inspect the destination tensor. If it's larger |
2938 | // than the input tensor, then, provided it's constant, we'll extract the |
2939 | // value that was used to generate it (via e.g. linalg.fill), fail otherwise. |
2940 | // TODO: Clarify the semantics when the input tensor is larger than the |
2941 | // destination. |
2942 | if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op)) |
2943 | return getStaticPadVal(slice.getDest().getDefiningOp()); |
2944 | |
2945 | return {}; |
2946 | } |
2947 | |
2948 | static LogicalResult |
2949 | vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, |
2950 | ArrayRef<int64_t> inputVectorSizes, |
2951 | SmallVectorImpl<Value> &newResults) { |
2952 | // TODO: Introduce a parent class that will handle the insertion point update. |
2953 | OpBuilder::InsertionGuard g(rewriter); |
2954 | rewriter.setInsertionPoint(sliceOp); |
2955 | |
2956 | TypedValue<RankedTensorType> source = sliceOp.getSource(); |
2957 | auto sourceType = source.getType(); |
2958 | auto resultType = sliceOp.getResultType(); |
2959 | |
2960 | Value padValue = getStaticPadVal(sliceOp); |
2961 | |
2962 | if (!padValue) { |
2963 | auto elemType = sourceType.getElementType(); |
2964 | padValue = rewriter.create<arith::ConstantOp>( |
2965 | sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); |
2966 | } |
2967 | |
2968 | // 2. Get the vector shape |
2969 | SmallVector<int64_t> vecShape; |
2970 | size_t rankDiff = resultType.getRank() - sourceType.getRank(); |
2971 | for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) { |
2972 | if (!inputVectorSizes.empty()) { |
2973 | vecShape.push_back(Elt: inputVectorSizes[i]); |
2974 | } else if (!sourceType.isDynamicDim(i)) { |
2975 | vecShape.push_back(Elt: sourceType.getDimSize(i)); |
2976 | } else if (!resultType.isDynamicDim(i)) { |
2977 | // Source shape is not statically known, but result shape is. |
2978 | // Vectorize with size of result shape. This may be larger than the |
2979 | // source size. |
2980 | // FIXME: Using rankDiff implies that the source tensor is inserted at |
2981 | // the end of the destination tensor. However, that's not required. |
2982 | vecShape.push_back(Elt: resultType.getDimSize(rankDiff + i)); |
2983 | } else { |
2984 | // Neither source nor result dim of padOp is static. Cannot vectorize |
2985 | // the copy. |
2986 | return failure(); |
2987 | } |
2988 | } |
2989 | auto vecType = VectorType::get(vecShape, sourceType.getElementType()); |
2990 | |
2991 | // 3. Generate TransferReadOp + TransferWriteOp |
2992 | auto loc = sliceOp.getLoc(); |
2993 | |
2994 | // Create read |
2995 | SmallVector<Value> readIndices( |
2996 | vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
2997 | Value read = mlir::vector::createReadOrMaskedRead( |
2998 | builder&: rewriter, loc: loc, source, inputVectorSizes: vecType.getShape(), padValue, |
2999 | /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); |
3000 | |
3001 | // Create write |
3002 | auto writeIndices = |
3003 | getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets()); |
3004 | Operation *write = |
3005 | createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(), |
3006 | writeIndices, inputVectorSizes.empty()); |
3007 | |
3008 | // 4. Finalize |
3009 | newResults.push_back(Elt: write->getResult(idx: 0)); |
3010 | |
3011 | return success(); |
3012 | } |
3013 | |
3014 | /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.: |
3015 | /// ``` |
3016 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
3017 | /// %r = tensor.insert_slice %0 |
3018 | /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] |
3019 | /// : tensor<17x5xf32> into tensor<?x?x17x5xf32> |
3020 | /// ``` |
3021 | /// is rewritten to: |
3022 | /// ``` |
3023 | /// %0 = vector.transfer_read %src[%c0, %c0], %padding |
3024 | /// : tensor<?x?xf32>, vector<17x5xf32> |
3025 | /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] |
3026 | /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32> |
3027 | /// ``` |
3028 | /// |
3029 | /// This rewrite is possible if: |
3030 | /// - Low padding is static 0. |
3031 | /// - `padOp` result shape is static. |
3032 | /// - The entire padded tensor is inserted. |
3033 | /// (Implies that sizes of `insertOp` are all static.) |
3034 | /// - Only unit strides in `insertOp`. |
3035 | /// - Single, scalar padding value. |
3036 | /// - `padOp` result not used as destination. |
3037 | struct PadOpVectorizationWithInsertSlicePattern |
3038 | : public VectorizePadOpUserPattern<tensor::InsertSliceOp> { |
3039 | using VectorizePadOpUserPattern< |
3040 | tensor::InsertSliceOp>::VectorizePadOpUserPattern; |
3041 | |
3042 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
3043 | tensor::InsertSliceOp insertOp) const override { |
3044 | // Low padding must be static 0. |
3045 | if (!padOp.hasZeroLowPad()) |
3046 | return failure(); |
3047 | // Only unit stride supported. |
3048 | if (!insertOp.hasUnitStride()) |
3049 | return failure(); |
3050 | // Pad value must be a constant. |
3051 | auto padValue = padOp.getConstantPaddingValue(); |
3052 | if (!padValue) |
3053 | return failure(); |
3054 | // Dynamic shapes not supported. |
3055 | if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape()) |
3056 | return failure(); |
3057 | // Pad result not used as destination. |
3058 | if (insertOp.getDest() == padOp.getResult()) |
3059 | return failure(); |
3060 | |
3061 | auto vecType = VectorType::get(padOp.getType().getShape(), |
3062 | padOp.getType().getElementType()); |
3063 | unsigned vecRank = vecType.getRank(); |
3064 | unsigned tensorRank = insertOp.getType().getRank(); |
3065 | |
3066 | // Check if sizes match: Insert the entire tensor into most minor dims. |
3067 | // (No permutations allowed.) |
3068 | SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1); |
3069 | expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); |
3070 | if (!llvm::all_of( |
3071 | llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { |
3072 | return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); |
3073 | })) |
3074 | return failure(); |
3075 | |
3076 | // Insert the TransferReadOp and TransferWriteOp at the position of the |
3077 | // InsertSliceOp. |
3078 | rewriter.setInsertionPoint(insertOp); |
3079 | |
3080 | // Generate TransferReadOp: Read entire source tensor and add high |
3081 | // padding. |
3082 | SmallVector<Value> readIndices( |
3083 | vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0)); |
3084 | auto read = rewriter.create<vector::TransferReadOp>( |
3085 | padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); |
3086 | |
3087 | // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at |
3088 | // specified offsets. Write is fully in-bounds because a InsertSliceOp's |
3089 | // source must fit into the destination at the specified offsets. |
3090 | auto writeIndices = getValueOrCreateConstantIndexOp( |
3091 | rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); |
3092 | SmallVector<bool> inBounds(vecRank, true); |
3093 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
3094 | insertOp, read, insertOp.getDest(), writeIndices, |
3095 | ArrayRef<bool>{inBounds}); |
3096 | |
3097 | return success(); |
3098 | } |
3099 | }; |
3100 | |
3101 | void mlir::linalg::populatePadOpVectorizationPatterns( |
3102 | RewritePatternSet &patterns, PatternBenefit baseBenefit) { |
3103 | patterns.add<PadOpVectorizationWithTransferReadPattern, |
3104 | PadOpVectorizationWithTransferWritePattern, |
3105 | PadOpVectorizationWithInsertSlicePattern>( |
3106 | arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1); |
3107 | } |
3108 | |
3109 | //----------------------------------------------------------------------------// |
3110 | // Forwarding patterns |
3111 | //----------------------------------------------------------------------------// |
3112 | |
3113 | /// Check whether there is any interleaved use of any `values` between |
3114 | /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value |
3115 | /// is in a different block. |
3116 | static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, |
3117 | ValueRange values) { |
3118 | if (firstOp->getBlock() != secondOp->getBlock() || |
3119 | !firstOp->isBeforeInBlock(other: secondOp)) { |
3120 | LDBG("interleavedUses precondition failed, firstOp: " |
3121 | << *firstOp << ", second op: "<< *secondOp << "\n"); |
3122 | return true; |
3123 | } |
3124 | for (auto v : values) { |
3125 | for (auto &u : v.getUses()) { |
3126 | Operation *owner = u.getOwner(); |
3127 | if (owner == firstOp || owner == secondOp) |
3128 | continue; |
3129 | // TODO: this is too conservative, use dominance info in the future. |
3130 | if (owner->getBlock() == firstOp->getBlock() && |
3131 | (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner))) |
3132 | continue; |
3133 | LDBG(" found interleaved op "<< *owner << ", firstOp: "<< *firstOp |
3134 | << ", second op: "<< *secondOp << "\n"); |
3135 | return true; |
3136 | } |
3137 | } |
3138 | return false; |
3139 | } |
3140 | |
3141 | /// Return the unique subview use of `v` if it is indeed unique, null |
3142 | /// otherwise. |
3143 | static memref::SubViewOp getSubViewUseIfUnique(Value v) { |
3144 | memref::SubViewOp subViewOp; |
3145 | for (auto &u : v.getUses()) { |
3146 | if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) { |
3147 | if (subViewOp) |
3148 | return memref::SubViewOp(); |
3149 | subViewOp = newSubViewOp; |
3150 | } |
3151 | } |
3152 | return subViewOp; |
3153 | } |
3154 | |
3155 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
3156 | /// when available. |
3157 | LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( |
3158 | vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { |
3159 | |
3160 | // TODO: support mask. |
3161 | if (xferOp.getMask()) |
3162 | return rewriter.notifyMatchFailure(xferOp, "unsupported mask"); |
3163 | |
3164 | // Transfer into `view`. |
3165 | Value viewOrAlloc = xferOp.getBase(); |
3166 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
3167 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
3168 | return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc"); |
3169 | |
3170 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
3171 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); |
3172 | if (!subViewOp) |
3173 | return rewriter.notifyMatchFailure(xferOp, "no subview found"); |
3174 | Value subView = subViewOp.getResult(); |
3175 | |
3176 | // Find the copy into `subView` without interleaved uses. |
3177 | memref::CopyOp copyOp; |
3178 | for (auto &u : subView.getUses()) { |
3179 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) { |
3180 | assert(isa<MemRefType>(newCopyOp.getTarget().getType())); |
3181 | if (newCopyOp.getTarget() != subView) |
3182 | continue; |
3183 | if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) |
3184 | continue; |
3185 | copyOp = newCopyOp; |
3186 | break; |
3187 | } |
3188 | } |
3189 | if (!copyOp) |
3190 | return rewriter.notifyMatchFailure(xferOp, "no copy found"); |
3191 | |
3192 | // Find the fill into `viewOrAlloc` without interleaved uses before the |
3193 | // copy. |
3194 | FillOp maybeFillOp; |
3195 | for (auto &u : viewOrAlloc.getUses()) { |
3196 | if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) { |
3197 | assert(isa<MemRefType>(newFillOp.output().getType())); |
3198 | if (newFillOp.output() != viewOrAlloc) |
3199 | continue; |
3200 | if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) |
3201 | continue; |
3202 | maybeFillOp = newFillOp; |
3203 | break; |
3204 | } |
3205 | } |
3206 | // Ensure padding matches. |
3207 | if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) |
3208 | return rewriter.notifyMatchFailure(xferOp, |
3209 | "padding value does not match fill"); |
3210 | |
3211 | // `in` is the subview that memref.copy reads. Replace it. |
3212 | Value in = copyOp.getSource(); |
3213 | |
3214 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
3215 | // The `masked` attribute is only valid on this padded buffer. |
3216 | // When forwarding to vector.transfer_read, the attribute must be reset |
3217 | // conservatively. |
3218 | auto vectorType = xferOp.getVectorType(); |
3219 | Value res = rewriter.create<vector::TransferReadOp>( |
3220 | xferOp.getLoc(), vectorType, in, xferOp.getIndices(), |
3221 | xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), |
3222 | rewriter.getBoolArrayAttr( |
3223 | SmallVector<bool>(vectorType.getRank(), false))); |
3224 | |
3225 | if (maybeFillOp) |
3226 | rewriter.eraseOp(op: maybeFillOp); |
3227 | rewriter.eraseOp(op: copyOp); |
3228 | rewriter.replaceOp(xferOp, res); |
3229 | |
3230 | return success(); |
3231 | } |
3232 | |
3233 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
3234 | /// when available. |
3235 | LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( |
3236 | vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { |
3237 | // TODO: support mask. |
3238 | if (xferOp.getMask()) |
3239 | return rewriter.notifyMatchFailure(xferOp, "unsupported mask"); |
3240 | |
3241 | // Transfer into `viewOrAlloc`. |
3242 | Value viewOrAlloc = xferOp.getBase(); |
3243 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
3244 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
3245 | return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc"); |
3246 | |
3247 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
3248 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); |
3249 | if (!subViewOp) |
3250 | return rewriter.notifyMatchFailure(xferOp, "no subview found"); |
3251 | Value subView = subViewOp.getResult(); |
3252 | |
3253 | // Find the copy from `subView` without interleaved uses. |
3254 | memref::CopyOp copyOp; |
3255 | for (auto &u : subViewOp.getResult().getUses()) { |
3256 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) { |
3257 | if (newCopyOp.getSource() != subView) |
3258 | continue; |
3259 | if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) |
3260 | continue; |
3261 | copyOp = newCopyOp; |
3262 | break; |
3263 | } |
3264 | } |
3265 | if (!copyOp) |
3266 | return rewriter.notifyMatchFailure(xferOp, "no copy found"); |
3267 | |
3268 | // `out` is the subview copied into that we replace. |
3269 | assert(isa<MemRefType>(copyOp.getTarget().getType())); |
3270 | Value out = copyOp.getTarget(); |
3271 | |
3272 | // Forward vector.transfer into copy. |
3273 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
3274 | // The `masked` attribute is only valid on this padded buffer. |
3275 | // When forwarding to vector.transfer_write, the attribute must be reset |
3276 | // conservatively. |
3277 | auto vector = xferOp.getVector(); |
3278 | rewriter.create<vector::TransferWriteOp>( |
3279 | xferOp.getLoc(), vector, out, xferOp.getIndices(), |
3280 | xferOp.getPermutationMapAttr(), xferOp.getMask(), |
3281 | rewriter.getBoolArrayAttr(SmallVector<bool>( |
3282 | dyn_cast<VectorType>(vector.getType()).getRank(), false))); |
3283 | |
3284 | rewriter.eraseOp(op: copyOp); |
3285 | rewriter.eraseOp(op: xferOp); |
3286 | |
3287 | return success(); |
3288 | } |
3289 | |
3290 | //===----------------------------------------------------------------------===// |
3291 | // Convolution vectorization patterns |
3292 | //===----------------------------------------------------------------------===// |
3293 | |
3294 | template <int N> |
3295 | static void bindShapeDims(ShapedType shapedType) {} |
3296 | |
3297 | template <int N, typename IntTy, typename... IntTy2> |
3298 | static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { |
3299 | val = shapedType.getShape()[N]; |
3300 | bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...); |
3301 | } |
3302 | |
3303 | /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). |
3304 | template <typename... IntTy> |
3305 | static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { |
3306 | bindShapeDims<0>(shapedType, vals...); |
3307 | } |
3308 | |
3309 | namespace { |
3310 | /// Generate a vector implementation for either: |
3311 | /// ``` |
3312 | /// Op def: ( w, kw ) |
3313 | /// Iters: ({Par(), Red()}) |
3314 | /// Layout: {{w + kw}, {kw}, {w}} |
3315 | /// ``` |
3316 | /// kw is unrolled. |
3317 | /// |
3318 | /// or |
3319 | /// |
3320 | /// ``` |
3321 | /// Op def: ( n, w, c, kw, f ) |
3322 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
3323 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
3324 | /// ``` |
3325 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
3326 | /// |
3327 | /// or |
3328 | /// |
3329 | /// ``` |
3330 | /// Op def: ( n, c, w, f, kw ) |
3331 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
3332 | /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
3333 | /// ``` |
3334 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
3335 | /// |
3336 | /// or |
3337 | /// |
3338 | /// ``` |
3339 | /// Op def: ( n, w, c, kw ) |
3340 | /// Iters: ({Par(), Par(), Par(), Red()}) |
3341 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
3342 | /// ``` |
3343 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
3344 | struct Conv1DGenerator |
3345 | : public StructuredGenerator<LinalgOp, utils::IteratorType> { |
3346 | Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) |
3347 | : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) { |
3348 | |
3349 | lhsShaped = linalgOp.getDpsInputOperand(0)->get(); |
3350 | rhsShaped = linalgOp.getDpsInputOperand(1)->get(); |
3351 | resShaped = linalgOp.getDpsInitOperand(0)->get(); |
3352 | lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType()); |
3353 | rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType()); |
3354 | resShapedType = dyn_cast<ShapedType>(resShaped.getType()); |
3355 | |
3356 | Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); |
3357 | redOp = reduceOp->getName().getIdentifier(); |
3358 | |
3359 | setConvOperationKind(reduceOp); |
3360 | |
3361 | auto maybeKind = getCombinerOpKind(reduceOp); |
3362 | reductionKind = maybeKind.value(); |
3363 | |
3364 | // The ConvolutionOpInterface gives us guarantees of existence for |
3365 | // strides/dilations. However, we do not need to rely on those, we can |
3366 | // simply use them if present, otherwise use the default and let the generic |
3367 | // conv. matcher in the ConvGenerator succeed or fail. |
3368 | auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"); |
3369 | auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations"); |
3370 | strideW = strides ? *strides.getValues<uint64_t>().begin() : 1; |
3371 | dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1; |
3372 | } |
3373 | |
3374 | /// Generate a vector implementation for: |
3375 | /// ``` |
3376 | /// Op def: ( w, kw ) |
3377 | /// Iters: ({Par(), Red()}) |
3378 | /// Layout: {{w + kw}, {kw}, {w}} |
3379 | /// ``` |
3380 | /// kw is always unrolled. |
3381 | /// |
3382 | /// or |
3383 | /// |
3384 | /// ``` |
3385 | /// Op def: ( n, w, c, kw, f ) |
3386 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
3387 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
3388 | /// ``` |
3389 | /// kw is always unrolled. |
3390 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
3391 | /// > 1. |
3392 | FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) { |
3393 | int64_t nSize, wSize, cSize, kwSize, fSize; |
3394 | SmallVector<int64_t, 3> lhsShape, rhsShape, resShape; |
3395 | bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); |
3396 | switch (conv1DOpOrder) { |
3397 | case Conv1DOpOrder::W: |
3398 | // Initialize unused dimensions |
3399 | nSize = fSize = cSize = 0; |
3400 | // out{W} |
3401 | bindShapeDims(resShapedType, wSize); |
3402 | // kernel{kw} |
3403 | bindShapeDims(rhsShapedType, kwSize); |
3404 | lhsShape = {// iw = ow + kw - 1 |
3405 | // (i.e. 16 convolved with 3 -> 14) |
3406 | (wSize + kwSize - 1)}; |
3407 | rhsShape = {kwSize}; |
3408 | resShape = {wSize}; |
3409 | break; |
3410 | case Conv1DOpOrder::Nwc: |
3411 | // out{n, w, f} |
3412 | bindShapeDims(resShapedType, nSize, wSize, fSize); |
3413 | switch (oper) { |
3414 | case ConvOperationKind::Conv: |
3415 | // kernel{kw, c, f} |
3416 | bindShapeDims(rhsShapedType, kwSize, cSize); |
3417 | break; |
3418 | case ConvOperationKind::Pool: |
3419 | // kernel{kw} |
3420 | bindShapeDims(rhsShapedType, kwSize); |
3421 | cSize = fSize; |
3422 | break; |
3423 | } |
3424 | lhsShape = {nSize, |
3425 | // iw = ow * sw + kw * dw - 1 |
3426 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
3427 | // Perform the proper inclusive -> exclusive -> inclusive. |
3428 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
3429 | 1, |
3430 | cSize}; |
3431 | switch (oper) { |
3432 | case ConvOperationKind::Conv: |
3433 | rhsShape = {kwSize, cSize, fSize}; |
3434 | break; |
3435 | case ConvOperationKind::Pool: |
3436 | rhsShape = {kwSize}; |
3437 | break; |
3438 | } |
3439 | resShape = {nSize, wSize, fSize}; |
3440 | break; |
3441 | case Conv1DOpOrder::Ncw: |
3442 | // out{n, f, w} |
3443 | bindShapeDims(resShapedType, nSize, fSize, wSize); |
3444 | switch (oper) { |
3445 | case ConvOperationKind::Conv: |
3446 | // kernel{f, c, kw} |
3447 | bindShapeDims(rhsShapedType, fSize, cSize, kwSize); |
3448 | break; |
3449 | case ConvOperationKind::Pool: |
3450 | // kernel{kw} |
3451 | bindShapeDims(rhsShapedType, kwSize); |
3452 | cSize = fSize; |
3453 | break; |
3454 | } |
3455 | lhsShape = {nSize, cSize, |
3456 | // iw = ow * sw + kw * dw - 1 |
3457 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
3458 | // Perform the proper inclusive -> exclusive -> inclusive. |
3459 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
3460 | 1}; |
3461 | switch (oper) { |
3462 | case ConvOperationKind::Conv: |
3463 | rhsShape = {fSize, cSize, kwSize}; |
3464 | break; |
3465 | case ConvOperationKind::Pool: |
3466 | rhsShape = {kwSize}; |
3467 | break; |
3468 | } |
3469 | resShape = {nSize, fSize, wSize}; |
3470 | break; |
3471 | } |
3472 | |
3473 | vector::TransferWriteOp write; |
3474 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
3475 | |
3476 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
3477 | // When strideW == 1, we can batch the contiguous loads and avoid |
3478 | // unrolling |
3479 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
3480 | |
3481 | Type lhsEltType = lhsShapedType.getElementType(); |
3482 | Type rhsEltType = rhsShapedType.getElementType(); |
3483 | Type resEltType = resShapedType.getElementType(); |
3484 | auto lhsType = VectorType::get(lhsShape, lhsEltType); |
3485 | auto rhsType = VectorType::get(rhsShape, rhsEltType); |
3486 | auto resType = VectorType::get(resShape, resEltType); |
3487 | // Zero padding with the corresponding dimensions for lhs, rhs and res. |
3488 | SmallVector<Value> lhsPadding(lhsShape.size(), zero); |
3489 | SmallVector<Value> rhsPadding(rhsShape.size(), zero); |
3490 | SmallVector<Value> resPadding(resShape.size(), zero); |
3491 | |
3492 | // Read the whole lhs, rhs and res in one shot (with zero padding). |
3493 | Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped, |
3494 | lhsPadding); |
3495 | // This is needed only for Conv. |
3496 | Value rhs = nullptr; |
3497 | if (oper == ConvOperationKind::Conv) |
3498 | rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, |
3499 | rhsPadding); |
3500 | Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped, |
3501 | resPadding); |
3502 | |
3503 | // The base vectorization case for channeled convolution is input: |
3504 | // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern |
3505 | // vectorization case, we do pre transpose on input, weight, and output. |
3506 | switch (conv1DOpOrder) { |
3507 | case Conv1DOpOrder::W: |
3508 | case Conv1DOpOrder::Nwc: |
3509 | // Base case, so no transposes necessary. |
3510 | break; |
3511 | case Conv1DOpOrder::Ncw: { |
3512 | // To match base vectorization case, we pre-transpose current case. |
3513 | // ncw -> nwc |
3514 | static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1}; |
3515 | lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs); |
3516 | // fcw -> wcf |
3517 | static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0}; |
3518 | |
3519 | // This is needed only for Conv. |
3520 | if (oper == ConvOperationKind::Conv) |
3521 | rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs); |
3522 | // nfw -> nwf |
3523 | static constexpr std::array<int64_t, 3> permRes = {0, 2, 1}; |
3524 | res = rewriter.create<vector::TransposeOp>(loc, res, permRes); |
3525 | break; |
3526 | } |
3527 | } |
3528 | |
3529 | //===------------------------------------------------------------------===// |
3530 | // Begin vector-only rewrite part |
3531 | //===------------------------------------------------------------------===// |
3532 | // Unroll along kw and read slices of lhs and rhs. |
3533 | SmallVector<Value> lhsVals, rhsVals, resVals; |
3534 | lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize, |
3535 | kwSize, strideW, dilationW, wSizeStep, |
3536 | isSingleChanneled); |
3537 | // Do not do for pooling. |
3538 | if (oper == ConvOperationKind::Conv) |
3539 | rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize); |
3540 | resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, |
3541 | wSizeStep, isSingleChanneled); |
3542 | |
3543 | auto linearIndex = [&](int64_t kw, int64_t w) { |
3544 | return kw * (wSize / wSizeStep) + w; |
3545 | }; |
3546 | |
3547 | // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} |
3548 | // or perform outerproduct for non-channeled convolution or perform simple |
3549 | // arith operation for pooling |
3550 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3551 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3552 | switch (oper) { |
3553 | case ConvOperationKind::Conv: |
3554 | if (isSingleChanneled) { |
3555 | resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, |
3556 | lhsVals[linearIndex(kw, w)], |
3557 | rhsVals[kw], resVals[w]); |
3558 | } else { |
3559 | resVals[w] = conv1dSliceAsContraction(rewriter, loc, |
3560 | lhsVals[linearIndex(kw, w)], |
3561 | rhsVals[kw], resVals[w]); |
3562 | } |
3563 | break; |
3564 | case ConvOperationKind::Pool: |
3565 | resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)], |
3566 | resVals[w]); |
3567 | break; |
3568 | } |
3569 | } |
3570 | } |
3571 | |
3572 | res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, |
3573 | isSingleChanneled); |
3574 | //===------------------------------------------------------------------===// |
3575 | // End vector-only rewrite part |
3576 | //===------------------------------------------------------------------===// |
3577 | |
3578 | // The base vectorization case for channeled convolution is output: |
3579 | // {n,w,f} To reuse the result from base pattern vectorization case, we |
3580 | // post transpose the base case result. |
3581 | switch (conv1DOpOrder) { |
3582 | case Conv1DOpOrder::W: |
3583 | case Conv1DOpOrder::Nwc: |
3584 | // Base case, so no transposes necessary. |
3585 | break; |
3586 | case Conv1DOpOrder::Ncw: { |
3587 | // nwf -> nfw |
3588 | static constexpr std::array<int64_t, 3> perm = {0, 2, 1}; |
3589 | res = rewriter.create<vector::TransposeOp>(loc, res, perm); |
3590 | break; |
3591 | } |
3592 | } |
3593 | |
3594 | return rewriter |
3595 | .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding) |
3596 | .getOperation(); |
3597 | } |
3598 | |
3599 | // Take a value and widen to have the same element type as `ty`. |
3600 | Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { |
3601 | const Type srcElementType = getElementTypeOrSelf(type: val.getType()); |
3602 | const Type dstElementType = getElementTypeOrSelf(type: ty); |
3603 | assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType)); |
3604 | if (srcElementType == dstElementType) |
3605 | return val; |
3606 | |
3607 | const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth(); |
3608 | const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth(); |
3609 | const Type dstType = |
3610 | cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType); |
3611 | |
3612 | if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) { |
3613 | return rewriter.create<arith::SIToFPOp>(loc, dstType, val); |
3614 | } |
3615 | |
3616 | if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) && |
3617 | srcWidth < dstWidth) |
3618 | return rewriter.create<arith::ExtFOp>(loc, dstType, val); |
3619 | |
3620 | if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) && |
3621 | srcWidth < dstWidth) |
3622 | return rewriter.create<arith::ExtSIOp>(loc, dstType, val); |
3623 | |
3624 | assert(false && "unhandled promotion case"); |
3625 | return nullptr; |
3626 | } |
3627 | |
3628 | // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} |
3629 | Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, |
3630 | Value lhs, Value rhs, Value res) { |
3631 | vector::IteratorType par = vector::IteratorType::parallel; |
3632 | vector::IteratorType red = vector::IteratorType::reduction; |
3633 | AffineExpr n, w, f, c; |
3634 | bindDims(ctx, n, w, f, c); |
3635 | lhs = promote(rewriter, loc, val: lhs, ty: res.getType()); |
3636 | rhs = promote(rewriter, loc, val: rhs, ty: res.getType()); |
3637 | auto contrationOp = rewriter.create<vector::ContractionOp>( |
3638 | loc, lhs, rhs, res, |
3639 | /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, |
3640 | /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red}); |
3641 | contrationOp.setKind(reductionKind); |
3642 | return contrationOp; |
3643 | } |
3644 | |
3645 | // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel |
3646 | // convolution. |
3647 | Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, |
3648 | Value lhs, Value rhs, Value res) { |
3649 | return rewriter.create<vector::OuterProductOp>( |
3650 | loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); |
3651 | } |
3652 | |
3653 | // Create a reduction: lhs{n, w, c} -> res{n, w, c} |
3654 | Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, |
3655 | Value res) { |
3656 | if (isPoolExt) |
3657 | lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0); |
3658 | return rewriter |
3659 | .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType()) |
3660 | ->getResult(0); |
3661 | } |
3662 | |
3663 | /// Generate a vector implementation for: |
3664 | /// ``` |
3665 | /// Op def: ( n, w, c, kw) |
3666 | /// Iters: ({Par(), Par(), Par(), Red()}) |
3667 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
3668 | /// ``` |
3669 | /// kw is always unrolled. |
3670 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
3671 | /// > 1. |
3672 | FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize, |
3673 | bool channelDimScalableFlag, |
3674 | bool flatten) { |
3675 | bool scalableChDim = false; |
3676 | bool useMasking = false; |
3677 | int64_t nSize, wSize, cSize, kwSize; |
3678 | // kernel{kw, c} |
3679 | bindShapeDims(rhsShapedType, kwSize, cSize); |
3680 | if (ShapedType::isDynamic(cSize)) { |
3681 | assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0"); |
3682 | cSize = channelDimVecSize; |
3683 | // Scalable vectors are only used when both conditions are met: |
3684 | // 1. channel dim is dynamic |
3685 | // 2. channelDimScalableFlag is set |
3686 | scalableChDim = channelDimScalableFlag; |
3687 | useMasking = true; |
3688 | } |
3689 | |
3690 | assert(!(useMasking && flatten) && |
3691 | "Unsupported flattened conv with dynamic shapes"); |
3692 | |
3693 | // out{n, w, c} |
3694 | bindShapeDims(resShapedType, nSize, wSize); |
3695 | |
3696 | vector::TransferWriteOp write; |
3697 | Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
3698 | |
3699 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
3700 | // When strideW == 1, we can batch the contiguous loads and avoid |
3701 | // unrolling |
3702 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
3703 | |
3704 | Type lhsEltType = lhsShapedType.getElementType(); |
3705 | Type rhsEltType = rhsShapedType.getElementType(); |
3706 | Type resEltType = resShapedType.getElementType(); |
3707 | VectorType lhsType = VectorType::get( |
3708 | {nSize, |
3709 | // iw = ow * sw + kw * dw - 1 |
3710 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
3711 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, |
3712 | cSize}, |
3713 | lhsEltType, /*scalableDims=*/{false, false, scalableChDim}); |
3714 | VectorType rhsType = |
3715 | VectorType::get({kwSize, cSize}, rhsEltType, |
3716 | /*scalableDims=*/{false, scalableChDim}); |
3717 | VectorType resType = |
3718 | VectorType::get({nSize, wSize, cSize}, resEltType, |
3719 | /*scalableDims=*/{false, false, scalableChDim}); |
3720 | |
3721 | // Masks the input xfer Op along the channel dim, iff the corresponding |
3722 | // scalable flag is set. |
3723 | auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape, |
3724 | ArrayRef<bool> scalableDims, |
3725 | Operation *opToMask) { |
3726 | if (!useMasking) |
3727 | return opToMask; |
3728 | auto maskType = |
3729 | VectorType::get(maskShape, rewriter.getI1Type(), scalableDims); |
3730 | |
3731 | SmallVector<bool> inBounds(maskShape.size(), true); |
3732 | auto xferOp = cast<VectorTransferOpInterface>(opToMask); |
3733 | xferOp->setAttr(xferOp.getInBoundsAttrName(), |
3734 | rewriter.getBoolArrayAttr(inBounds)); |
3735 | |
3736 | SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer( |
3737 | cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter); |
3738 | |
3739 | Value maskOp = |
3740 | rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims); |
3741 | |
3742 | return mlir::vector::maskOperation(rewriter, opToMask, maskOp); |
3743 | }; |
3744 | |
3745 | // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, |
3746 | // 0]. |
3747 | Value lhs = rewriter.create<vector::TransferReadOp>( |
3748 | loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); |
3749 | auto maybeMaskedLhs = maybeMaskXferOp( |
3750 | lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); |
3751 | |
3752 | // Read rhs slice of size {kw, c} @ [0, 0]. |
3753 | Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped, |
3754 | ValueRange{zero, zero}); |
3755 | auto maybeMaskedRhs = maybeMaskXferOp( |
3756 | rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); |
3757 | |
3758 | // Read res slice of size {n, w, c} @ [0, 0, 0]. |
3759 | Value res = rewriter.create<vector::TransferReadOp>( |
3760 | loc, resType, resShaped, ValueRange{zero, zero, zero}); |
3761 | auto maybeMaskedRes = maybeMaskXferOp( |
3762 | resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); |
3763 | |
3764 | //===------------------------------------------------------------------===// |
3765 | // Begin vector-only rewrite part |
3766 | //===------------------------------------------------------------------===// |
3767 | // Unroll along kw and read slices of lhs and rhs. |
3768 | SmallVector<Value> lhsVals, rhsVals, resVals; |
3769 | SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize}; |
3770 | SmallVector<int64_t> inOutStrides = {1, 1, 1}; |
3771 | |
3772 | // Extract lhs slice of size {n, wSizeStep, c} |
3773 | // @ [0, sw * w + dw * kw, 0]. |
3774 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3775 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3776 | lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
3777 | loc, maybeMaskedLhs->getResult(0), |
3778 | /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
3779 | inOutSliceSizes, inOutStrides)); |
3780 | } |
3781 | } |
3782 | // Extract rhs slice of size {c} @ [kw]. |
3783 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3784 | rhsVals.push_back(rewriter.create<vector::ExtractOp>( |
3785 | loc, maybeMaskedRhs->getResult(0), |
3786 | /*offsets=*/ArrayRef<int64_t>{kw})); |
3787 | } |
3788 | // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. |
3789 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3790 | resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( |
3791 | loc, maybeMaskedRes->getResult(0), |
3792 | /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes, |
3793 | inOutStrides)); |
3794 | } |
3795 | |
3796 | auto linearIndex = [&](int64_t kw, int64_t w) { |
3797 | return kw * (wSize / wSizeStep) + w; |
3798 | }; |
3799 | |
3800 | // Note - the scalable flags are ignored as flattening combined with |
3801 | // scalable vectorization is not supported. |
3802 | SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize}; |
3803 | auto lhsTypeAfterFlattening = |
3804 | VectorType::get(inOutFlattenSliceSizes, lhsEltType); |
3805 | auto resTypeAfterFlattening = |
3806 | VectorType::get(inOutFlattenSliceSizes, resEltType); |
3807 | |
3808 | // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} |
3809 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
3810 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3811 | Value lhsVal = lhsVals[linearIndex(kw, w)]; |
3812 | Value resVal = resVals[w]; |
3813 | if (flatten) { |
3814 | // Flatten the input and output vectors (collapse the channel |
3815 | // dimension) |
3816 | lhsVal = rewriter.create<vector::ShapeCastOp>( |
3817 | loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]); |
3818 | resVal = rewriter.create<vector::ShapeCastOp>( |
3819 | loc, resTypeAfterFlattening, resVals[w]); |
3820 | } |
3821 | resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal, |
3822 | rhsVals[kw], resVal, flatten); |
3823 | if (flatten) { |
3824 | // Un-flatten the output vector (restore the channel dimension) |
3825 | resVals[w] = rewriter.create<vector::ShapeCastOp>( |
3826 | loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); |
3827 | } |
3828 | } |
3829 | } |
3830 | |
3831 | // Its possible we failed to create the Fma. |
3832 | if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) { |
3833 | // Manually revert (in reverse order) to avoid leaving a bad IR state. |
3834 | for (auto &collection : |
3835 | {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) |
3836 | for (Value v : collection) |
3837 | rewriter.eraseOp(v.getDefiningOp()); |
3838 | return rewriter.notifyMatchFailure(op, "failed to create FMA"); |
3839 | } |
3840 | |
3841 | // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. |
3842 | // This does not depend on kw. |
3843 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
3844 | maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>( |
3845 | loc, resVals[w], maybeMaskedRes->getResult(0), |
3846 | /*offsets=*/ArrayRef<int64_t>{0, w, 0}, |
3847 | /*strides=*/ArrayRef<int64_t>{1, 1, 1}); |
3848 | } |
3849 | //===------------------------------------------------------------------===// |
3850 | // End vector-only rewrite part |
3851 | //===------------------------------------------------------------------===// |
3852 | |
3853 | // Write back res slice of size {n, w, c} @ [0, 0, 0]. |
3854 | Operation *resOut = rewriter.create<vector::TransferWriteOp>( |
3855 | loc, maybeMaskedRes->getResult(0), resShaped, |
3856 | ValueRange{zero, zero, zero}); |
3857 | return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(), |
3858 | resOut); |
3859 | } |
3860 | |
3861 | /// Lower: |
3862 | /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false) |
3863 | /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true) |
3864 | /// to MulAcc. |
3865 | Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, |
3866 | Value lhs, Value rhs, Value res, |
3867 | bool flatten) { |
3868 | auto rhsTy = cast<ShapedType>(rhs.getType()); |
3869 | auto resTy = cast<ShapedType>(res.getType()); |
3870 | |
3871 | // TODO(suderman): Change this to use a vector.ima intrinsic. |
3872 | lhs = promote(rewriter, loc, val: lhs, ty: resTy); |
3873 | |
3874 | if (flatten) { |
3875 | // NOTE: This following logic won't work for scalable vectors. For this |
3876 | // reason, "flattening" is not supported when shapes are dynamic (this |
3877 | // should be captured by one of the pre-conditions). |
3878 | |
3879 | // There are two options for handling the filter: |
3880 | // * shape_cast(broadcast(filter)) |
3881 | // * broadcast(shuffle(filter)) |
3882 | // Opt for the option without shape_cast to simplify the codegen. |
3883 | auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0]; |
3884 | auto resSize = cast<VectorType>(res.getType()).getShape()[1]; |
3885 | |
3886 | SmallVector<int64_t, 16> indices; |
3887 | for (int i = 0; i < resSize / rhsSize; ++i) { |
3888 | for (int j = 0; j < rhsSize; ++j) |
3889 | indices.push_back(Elt: j); |
3890 | } |
3891 | |
3892 | rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices); |
3893 | } |
3894 | // Broadcast the filter to match the output vector |
3895 | rhs = rewriter.create<vector::BroadcastOp>( |
3896 | loc, resTy.clone(rhsTy.getElementType()), rhs); |
3897 | |
3898 | rhs = promote(rewriter, loc, val: rhs, ty: resTy); |
3899 | |
3900 | if (!lhs || !rhs) |
3901 | return nullptr; |
3902 | |
3903 | if (isa<FloatType>(resTy.getElementType())) |
3904 | return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res); |
3905 | |
3906 | auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs); |
3907 | return rewriter.create<arith::AddIOp>(loc, mul, res); |
3908 | } |
3909 | |
3910 | /// Entry point for non-channeled convolution: |
3911 | /// {{w + kw}, {kw}, {w}} |
3912 | FailureOr<Operation *> generateNonChanneledConv() { |
3913 | AffineExpr w, kw; |
3914 | bindDims(ctx, w, kw); |
3915 | if (!iters({Par(), Red()})) |
3916 | return rewriter.notifyMatchFailure(op, |
3917 | "failed to match conv::W 1-par 1-red"); |
3918 | |
3919 | // No transposition needed. |
3920 | if (layout({/*lhsIndex*/ {w + kw}, |
3921 | /*rhsIndex*/ {kw}, |
3922 | /*resIndex*/ {w}})) |
3923 | return conv(conv1DOpOrder: Conv1DOpOrder::W); |
3924 | |
3925 | return rewriter.notifyMatchFailure(op, "not a conv::W layout"); |
3926 | } |
3927 | |
3928 | /// Entry point that transposes into the common form: |
3929 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
3930 | FailureOr<Operation *> generateNwcConv() { |
3931 | AffineExpr n, w, f, kw, c; |
3932 | bindDims(ctx, n, w, f, kw, c); |
3933 | if (!iters({Par(), Par(), Par(), Red(), Red()})) |
3934 | return rewriter.notifyMatchFailure( |
3935 | op, "failed to match conv::Nwc 3-par 2-red"); |
3936 | |
3937 | // No transposition needed. |
3938 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
3939 | /*rhsIndex*/ {kw, c, f}, |
3940 | /*resIndex*/ {n, w, f}})) |
3941 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
3942 | |
3943 | return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout"); |
3944 | } |
3945 | |
3946 | /// Entry point that transposes into the common form: |
3947 | /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
3948 | FailureOr<Operation *> generateNcwConv() { |
3949 | AffineExpr n, w, f, kw, c; |
3950 | bindDims(ctx, n, f, w, c, kw); |
3951 | if (!iters({Par(), Par(), Par(), Red(), Red()})) |
3952 | return rewriter.notifyMatchFailure( |
3953 | op, "failed to match conv::Ncw 3-par 2-red"); |
3954 | |
3955 | if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
3956 | /*rhsIndex*/ {f, c, kw}, |
3957 | /*resIndex*/ {n, f, w}})) |
3958 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
3959 | |
3960 | return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout"); |
3961 | } |
3962 | |
3963 | /// Entry point that transposes into the common form: |
3964 | /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling |
3965 | FailureOr<Operation *> generateNwcPooling() { |
3966 | AffineExpr n, w, c, kw; |
3967 | bindDims(ctx, n, w, c, kw); |
3968 | if (!iters({Par(), Par(), Par(), Red()})) |
3969 | return rewriter.notifyMatchFailure(op, |
3970 | "failed to match pooling 3-par 1-red"); |
3971 | |
3972 | // No transposition needed. |
3973 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
3974 | /*rhsIndex*/ {kw}, |
3975 | /*resIndex*/ {n, w, c}})) |
3976 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
3977 | |
3978 | return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout"); |
3979 | } |
3980 | |
3981 | /// Entry point that transposes into the common form: |
3982 | /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling |
3983 | FailureOr<Operation *> generateNcwPooling() { |
3984 | AffineExpr n, w, c, kw; |
3985 | bindDims(ctx, n, c, w, kw); |
3986 | if (!iters({Par(), Par(), Par(), Red()})) |
3987 | return rewriter.notifyMatchFailure(op, |
3988 | "failed to match pooling 3-par 1-red"); |
3989 | |
3990 | if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
3991 | /*rhsIndex*/ {kw}, |
3992 | /*resIndex*/ {n, c, w}})) |
3993 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
3994 | |
3995 | return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout"); |
3996 | } |
3997 | |
3998 | /// Entry point that transposes into the common form: |
3999 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
4000 | FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0, |
4001 | bool vecChDimScalableFlag = false, |
4002 | bool flatten = false) { |
4003 | AffineExpr n, w, c, kw; |
4004 | bindDims(ctx, n, w, c, kw); |
4005 | if (!iters({Par(), Par(), Par(), Red()})) |
4006 | return rewriter.notifyMatchFailure( |
4007 | op, "failed to match depthwise::Nwc conv 3-par 1-red"); |
4008 | |
4009 | // No transposition needed. |
4010 | if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
4011 | /*rhsIndex*/ {kw, c}, |
4012 | /*resIndex*/ {n, w, c}})) |
4013 | return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten); |
4014 | |
4015 | return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout"); |
4016 | } |
4017 | |
4018 | private: |
4019 | ConvOperationKind oper = ConvOperationKind::Conv; |
4020 | StringAttr redOp; |
4021 | StringAttr poolExtOp; |
4022 | bool isPoolExt = false; |
4023 | int strideW, dilationW; |
4024 | Value lhsShaped, rhsShaped, resShaped; |
4025 | ShapedType lhsShapedType, rhsShapedType, resShapedType; |
4026 | vector::CombiningKind reductionKind; |
4027 | |
4028 | // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. |
4029 | void setConvOperationKind(Operation *reduceOp) { |
4030 | int numBlockArguments = |
4031 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
4032 | if (numBlockArguments == 1) { |
4033 | // Will be convolution if feeder is a MulOp. |
4034 | // A strength reduced version of MulOp for i1 type is AndOp which is also |
4035 | // supported. Otherwise, it can be pooling. This strength reduction logic |
4036 | // is in `buildBinaryFn` helper in the Linalg dialect. |
4037 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
4038 | P: llvm::IsaPred<BlockArgument>); |
4039 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
4040 | if (isCastOfBlockArgument(op: feedOp)) { |
4041 | oper = ConvOperationKind::Pool; |
4042 | isPoolExt = true; |
4043 | poolExtOp = feedOp->getName().getIdentifier(); |
4044 | return; |
4045 | } |
4046 | oper = ConvOperationKind::Conv; |
4047 | return; |
4048 | } |
4049 | // numBlockArugments == 2 and this is a pooling op. |
4050 | oper = ConvOperationKind::Pool; |
4051 | isPoolExt = false; |
4052 | } |
4053 | }; |
4054 | } // namespace |
4055 | |
4056 | /// Helper function to vectorize a LinalgOp with convolution semantics. |
4057 | // TODO: extend the generic vectorization to support windows and drop this. |
4058 | static FailureOr<Operation *> vectorizeConvolution( |
4059 | RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes, |
4060 | ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) { |
4061 | Conv1DGenerator conv1dGen(rewriter, op); |
4062 | auto res = conv1dGen.generateNonChanneledConv(); |
4063 | if (succeeded(res)) |
4064 | return res; |
4065 | res = conv1dGen.generateNwcConv(); |
4066 | if (succeeded(res)) |
4067 | return res; |
4068 | res = conv1dGen.generateNcwConv(); |
4069 | if (succeeded(res)) |
4070 | return res; |
4071 | res = conv1dGen.generateNwcPooling(); |
4072 | if (succeeded(res)) |
4073 | return res; |
4074 | res = conv1dGen.generateNcwPooling(); |
4075 | if (succeeded(res)) |
4076 | return res; |
4077 | |
4078 | // Only depthwise 1D NWC convs are left - these can be vectorized using masks |
4079 | // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e. |
4080 | // masked/scalable) is the channel dim (i.e. the trailing dim). |
4081 | uint64_t vecChDimSize = ShapedType::kDynamic; |
4082 | bool vecChDimScalableFlag = false; |
4083 | if (!inputVecSizes.empty()) { |
4084 | // Only use the input vector size corresponding to the channel dim. Other |
4085 | // vector dims will be inferred from the Ops. |
4086 | assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) || |
4087 | isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) && |
4088 | "Not a 1D depthwise conv!"); |
4089 | size_t chDimIdx = |
4090 | TypeSwitch<Operation *, size_t>(op) |
4091 | .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; }) |
4092 | .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; }); |
4093 | |
4094 | vecChDimSize = inputVecSizes[chDimIdx]; |
4095 | vecChDimScalableFlag = inputScalableVecDims[chDimIdx]; |
4096 | } |
4097 | return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag, |
4098 | flatten: flatten1DDepthwiseConv); |
4099 | } |
4100 | |
4101 | struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> { |
4102 | using OpInterfaceRewritePattern::OpInterfaceRewritePattern; |
4103 | |
4104 | LogicalResult matchAndRewrite(LinalgOp op, |
4105 | PatternRewriter &rewriter) const override { |
4106 | FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op); |
4107 | if (failed(Result: resultOrFail)) |
4108 | return failure(); |
4109 | Operation *newOp = *resultOrFail; |
4110 | if (newOp->getNumResults() == 0) { |
4111 | rewriter.eraseOp(op: op.getOperation()); |
4112 | return success(); |
4113 | } |
4114 | assert(newOp->getNumResults() == 1 && "expected single result"); |
4115 | rewriter.replaceOp(op.getOperation(), newOp->getResult(idx: 0)); |
4116 | return success(); |
4117 | } |
4118 | }; |
4119 | |
4120 | void mlir::linalg::populateConvolutionVectorizationPatterns( |
4121 | RewritePatternSet &patterns, PatternBenefit benefit) { |
4122 | patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit); |
4123 | } |
4124 |
Definitions
- getSingleOpOfType
- extractConvInputSlices
- extractConvFilterSlices
- extractConvResultSlices
- insertConvResultSlices
- VectorizationState
- VectorizationState
- getCanonicalVecShape
- getScalableVecDims
- getCanonicalVecType
- initIterSpaceStaticSizes
- isValidMaskingMap
- getMaskingMapFromIndexingMap
- precomputeIterSpaceValueSizes
- initState
- getOrCreateMaskFor
- maskOperation
- reindexIndexingMap
- Conv1DOpOrder
- VectorizationStatus
- VectorizationResult
- getCombinerOpKind
- matchLinalgReduction
- broadcastIfNeeded
- buildMultiDimReduce
- getDimsToReduce
- hasReductionIterator
- buildVectorWrite
- vectorizeLinalgYield
- vectorizeLinalgIndex
- tensorExtractVectorizationPrecondition
- calculateGatherOffset
- VectorMemoryAccessKind
- getTrailingNonUnitLoopDimIdx
- isLoopInvariantIdx
- isContiguousLoadIdx
- getTensorExtractMemoryAccessPattern
- vectorizeTensorExtract
- reduceIfNeeded
- vectorizeOneOp
- vectorizeAsLinalgGeneric
- getTiledPackShape
- isMaskTriviallyFoldable
- createWriteOrMaskedWrite
- vectorizeAsTensorPackOp
- vectorizeAsTensorUnpackOp
- vectorizeAsTensorPadOp
- reductionPreconditions
- vectorizeDynamicConvOpPrecondition
- vectorizeDynamicLinalgOpPrecondition
- vectorizeUnPackOpPrecondition
- vectorizeInsertSliceOpPrecondition
- ConvOperationKind
- isCastOfBlockArgument
- getConvOperationKind
- isSupportedPoolKind
- vectorizeConvOpPrecondition
- vectorizeLinalgOpPrecondition
- vectorizePackOpPrecondition
- vectorizePadOpPrecondition
- vectorizeScalableVectorPrecondition
- vectorizeOpPrecondition
- convertAffineApply
- hasVectorizationImpl
- vectorize
- vectorizeCopy
- VectorizePadOpUserPattern
- matchAndRewrite
- PadOpVectorizationWithTransferReadPattern
- rewriteUser
- PadOpVectorizationWithTransferWritePattern
- rewriteUser
- hasSameTensorSize
- getStaticPadVal
- vectorizeAsInsertSliceOp
- PadOpVectorizationWithInsertSlicePattern
- rewriteUser
- populatePadOpVectorizationPatterns
- mayExistInterleavedUses
- getSubViewUseIfUnique
- matchAndRewrite
- matchAndRewrite
- bindShapeDims
- bindShapeDims
- bindShapeDims
- Conv1DGenerator
- Conv1DGenerator
- conv
- promote
- conv1dSliceAsContraction
- conv1dSliceAsOuterProduct
- pool1dSlice
- depthwiseConv
- depthwiseConv1dSliceAsMulAcc
- generateNonChanneledConv
- generateNwcConv
- generateNcwConv
- generateNwcPooling
- generateNcwPooling
- generateDilatedConv
- setConvOperationKind
- vectorizeConvolution
- VectorizeConvolution
- matchAndRewrite
Improve your Profiling and Debugging skills
Find out more