1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
12#include "mlir/Dialect/Affine/Utils.h"
13
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/Linalg/Utils/Utils.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
26#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinTypeInterfaces.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/OpDefinition.h"
32#include "mlir/IR/PatternMatch.h"
33#include "mlir/IR/Value.h"
34#include "mlir/Support/LLVM.h"
35#include "mlir/Transforms/RegionUtils.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/Sequence.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/Support/Debug.h"
41#include "llvm/Support/MathExtras.h"
42#include "llvm/Support/raw_ostream.h"
43#include <optional>
44
45using namespace mlir;
46using namespace mlir::linalg;
47
48#define DEBUG_TYPE "linalg-vectorization"
49
50#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
51#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
52
53/// Try to vectorize `convOp` as a convolution.
54static FailureOr<Operation *>
55vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
56 ArrayRef<int64_t> inputVecSizes = {},
57 ArrayRef<bool> inputVecScalableFlags = {},
58 bool flatten1DDepthwiseConv = false);
59
60/// Vectorize tensor::InsertSliceOp with:
61/// * vector::TransferReadOp + vector::TransferWriteOp
62/// The vector sizes are either:
63/// * user-provided in `inputVectorSizes`, or
64/// * inferred from the static dims in the input and output tensors.
65/// Bails out if:
66/// * vector sizes are not user-provided, and
67/// * at least one dim is dynamic (in both the input and output tensors).
68///
69/// Before:
70/// !t_in_type = tensor<1x2x3xf32>
71/// !t_out_type = tensor<9x8x7x1x2x3xf32>
72/// !v_type = vector<1x2x3xf32>
73/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
74/// into !t_out_type
75/// After:
76/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
77/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
78static LogicalResult
79vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
80 ArrayRef<int64_t> inputVectorSizes,
81 SmallVectorImpl<Value> &newResults);
82
83/// Returns the effective Pad value for the input op, provided it's a scalar.
84///
85/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
86/// this Op performs padding, retrieve the padding value provided that it's
87/// a scalar and static/fixed for all the padded values. Returns an empty value
88/// otherwise.
89static Value getStaticPadVal(Operation *op);
90
91/// Return the unique instance of OpType in `block` if it is indeed unique.
92/// Return null if none or more than 1 instances exist.
93template <typename OpType>
94static OpType getSingleOpOfType(Block &block) {
95 OpType res;
96 block.walk([&](OpType op) {
97 if (res) {
98 res = nullptr;
99 return WalkResult::interrupt();
100 }
101 res = op;
102 return WalkResult::advance();
103 });
104 return res;
105}
106
107/// Helper function to extract the input slices after filter is unrolled along
108/// kw.
109static SmallVector<Value>
110extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
111 int64_t nSize, int64_t wSize, int64_t cSize,
112 int64_t kwSize, int strideW, int dilationW,
113 int64_t wSizeStep, bool isSingleChanneled) {
114 SmallVector<Value> result;
115 if (isSingleChanneled) {
116 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
117 // convolution.
118 SmallVector<int64_t> sizes = {wSizeStep};
119 SmallVector<int64_t> strides = {1};
120 for (int64_t kw = 0; kw < kwSize; ++kw) {
121 for (int64_t w = 0; w < wSize; w += wSizeStep) {
122 result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>(
123 location: loc, args&: input, /*offsets=*/args: ArrayRef<int64_t>{w + kw}, args&: sizes, args&: strides));
124 }
125 }
126 } else {
127 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128 // for channeled convolution.
129 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130 SmallVector<int64_t> strides = {1, 1, 1};
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>(
134 location: loc, args&: input,
135 /*offsets=*/args: ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136 args&: sizes, args&: strides));
137 }
138 }
139 }
140 return result;
141}
142
143/// Helper function to extract the filter slices after filter is unrolled along
144/// kw.
145static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
146 Location loc, Value filter,
147 int64_t kwSize) {
148 SmallVector<Value> result;
149 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150 // non-chanelled convolution] @ [kw].
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(Elt: rewriter.create<vector::ExtractOp>(
153 location: loc, args&: filter, /*offsets=*/args: ArrayRef<int64_t>{kw}));
154 }
155 return result;
156}
157
158/// Helper function to extract the result slices after filter is unrolled along
159/// kw.
160static SmallVector<Value>
161extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep, bool isSingleChanneled) {
164 SmallVector<Value> result;
165 if (isSingleChanneled) {
166 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167 SmallVector<int64_t> sizes = {wSizeStep};
168 SmallVector<int64_t> strides = {1};
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>(
171 location: loc, args&: res, /*offsets=*/args: ArrayRef<int64_t>{w}, args&: sizes, args&: strides));
172 }
173 } else {
174 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
175 // convolution.
176 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
177 SmallVector<int64_t> strides = {1, 1, 1};
178 for (int64_t w = 0; w < wSize; w += wSizeStep) {
179 result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>(
180 location: loc, args&: res, /*offsets=*/args: ArrayRef<int64_t>{0, w, 0}, args&: sizes, args&: strides));
181 }
182 }
183 return result;
184}
185
186/// Helper function to insert the computed result slices.
187static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
188 Value res, int64_t wSize, int64_t wSizeStep,
189 SmallVectorImpl<Value> &resVals,
190 bool isSingleChanneled) {
191
192 if (isSingleChanneled) {
193 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
194 // This does not depend on kw.
195 SmallVector<int64_t> strides = {1};
196 for (int64_t w = 0; w < wSize; w += wSizeStep) {
197 res = rewriter.create<vector::InsertStridedSliceOp>(
198 location: loc, args&: resVals[w], args&: res, /*offsets=*/args: ArrayRef<int64_t>{w}, args&: strides);
199 }
200 } else {
201 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
202 // convolution. This does not depend on kw.
203 SmallVector<int64_t> strides = {1, 1, 1};
204 for (int64_t w = 0; w < wSize; w += wSizeStep) {
205 res = rewriter.create<vector::InsertStridedSliceOp>(
206 location: loc, args&: resVals[w], args&: res, /*offsets=*/args: ArrayRef<int64_t>{0, w, 0},
207 args&: strides);
208 }
209 }
210 return res;
211}
212
213/// Contains the vectorization state and related methods used across the
214/// vectorization process of a given operation.
215struct VectorizationState {
216 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
217
218 /// Initializes the vectorization state, including the computation of the
219 /// canonical vector shape for vectorization.
220 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
221 ArrayRef<int64_t> inputVectorSizes,
222 ArrayRef<bool> inputScalableVecDims);
223
224 /// Returns the canonical vector shape used to vectorize the iteration space.
225 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
226
227 /// Returns the vector dimensions that are scalable in the canonical vector
228 /// shape.
229 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
230
231 /// Returns a vector type of the provided `elementType` with the canonical
232 /// vector shape and the corresponding fixed/scalable dimensions bit. If
233 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
234 /// accordingly.
235 VectorType getCanonicalVecType(
236 Type elementType,
237 std::optional<AffineMap> dimPermutation = std::nullopt) const {
238 SmallVector<int64_t> vectorShape;
239 SmallVector<bool> scalableDims;
240 if (dimPermutation.has_value()) {
241 vectorShape =
242 applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape);
243 scalableDims =
244 applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims);
245 } else {
246 vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end());
247 scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end());
248 }
249
250 return VectorType::get(shape: vectorShape, elementType, scalableDims);
251 }
252
253 /// Masks an operation with the canonical vector mask if the operation needs
254 /// masking. Returns the masked operation or the original operation if masking
255 /// is not needed. If provided, the canonical mask for this operation is
256 /// permuted using `maybeIndexingMap`.
257 Operation *
258 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
259 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
260
261private:
262 /// Initializes the iteration space static sizes using the Linalg op
263 /// information. This may become more complicated in the future.
264 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
265 iterSpaceStaticSizes.append(RHS: linalgOp.getStaticLoopRanges());
266 }
267
268 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
269 /// all the static and dynamic dimensions of the iteration space to be
270 /// vectorized and store them in `iterSpaceValueSizes`.
271 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
272 LinalgOp linalgOp);
273
274 /// Create or retrieve an existing mask value to mask `opToMask` in the
275 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
276 /// permuted using that permutation map. If a new mask is created, it will be
277 /// cached for future users.
278 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
279 LinalgOp linalgOp,
280 std::optional<AffineMap> maybeMaskingMap);
281
282 /// Check whether this permutation map can be used for masking. At the
283 /// moment we only make sure that there are no broadcast dimensions, but this
284 /// might change if indexing maps evolve.
285 bool isValidMaskingMap(AffineMap maskingMap) {
286 return maskingMap.getBroadcastDims().size() == 0;
287 }
288
289 /// Turn the input indexing map into a valid masking map.
290 ///
291 /// The input indexing map may contain "zero" results, e.g.:
292 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
293 /// Applying such maps to canonical vector shapes like this one:
294 /// (1, 16, 16, 4)
295 /// would yield an invalid vector shape like this:
296 /// (16, 16, 1, 0)
297 /// Instead, drop the broadcasting dims that make no sense for masking perm.
298 /// maps:
299 /// (d0, d1, d2, d3) -> (d2, d1, d0)
300 /// This way, the corresponding vector/mask type will be:
301 /// vector<16x16x1xty>
302 /// rather than this invalid Vector type:
303 /// vector<16x16x1x0xty>
304 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
305 return indexingMap.dropZeroResults();
306 }
307
308 // Holds the compile-time static sizes of the iteration space to vectorize.
309 // Dynamic dimensions are represented using ShapedType::kDynamic.
310 SmallVector<int64_t> iterSpaceStaticSizes;
311
312 /// Holds the value sizes of the iteration space to vectorize. Static
313 /// dimensions are represented by 'arith.constant' and dynamic
314 /// dimensions by 'tensor/memref.dim'.
315 SmallVector<Value> iterSpaceValueSizes;
316
317 /// Holds the canonical vector shape used to vectorize the iteration space.
318 SmallVector<int64_t> canonicalVecShape;
319
320 /// Holds the vector dimensions that are scalable in the canonical vector
321 /// shape.
322 SmallVector<bool> scalableVecDims;
323
324 /// Holds the active masks for permutations of the canonical vector iteration
325 /// space.
326 DenseMap<AffineMap, Value> activeMaskCache;
327
328 /// Global vectorization guard for the incoming rewriter. It's initialized
329 /// when the vectorization state is initialized.
330 OpBuilder::InsertionGuard rewriterGuard;
331};
332
333LogicalResult
334VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
335 LinalgOp linalgOp) {
336 // TODO: Support 0-d vectors.
337 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
338 if (ShapedType::isStatic(dValue: iterSpaceStaticSizes[vecDim])) {
339 // Create constant index op for static dimensions.
340 iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(
341 location: linalgOp.getLoc(), args&: iterSpaceStaticSizes[vecDim]));
342 continue;
343 }
344
345 // Find an operand defined on this dimension of the iteration space to
346 // extract the runtime dimension size.
347 Value operand;
348 unsigned operandDimPos;
349 if (failed(Result: linalgOp.mapIterationSpaceDimToOperandDim(dimPos: vecDim, operand,
350 operandDimPos)))
351 return failure();
352
353 Value dynamicDim = linalgOp.hasPureTensorSemantics()
354 ? (Value)rewriter.create<tensor::DimOp>(
355 location: linalgOp.getLoc(), args&: operand, args&: operandDimPos)
356 : (Value)rewriter.create<memref::DimOp>(
357 location: linalgOp.getLoc(), args&: operand, args&: operandDimPos);
358 iterSpaceValueSizes.push_back(Elt: dynamicDim);
359 }
360
361 return success();
362}
363
364/// Initializes the vectorization state, including the computation of the
365/// canonical vector shape for vectorization.
366// TODO: Move this to the constructor when we can remove the failure cases.
367LogicalResult
368VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
369 ArrayRef<int64_t> inputVectorSizes,
370 ArrayRef<bool> inputScalableVecDims) {
371 // Initialize the insertion point.
372 rewriter.setInsertionPoint(linalgOp);
373
374 if (!inputVectorSizes.empty()) {
375 // Get the canonical vector shape from the input vector sizes provided. This
376 // path should be taken to vectorize code with dynamic shapes and when using
377 // vector sizes greater than the iteration space sizes.
378 canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end());
379 scalableVecDims.append(in_start: inputScalableVecDims.begin(),
380 in_end: inputScalableVecDims.end());
381 } else {
382 // Compute the canonical vector shape from the operation shape. If there are
383 // dynamic shapes, the operation won't be vectorized. We assume all the
384 // vector dimensions are fixed.
385 canonicalVecShape = linalgOp.getStaticLoopRanges();
386 scalableVecDims.append(NumInputs: linalgOp.getNumLoops(), Elt: false);
387 }
388
389 LDBG("Canonical vector shape: ");
390 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
391 LLVM_DEBUG(llvm::dbgs() << "\n");
392 LDBG("Scalable vector dims: ");
393 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
394 LLVM_DEBUG(llvm::dbgs() << "\n");
395
396 if (ShapedType::isDynamicShape(dSizes: canonicalVecShape))
397 return failure();
398
399 // Initialize iteration space static sizes.
400 initIterSpaceStaticSizes(linalgOp);
401
402 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
403 // all the static and dynamic dimensions of the iteration space, needed to
404 // compute a mask during vectorization.
405 if (failed(Result: precomputeIterSpaceValueSizes(rewriter, linalgOp)))
406 return failure();
407
408 return success();
409}
410
411/// Create or retrieve an existing mask value to mask `opToMask` in the
412/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
413/// using that permutation map. If a new mask is created, it will be cached for
414/// future users.
415Value VectorizationState::getOrCreateMaskFor(
416 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
417 std::optional<AffineMap> maybeMaskingMap) {
418
419 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
420 "Ill-formed masking map.");
421
422 // No mask is needed if the operation is not maskable.
423 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(Val: opToMask);
424 if (!maskableOp)
425 return Value();
426
427 assert(!maskableOp.isMasked() &&
428 "Masking an operation that is already masked");
429
430 // If no masking map was provided, use an identity map with the loop dims.
431 assert((!maybeMaskingMap || *maybeMaskingMap) &&
432 "Unexpected null mask permutation map");
433 AffineMap maskingMap =
434 maybeMaskingMap ? *maybeMaskingMap
435 : AffineMap::getMultiDimIdentityMap(
436 numDims: linalgOp.getNumLoops(), context: rewriter.getContext());
437
438 LDBG("Masking map: " << maskingMap << "\n");
439
440 // Return the active mask for the masking map of this operation if it was
441 // already created.
442 auto activeMaskIt = activeMaskCache.find(Val: maskingMap);
443 if (activeMaskIt != activeMaskCache.end()) {
444 Value mask = activeMaskIt->second;
445 LDBG("Reusing mask: " << mask << "\n");
446 return mask;
447 }
448
449 // Compute permuted projection of the iteration space to be masked and the
450 // corresponding mask shape. If the resulting iteration space dimensions are
451 // static and identical to the mask shape, masking is not needed for this
452 // operation.
453 // TODO: Improve this check. Only projected permutation indexing maps are
454 // supported.
455 SmallVector<int64_t> permutedStaticSizes =
456 applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes);
457 auto maskType = getCanonicalVecType(elementType: rewriter.getI1Type(), dimPermutation: maskingMap);
458 auto maskShape = maskType.getShape();
459
460 LDBG("Mask shape: ");
461 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
462 LLVM_DEBUG(llvm::dbgs() << "\n");
463
464 if (permutedStaticSizes == maskShape) {
465 LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
466 activeMaskCache[maskingMap] = Value();
467 return Value();
468 }
469
470 // Permute the iteration space value sizes to compute the mask upper bounds.
471 SmallVector<Value> upperBounds =
472 applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes));
473 assert(!maskShape.empty() && !upperBounds.empty() &&
474 "Masked 0-d vectors are not supported yet");
475
476 // Create the mask based on the dimension values.
477 Value mask = rewriter.create<vector::CreateMaskOp>(location: linalgOp.getLoc(),
478 args&: maskType, args&: upperBounds);
479 LDBG("Creating new mask: " << mask << "\n");
480 activeMaskCache[maskingMap] = mask;
481 return mask;
482}
483
484Operation *
485VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
486 LinalgOp linalgOp,
487 std::optional<AffineMap> maybeIndexingMap) {
488 LDBG("Trying to mask: " << *opToMask << "\n");
489
490 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
491 if (maybeIndexingMap)
492 maybeMaskingMap = getMaskingMapFromIndexingMap(indexingMap&: *maybeIndexingMap);
493
494 // Create or retrieve mask for this operation.
495 Value mask =
496 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
497
498 if (!mask) {
499 LDBG("No mask required\n");
500 return opToMask;
501 }
502
503 // Wrap the operation with a new `vector.mask` and update D-U chain.
504 assert(opToMask && "Expected a valid operation to mask");
505 auto maskOp = cast<vector::MaskOp>(
506 Val: mlir::vector::maskOperation(builder&: rewriter, maskableOp: opToMask, mask));
507 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
508
509 for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults()))
510 rewriter.replaceAllUsesExcept(from: resVal, to: maskOp.getResult(i: resIdx),
511 exceptedUser: maskOpTerminator);
512
513 LDBG("Masked operation: " << *maskOp << "\n");
514 return maskOp;
515}
516
517/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
518/// projectedPermutation, compress the unused dimensions to serve as a
519/// permutation_map for a vector transfer operation.
520/// For example, given a linalg op such as:
521///
522/// ```
523/// %0 = linalg.generic {
524/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
525/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
526/// }
527/// ins(%0 : tensor<2x3x4xf32>)
528/// outs(%1 : tensor<5x6xf32>)
529/// ```
530///
531/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
532/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
533/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
534static AffineMap reindexIndexingMap(AffineMap map) {
535 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
536 "expected projected permutation");
537 auto res = compressUnusedDims(map);
538 assert(res.getNumDims() ==
539 (res.getNumResults() - res.getNumOfZeroResults()) &&
540 "expected reindexed map with same number of dims and results");
541 return res;
542}
543
544/// Helper enum to represent conv1d input traversal order.
545enum class Conv1DOpOrder {
546 W, // Corresponds to non-channeled 1D convolution operation.
547 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
548 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
549};
550
551/// Helper data structure to represent the result of vectorization for a single
552/// operation. In certain specific cases, like terminators, we do not want to
553/// propagate.
554enum VectorizationHookStatus {
555 /// Op failed to vectorize.
556 Failure = 0,
557 /// Op vectorized and custom function took care of replacement logic
558 NoReplace,
559 /// Op vectorized into a new Op whose results will replace original Op's
560 /// results.
561 NewOp
562 // TODO: support values if Op vectorized to Many-Ops whose results we need to
563 // aggregate for replacement.
564};
565/// VectorizationHookResult contains the vectorized op returned from a
566/// CustomVectorizationHook. This is an internal implementation detail of
567/// linalg vectorization, not to be confused with VectorizationResult.
568struct VectorizationHookResult {
569 /// Return status from vectorizing the current op.
570 enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
571 /// New vectorized operation to replace the current op.
572 /// Replacement behavior is specified by `status`.
573 Operation *newOp;
574};
575
576std::optional<vector::CombiningKind>
577mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
578 using ::mlir::vector::CombiningKind;
579
580 if (!combinerOp)
581 return std::nullopt;
582 return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
583 .Case<arith::AddIOp, arith::AddFOp>(
584 caseFn: [&](auto op) { return CombiningKind::ADD; })
585 .Case<arith::AndIOp>(caseFn: [&](auto op) { return CombiningKind::AND; })
586 .Case<arith::MaxSIOp>(caseFn: [&](auto op) { return CombiningKind::MAXSI; })
587 .Case<arith::MaxUIOp>(caseFn: [&](auto op) { return CombiningKind::MAXUI; })
588 .Case<arith::MaximumFOp>(caseFn: [&](auto op) { return CombiningKind::MAXIMUMF; })
589 .Case<arith::MaxNumFOp>(caseFn: [&](auto op) { return CombiningKind::MAXNUMF; })
590 .Case<arith::MinSIOp>(caseFn: [&](auto op) { return CombiningKind::MINSI; })
591 .Case<arith::MinUIOp>(caseFn: [&](auto op) { return CombiningKind::MINUI; })
592 .Case<arith::MinimumFOp>(caseFn: [&](auto op) { return CombiningKind::MINIMUMF; })
593 .Case<arith::MinNumFOp>(caseFn: [&](auto op) { return CombiningKind::MINNUMF; })
594 .Case<arith::MulIOp, arith::MulFOp>(
595 caseFn: [&](auto op) { return CombiningKind::MUL; })
596 .Case<arith::OrIOp>(caseFn: [&](auto op) { return CombiningKind::OR; })
597 .Case<arith::XOrIOp>(caseFn: [&](auto op) { return CombiningKind::XOR; })
598 .Default(defaultFn: [&](auto op) { return std::nullopt; });
599}
600
601/// Check whether `outputOperand` is a reduction with a single combiner
602/// operation. Return the combiner operation of the reduction. Return
603/// nullptr otherwise. Multiple reduction operations would impose an
604/// ordering between reduction dimensions and is currently unsupported in
605/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
606/// max(min(X))
607// TODO: use in LinalgOp verification, there is a circular dependency atm.
608static Operation *matchLinalgReduction(OpOperand *outputOperand) {
609 auto linalgOp = cast<LinalgOp>(Val: outputOperand->getOwner());
610 unsigned outputPos =
611 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
612 // Only single combiner operations are supported for now.
613 SmallVector<Operation *, 4> combinerOps;
614 if (!matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: outputPos, combinerOps) ||
615 combinerOps.size() != 1)
616 return nullptr;
617
618 // Return the combiner operation.
619 return combinerOps[0];
620}
621
622/// Broadcast `value` to a vector of `shape` if possible. Return value
623/// otherwise.
624static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
625 auto dstVecType = dyn_cast<VectorType>(Val&: dstType);
626 // If no shape to broadcast to, just return `value`.
627 if (dstVecType.getRank() == 0)
628 return value;
629 if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) !=
630 vector::BroadcastableToResult::Success)
631 return value;
632 Location loc = b.getInsertionPoint()->getLoc();
633 return b.createOrFold<vector::BroadcastOp>(location: loc, args&: dstVecType, args&: value);
634}
635
636/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
637/// assumes that `reductionOp` has two operands and one of them is the reduction
638/// initial value.buildMultiDimReduce
639// Note: this is a true builder that notifies the OpBuilder listener.
640// TODO: Consider moving as a static helper on the ReduceOp.
641static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
642 Value valueToReduce, Value acc,
643 ArrayRef<bool> dimsToMask) {
644 auto maybeKind = getCombinerOpKind(combinerOp: reduceOp);
645 assert(maybeKind && "Failed precondition: could not get reduction kind");
646 return b.create<vector::MultiDimReductionOp>(
647 location: reduceOp->getLoc(), args&: valueToReduce, args&: acc, args&: dimsToMask, args&: *maybeKind);
648}
649
650static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
651 return llvm::to_vector(
652 Range: llvm::map_range(C: linalgOp.getIteratorTypesArray(), F: isReductionIterator));
653}
654
655/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
656/// reduction iterator.
657static bool hasReductionIterator(LinalgOp &op) {
658 return isa<linalg::ReduceOp>(Val: op) ||
659 (isa<linalg::GenericOp>(Val: op) &&
660 llvm::any_of(Range: op.getIteratorTypesArray(), P: isReductionIterator));
661}
662
663/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
664/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
665/// currently being vectorized. If `dest` has null rank, build an memref.store.
666/// Return the produced value or null if no value is produced.
667// Note: this is a true builder that notifies the OpBuilder listener.
668// TODO: Consider moving as a static helper on the ReduceOp.
669static Value buildVectorWrite(RewriterBase &rewriter, Value value,
670 OpOperand *outputOperand,
671 VectorizationState &state) {
672 Location loc = value.getLoc();
673 auto linalgOp = cast<LinalgOp>(Val: outputOperand->getOwner());
674 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(opOperand: outputOperand);
675
676 // Compute the vector type of the value to store. This type should be an
677 // identity or projection of the canonical vector type without any permutation
678 // applied, given that any permutation in a transfer write happens as part of
679 // the write itself.
680 AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
681 ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(),
682 keepDimFilter: [&](AffineDimExpr dimExpr) -> bool {
683 return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr);
684 });
685 auto vectorType = state.getCanonicalVecType(
686 elementType: getElementTypeOrSelf(type: outputOperand->get().getType()), dimPermutation: vectorTypeMap);
687
688 Operation *write;
689 if (vectorType.getRank() > 0) {
690 AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap));
691 SmallVector<Value> indices(linalgOp.getRank(opOperand: outputOperand),
692 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
693 value = broadcastIfNeeded(b&: rewriter, value, dstType: vectorType);
694 assert(value.getType() == vectorType && "Incorrect type");
695 write = rewriter.create<vector::TransferWriteOp>(
696 location: loc, args&: value, args: outputOperand->get(), args&: indices, args&: writeMap);
697 } else {
698 // 0-d case is still special: do not invert the reindexing writeMap.
699 if (!isa<VectorType>(Val: value.getType()))
700 value = rewriter.create<vector::BroadcastOp>(location: loc, args&: vectorType, args&: value);
701 assert(value.getType() == vectorType && "Incorrect type");
702 write = rewriter.create<vector::TransferWriteOp>(
703 location: loc, args&: value, args: outputOperand->get(), args: ValueRange{});
704 }
705
706 write = state.maskOperation(rewriter, opToMask: write, linalgOp, maybeIndexingMap: opOperandMap);
707
708 // If masked, set in-bounds to true. Masking guarantees that the access will
709 // be in-bounds.
710 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(Val: write)) {
711 auto maskedWriteOp = cast<vector::TransferWriteOp>(Val: maskOp.getMaskableOp());
712 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
713 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(values: inBounds));
714 }
715
716 LDBG("vectorized op: " << *write << "\n");
717 if (!write->getResults().empty())
718 return write->getResult(idx: 0);
719 return Value();
720}
721
722// Custom vectorization precondition function type. This is intented to be used
723// with CustomVectorizationHook. Returns success if the corresponding custom
724// hook can vectorize the op.
725using CustomVectorizationPrecondition =
726 std::function<LogicalResult(Operation *, bool)>;
727
728// Custom vectorization function type. Produce a vector form of Operation*
729// assuming all its vectorized operands are already in the IRMapping.
730// Return nullptr if the Operation cannot be vectorized.
731using CustomVectorizationHook =
732 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
733
734/// Helper function to vectorize the terminator of a `linalgOp`. New result
735/// vector values are appended to `newResults`. Return
736/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
737/// that it should not try to map produced operations and instead return the
738/// results using the `newResults` vector making them available to the
739/// vectorization algorithm for RAUW. This function is meant to be used as a
740/// CustomVectorizationHook.
741static VectorizationHookResult
742vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
743 const IRMapping &bvm, VectorizationState &state,
744 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
745 auto yieldOp = dyn_cast<linalg::YieldOp>(Val: op);
746 if (!yieldOp)
747 return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr};
748 for (const auto &output : llvm::enumerate(First: yieldOp.getValues())) {
749 // TODO: Scan for an opportunity for reuse.
750 // TODO: use a map.
751 Value vectorValue = bvm.lookup(from: output.value());
752 Value newResult =
753 buildVectorWrite(rewriter, value: vectorValue,
754 outputOperand: linalgOp.getDpsInitOperand(i: output.index()), state);
755 if (newResult)
756 newResults.push_back(Elt: newResult);
757 }
758
759 return VectorizationHookResult{.status: VectorizationHookStatus::NoReplace, .newOp: nullptr};
760}
761
762/// Helper function to vectorize the index operations of a `linalgOp`. Return
763/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
764/// should map the produced operations. This function is meant to be used as a
765/// CustomVectorizationHook.
766static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
767 VectorizationState &state,
768 Operation *op,
769 LinalgOp linalgOp) {
770 IndexOp indexOp = dyn_cast<linalg::IndexOp>(Val: op);
771 if (!indexOp)
772 return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr};
773 auto loc = indexOp.getLoc();
774 // Compute the static loop sizes of the index op.
775 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
776 auto dim = indexOp.getDim();
777 // Compute a one-dimensional index vector for the index op dimension.
778 auto indexVectorType =
779 VectorType::get(shape: {targetShape[dim]}, elementType: rewriter.getIndexType(),
780 scalableDims: state.getScalableVecDims()[dim]);
781 auto indexSteps = rewriter.create<vector::StepOp>(location: loc, args&: indexVectorType);
782 // Return the one-dimensional index vector if it lives in the trailing
783 // dimension of the iteration space since the vectorization algorithm in this
784 // case can handle the broadcast.
785 if (dim == targetShape.size() - 1)
786 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: indexSteps};
787 // Otherwise permute the targetShape to move the index dimension last,
788 // broadcast the one-dimensional index vector to the permuted shape, and
789 // finally transpose the broadcasted index vector to undo the permutation.
790 auto permPattern =
791 llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size()));
792 std::swap(a&: permPattern[dim], b&: permPattern.back());
793 auto permMap =
794 AffineMap::getPermutationMap(permutation: permPattern, context: linalgOp.getContext());
795
796 auto broadCastOp = rewriter.create<vector::BroadcastOp>(
797 location: loc, args: state.getCanonicalVecType(elementType: rewriter.getIndexType(), dimPermutation: permMap),
798 args&: indexSteps);
799 SmallVector<int64_t> transposition =
800 llvm::to_vector<16>(Range: llvm::seq<int64_t>(Begin: 0, End: linalgOp.getNumLoops()));
801 std::swap(a&: transposition.back(), b&: transposition[dim]);
802 auto transposeOp =
803 rewriter.create<vector::TransposeOp>(location: loc, args&: broadCastOp, args&: transposition);
804 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: transposeOp};
805}
806
807/// Helper function to check if the tensor.extract can be vectorized by the
808/// custom hook vectorizeTensorExtract.
809static LogicalResult
810tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
811 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(Val: op);
812 if (!extractOp)
813 return failure();
814
815 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
816 return failure();
817
818 // Check the index type, but only for non 0-d tensors (for which we do need
819 // access indices).
820 if (not extractOp.getIndices().empty()) {
821 if (!VectorType::isValidElementType(t: extractOp.getIndices()[0].getType()))
822 return failure();
823 }
824
825 if (!llvm::all_of(Range: extractOp->getResultTypes(),
826 P: VectorType::isValidElementType)) {
827 return failure();
828 }
829
830 return success();
831}
832
833/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
834/// generated from `tensor.extract`. The offset is calculated as follows
835/// (example using scalar values):
836///
837/// offset = extractOp.indices[0]
838/// for (i = 1; i < numIndices; i++)
839/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
840///
841/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
842/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
843static Value calculateGatherOffset(RewriterBase &rewriter,
844 VectorizationState &state,
845 tensor::ExtractOp extractOp,
846 const IRMapping &bvm) {
847 // The vector of indices for GatherOp should be shaped as the output vector.
848 auto indexVecType = state.getCanonicalVecType(elementType: rewriter.getIndexType());
849 auto loc = extractOp.getLoc();
850
851 Value offset = broadcastIfNeeded(
852 b&: rewriter, value: bvm.lookup(from: extractOp.getIndices()[0]), dstType: indexVecType);
853
854 const size_t numIndices = extractOp.getIndices().size();
855 for (size_t i = 1; i < numIndices; i++) {
856 Value dimIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i);
857
858 auto dimSize = broadcastIfNeeded(
859 b&: rewriter,
860 value: rewriter.create<tensor::DimOp>(location: loc, args: extractOp.getTensor(), args&: dimIdx),
861 dstType: indexVecType);
862
863 offset = rewriter.create<arith::MulIOp>(location: loc, args&: offset, args&: dimSize);
864
865 auto extractOpIndex = broadcastIfNeeded(
866 b&: rewriter, value: bvm.lookup(from: extractOp.getIndices()[i]), dstType: indexVecType);
867
868 offset = rewriter.create<arith::AddIOp>(location: loc, args&: extractOpIndex, args&: offset);
869 }
870
871 return offset;
872}
873
874enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
875
876/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
877/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
878/// represents a contiguous load operation.
879///
880/// Note that when calling this hook, it is assumed that the output vector is
881/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
882/// labelled as a gather load before entering this method.
883///
884/// Following on from the above, it is assumed that:
885/// * for statically shaped loops, when no masks are used, only one dim is !=
886/// 1 (that's what the shape of the output vector is based on).
887/// * for dynamically shaped loops, there might be more non-unit dims
888/// as the output vector type is user-specified.
889///
890/// TODO: Statically shaped loops + vector masking
891static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
892 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
893 assert(
894 (linalgOp.hasDynamicShape() ||
895 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
896 "For statically shaped Linalg Ops, only one "
897 "non-unit loop dim is expected");
898 assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
899
900 size_t idx = loopRanges.size() - 1;
901 for (; idx != 0; idx--)
902 if (loopRanges[idx] != 1)
903 break;
904
905 return idx;
906}
907
908/// Checks whether `val` can be used for calculating a loop invariant index.
909static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
910 VectorType resType) {
911
912 assert(((llvm::count_if(resType.getShape(),
913 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
914 "n-D vectors are not yet supported");
915
916 // Blocks outside _this_ linalg.generic are effectively loop invariant.
917 // However, analysing block arguments for _this_ linalg.generic Op is a bit
918 // tricky. Just bail out in the latter case.
919 // TODO: We could try analysing the corresponding affine map here.
920 auto *block = linalgOp.getBlock();
921 if (isa<BlockArgument>(Val: val))
922 return !llvm::is_contained(Range: block->getArguments(), Element: val);
923
924 Operation *defOp = val.getDefiningOp();
925 assert(defOp && "This is neither a block argument nor an operation result");
926
927 // IndexOp is loop invariant as long as its result remains constant across
928 // iterations. Note that for dynamic shapes, the corresponding dim will also
929 // be conservatively treated as != 1.
930 if (auto indexOp = dyn_cast<linalg::IndexOp>(Val: defOp)) {
931 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
932 }
933
934 auto *ancestor = block->findAncestorOpInBlock(op&: *defOp);
935
936 // Values define outside `linalgOp` are loop invariant.
937 if (!ancestor)
938 return true;
939
940 // Values defined inside `linalgOp`, which are constant, are loop invariant.
941 if (isa<arith::ConstantOp>(Val: ancestor))
942 return true;
943
944 bool result = true;
945 for (auto op : ancestor->getOperands())
946 result &= isLoopInvariantIdx(linalgOp, val&: op, resType);
947
948 return result;
949}
950
951/// Check whether `val` could be used for calculating the trailing index for a
952/// contiguous load operation.
953///
954/// There are currently 3 types of values that are allowed here:
955/// 1. loop-invariant values,
956/// 2. values that increment by 1 with every loop iteration,
957/// 3. results of basic arithmetic operations (linear and continuous)
958/// involving 1., 2. and 3.
959/// This method returns True if indeed only such values are used in calculating
960/// `val.`
961///
962/// Additionally, the trailing index for a contiguous load operation should
963/// increment by 1 with every loop iteration, i.e. be based on:
964/// * `linalg.index <dim>` ,
965/// where <dim> is the trailing non-unit dim of the iteration space (this way,
966/// `linalg.index <dim>` increments by 1 with every loop iteration).
967/// `foundIndexOp` is updated to `true` when such Op is found.
968static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
969 bool &foundIndexOp, VectorType resType) {
970
971 assert(((llvm::count_if(resType.getShape(),
972 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973 "n-D vectors are not yet supported");
974
975 // Blocks outside _this_ linalg.generic are effectively loop invariant.
976 // However, analysing block arguments for _this_ linalg.generic Op is a bit
977 // tricky. Just bail out in the latter case.
978 // TODO: We could try analysing the corresponding affine map here.
979 auto *block = linalgOp.getBlock();
980 if (isa<BlockArgument>(Val: val))
981 return !llvm::is_contained(Range: block->getArguments(), Element: val);
982
983 Operation *defOp = val.getDefiningOp();
984 assert(defOp && "This is neither a block argument nor an operation result");
985
986 if (auto indexOp = dyn_cast<linalg::IndexOp>(Val: defOp)) {
987 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
988
989 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
990 return true;
991 }
992
993 auto *ancestor = block->findAncestorOpInBlock(op&: *defOp);
994
995 if (!ancestor)
996 return false;
997
998 // Conservatively reject Ops that could lead to indices with stride other
999 // than 1.
1000 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(Val: ancestor))
1001 return false;
1002
1003 bool result = false;
1004 for (auto op : ancestor->getOperands())
1005 result |= isContiguousLoadIdx(linalgOp, val&: op, foundIndexOp, resType);
1006
1007 return result;
1008}
1009
1010/// Infer the memory access pattern for the input ExtractOp
1011///
1012/// Based on the ExtratOp result shape and the access indices, decides whether
1013/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1014/// or a gather load. When analysing the ExtractOp indices (to identify
1015/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1016/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1017/// Op).
1018///
1019/// Note that it is always safe to use gather load operations for contiguous
1020/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1021/// that `extractOp` is a gather load.
1022static VectorMemoryAccessKind
1023getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1024 LinalgOp &linalgOp, VectorType resType) {
1025
1026 auto inputShape = cast<ShapedType>(Val: extractOp.getTensor().getType());
1027
1028 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1029 if (inputShape.getShape().empty())
1030 return VectorMemoryAccessKind::ScalarBroadcast;
1031
1032 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1033 // otherwise.
1034 bool isOutput1DVector =
1035 (llvm::count_if(Range: resType.getShape(),
1036 P: [](int64_t dimSize) { return dimSize > 1; }) == 1);
1037 // 1. Assume that it's a gather load when reading non-1D vector.
1038 if (!isOutput1DVector)
1039 return VectorMemoryAccessKind::Gather;
1040
1041 bool leadingIdxsLoopInvariant = true;
1042
1043 // 2. Analyze the leading indices of `extractOp`.
1044 // Look at the way each index is calculated and decide whether it is suitable
1045 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1046 // gather load.
1047 auto indices = extractOp.getIndices();
1048 auto leadIndices = indices.drop_back(n: 1);
1049
1050 for (auto [i, indexVal] : llvm::enumerate(First&: leadIndices)) {
1051 if (inputShape.getShape()[i] == 1)
1052 continue;
1053
1054 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, val&: indexVal, resType);
1055 }
1056
1057 if (!leadingIdxsLoopInvariant) {
1058 LDBG("Found gather load: " << extractOp);
1059 return VectorMemoryAccessKind::Gather;
1060 }
1061
1062 // 3. Analyze the trailing index for `extractOp`.
1063 // At this point we know that the leading indices are loop invariant. This
1064 // means that is potentially a scalar or a contiguous load. We can decide
1065 // based on the trailing idx.
1066 auto extractOpTrailingIdx = indices.back();
1067
1068 // 3a. Scalar broadcast load
1069 // If the trailing index is loop invariant then this is a scalar load.
1070 if (leadingIdxsLoopInvariant &&
1071 isLoopInvariantIdx(linalgOp, val&: extractOpTrailingIdx, resType)) {
1072 LDBG("Found scalar broadcast load: " << extractOp);
1073
1074 return VectorMemoryAccessKind::ScalarBroadcast;
1075 }
1076
1077 // 3b. Contiguous loads
1078 // The trailing `extractOp` index should increment with every loop iteration.
1079 // This effectively means that it must be based on the trailing loop index.
1080 // This is what the following bool captures.
1081 bool foundIndexOp = false;
1082 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, val&: extractOpTrailingIdx,
1083 foundIndexOp, resType);
1084 // TODO: Support generating contiguous loads for column vectors - that will
1085 // require adding a permutation map to tranfer_read Ops.
1086 bool isRowVector = resType.getShape().back() != 1;
1087 isContiguousLoad &= (foundIndexOp && isRowVector);
1088
1089 if (isContiguousLoad) {
1090 LDBG("Found contigous load: " << extractOp);
1091 return VectorMemoryAccessKind::Contiguous;
1092 }
1093
1094 // 4. Fallback case - gather load.
1095 LDBG("Found gather load: " << extractOp);
1096 return VectorMemoryAccessKind::Gather;
1097}
1098
1099/// Helper function to vectorize the tensor.extract operations. Returns
1100/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1101/// should map the produced operations. This function is meant to be used as a
1102/// CustomVectorizationHook.
1103static VectorizationHookResult
1104vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1105 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1106 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(Val: op);
1107 if (!extractOp)
1108 return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr};
1109 auto loc = extractOp.getLoc();
1110
1111 // Compute the static loop sizes of the extract op.
1112 auto resultType = state.getCanonicalVecType(elementType: extractOp.getResult().getType());
1113 auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1114 location: loc,
1115 args: DenseIntElementsAttr::get(type: state.getCanonicalVecType(elementType: rewriter.getI1Type()),
1116 /*value=*/arg: true));
1117 auto passThruConstantOp =
1118 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getZeroAttr(type: resultType));
1119
1120 // Base indices are currently set to 0. We will need to re-visit if more
1121 // generic scenarios are to be supported.
1122 SmallVector<Value> baseIndices(
1123 extractOp.getIndices().size(),
1124 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
1125
1126 VectorMemoryAccessKind memAccessKind =
1127 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resType: resultType);
1128
1129 // 1. Handle gather access
1130 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1131 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1132
1133 // Generate the gather load
1134 Operation *gatherOp = rewriter.create<vector::GatherOp>(
1135 location: loc, args&: resultType, args: extractOp.getTensor(), args&: baseIndices, args&: offset,
1136 args&: maskConstantOp, args&: passThruConstantOp);
1137 gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp);
1138
1139 LDBG("Vectorised as gather load: " << extractOp << "\n");
1140 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: gatherOp};
1141 }
1142
1143 // 2. Handle:
1144 // a. scalar loads + broadcast,
1145 // b. contiguous loads.
1146 // Both cases use vector.transfer_read.
1147
1148 // Collect indices for `vector.transfer_read`. At this point, the indices will
1149 // either be scalars or would have been broadcast to vectors matching the
1150 // result type. For indices that are vectors, there are two options:
1151 // * for non-trailing indices, all elements are identical (contiguous
1152 // loads are identified by looking for non-trailing indices that are
1153 // invariant with respect to the corresponding linalg.generic), or
1154 // * for trailing indices, the index vector will contain values with stride
1155 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1156 // needed.
1157 // This means that
1158 // * for scalar indices - just re-use it,
1159 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1160 // (0th) element and use that.
1161 SmallVector<Value> transferReadIdxs;
1162 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1163 Value idx = bvm.lookup(from: extractOp.getIndices()[i]);
1164 if (idx.getType().isIndex()) {
1165 transferReadIdxs.push_back(Elt: idx);
1166 continue;
1167 }
1168
1169 auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1170 location: loc,
1171 args: VectorType::get(shape: resultType.getShape().back(), elementType: rewriter.getIndexType(),
1172 scalableDims: resultType.getScalableDims().back()),
1173 args&: idx);
1174 transferReadIdxs.push_back(
1175 Elt: rewriter.create<vector::ExtractOp>(location: loc, args&: indexAs1dVector, args: 0));
1176 }
1177
1178 // `tensor.extract_element` is always in-bounds, hence the following holds.
1179 auto dstRank = resultType.getRank();
1180 auto srcRank = extractOp.getTensor().getType().getRank();
1181 SmallVector<bool> inBounds(dstRank, true);
1182
1183 // 2a. Handle scalar broadcast access.
1184 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1185 MLIRContext *ctx = rewriter.getContext();
1186 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx));
1187 auto permutationMap = AffineMap::get(dimCount: srcRank, symbolCount: 0, results: exprs, context: ctx);
1188
1189 auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1190 location: loc, args&: resultType, args: extractOp.getTensor(), args&: transferReadIdxs,
1191 /*padding=*/args: std::nullopt, args&: permutationMap, args&: inBounds);
1192
1193 // Mask this broadcasting xfer_read here rather than relying on the generic
1194 // path (the generic path assumes identity masking map, which wouldn't be
1195 // valid here).
1196 SmallVector<int64_t> readMaskShape = {1};
1197 auto readMaskType = VectorType::get(shape: readMaskShape, elementType: rewriter.getI1Type());
1198 auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1199 location: loc, args&: readMaskType, args: vector::ConstantMaskKind::AllTrue);
1200 auto *maskedReadOp =
1201 mlir::vector::maskOperation(builder&: rewriter, maskableOp: transferReadOp, mask: allTrue);
1202
1203 LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1204 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp,
1205 .newOp: maskedReadOp};
1206 }
1207
1208 // 2b. Handle contiguous access.
1209 auto permutationMap = AffineMap::getMinorIdentityMap(
1210 dims: srcRank, results: std::min(a: dstRank, b: srcRank), context: rewriter.getContext());
1211
1212 int32_t rankDiff = dstRank - srcRank;
1213 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1214 // dims so that the ranks match. This is done by extending the map with 0s.
1215 // For example, for dstRank = 3, srcRank = 2, the following map created
1216 // above:
1217 // (d0, d1) --> (d0, d1)
1218 // is extended as:
1219 // (d0, d1) --> (0, d0, d1)
1220 while (rankDiff > 0) {
1221 permutationMap = permutationMap.insertResult(
1222 expr: mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), pos: 0);
1223 rankDiff--;
1224 }
1225
1226 auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1227 location: loc, args&: resultType, args: extractOp.getTensor(), args&: transferReadIdxs,
1228 /*padding=*/args: std::nullopt, args&: permutationMap, args&: inBounds);
1229
1230 LDBG("Vectorised as contiguous load: " << extractOp);
1231 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp,
1232 .newOp: transferReadOp};
1233}
1234
1235/// Emit reduction operations if the shapes of the value to reduce is different
1236/// that the result shape.
1237// Note: this is a true builder that notifies the OpBuilder listener.
1238// TODO: Consider moving as a static helper on the ReduceOp.
1239static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1240 Value reduceValue, Value initialValue,
1241 const IRMapping &bvm) {
1242 Value reduceVec = bvm.lookup(from: reduceValue);
1243 Value outputVec = bvm.lookup(from: initialValue);
1244 auto reduceType = dyn_cast<VectorType>(Val: reduceVec.getType());
1245 auto outputType = dyn_cast<VectorType>(Val: outputVec.getType());
1246 // Reduce only if needed as the value may already have been reduce for
1247 // contraction vectorization.
1248 if (!reduceType ||
1249 (outputType && reduceType.getShape() == outputType.getShape()))
1250 return nullptr;
1251 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1252 return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask);
1253}
1254
1255/// Generic vectorization for a single operation `op`, given already vectorized
1256/// operands carried by `bvm`. Vectorization occurs as follows:
1257/// 1. Try to apply any of the `customVectorizationHooks` and return its
1258/// result on success.
1259/// 2. Clone any constant in the current scope without vectorization: each
1260/// consumer of the constant will later determine the shape to which the
1261/// constant needs to be broadcast to.
1262/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1263/// of the `customVectorizationHooks` to cover such cases.
1264/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1265/// operand of maximal rank. Other operands have smaller rank and are
1266/// broadcast accordingly. It is assumed this broadcast is always legal,
1267/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1268///
1269/// This function assumes all operands of `op` have been vectorized and are in
1270/// the `bvm` mapping. As a consequence, this function is meant to be called on
1271/// a topologically-sorted list of ops.
1272/// This function does not update `bvm` but returns a VectorizationHookStatus
1273/// that instructs the caller what `bvm` update needs to occur.
1274static VectorizationHookResult
1275vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1276 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1277 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1278 LDBG("vectorize op " << *op << "\n");
1279
1280 // 1. Try to apply any CustomVectorizationHook.
1281 if (!customVectorizationHooks.empty()) {
1282 for (auto &customFunc : customVectorizationHooks) {
1283 VectorizationHookResult result = customFunc(op, bvm);
1284 if (result.status == VectorizationHookStatus::Failure)
1285 continue;
1286 return result;
1287 }
1288 }
1289
1290 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1291 // Clone so that the constant is not confined to the linalgOp block .
1292 if (isa<arith::ConstantOp, func::ConstantOp>(Val: op))
1293 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp,
1294 .newOp: rewriter.clone(op&: *op)};
1295
1296 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1297 if (!OpTrait::hasElementwiseMappableTraits(op))
1298 return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr};
1299
1300 // 4 . Check if the operation is a reduction.
1301 SmallVector<std::pair<Value, Value>> reductionOperands;
1302 for (Value operand : op->getOperands()) {
1303 auto blockArg = dyn_cast<BlockArgument>(Val&: operand);
1304 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1305 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1306 continue;
1307 SmallVector<Operation *> reductionOps;
1308 Value reduceValue = matchReduction(
1309 iterCarriedArgs: linalgOp.getRegionOutputArgs(),
1310 redPos: blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), combinerOps&: reductionOps);
1311 if (!reduceValue)
1312 continue;
1313 reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand));
1314 }
1315 if (!reductionOperands.empty()) {
1316 assert(reductionOperands.size() == 1);
1317 Operation *reduceOp =
1318 reduceIfNeeded(b&: rewriter, linalgOp, op, reduceValue: reductionOperands[0].first,
1319 initialValue: reductionOperands[0].second, bvm);
1320 if (reduceOp)
1321 return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: reduceOp};
1322 }
1323
1324 // 5. Generic vectorization path for ElementwiseMappable ops.
1325 // a. Get the first max ranked shape.
1326 VectorType firstMaxRankedType;
1327 for (Value operand : op->getOperands()) {
1328 auto vecOperand = bvm.lookup(from: operand);
1329 assert(vecOperand && "Vector operand couldn't be found");
1330
1331 auto vecType = dyn_cast<VectorType>(Val: vecOperand.getType());
1332 if (vecType && (!firstMaxRankedType ||
1333 firstMaxRankedType.getRank() < vecType.getRank()))
1334 firstMaxRankedType = vecType;
1335 }
1336 // b. Broadcast each op if needed.
1337 SmallVector<Value> vecOperands;
1338 for (Value scalarOperand : op->getOperands()) {
1339 Value vecOperand = bvm.lookup(from: scalarOperand);
1340 assert(vecOperand && "Vector operand couldn't be found");
1341
1342 if (firstMaxRankedType) {
1343 auto vecType = VectorType::get(shape: firstMaxRankedType.getShape(),
1344 elementType: getElementTypeOrSelf(type: vecOperand.getType()),
1345 scalableDims: firstMaxRankedType.getScalableDims());
1346 vecOperands.push_back(Elt: broadcastIfNeeded(b&: rewriter, value: vecOperand, dstType: vecType));
1347 } else {
1348 vecOperands.push_back(Elt: vecOperand);
1349 }
1350 }
1351 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1352 SmallVector<Type> resultTypes;
1353 for (Type resultType : op->getResultTypes()) {
1354 resultTypes.push_back(
1355 Elt: firstMaxRankedType
1356 ? VectorType::get(shape: firstMaxRankedType.getShape(), elementType: resultType,
1357 scalableDims: firstMaxRankedType.getScalableDims())
1358 : resultType);
1359 }
1360 // d. Build and return the new op.
1361 return VectorizationHookResult{
1362 .status: VectorizationHookStatus::NewOp,
1363 .newOp: rewriter.create(loc: op->getLoc(), opName: op->getName().getIdentifier(), operands: vecOperands,
1364 types: resultTypes, attributes: op->getAttrs())};
1365}
1366
1367/// Generic vectorization function that rewrites the body of a `linalgOp` into
1368/// vector form. Generic vectorization proceeds as follows:
1369/// 1. Verify the `linalgOp` has one non-empty region.
1370/// 2. Values defined above the region are mapped to themselves and will be
1371/// broadcasted on a per-need basis by their consumers.
1372/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1373/// load).
1374/// TODO: Reuse opportunities for RAR dependencies.
1375/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1376/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1377/// iteration indices.
1378/// 5. Iteratively call vectorizeOneOp on the region operations.
1379///
1380/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1381/// performed to the maximal common vector size implied by the `linalgOp`
1382/// iteration space. This eager broadcasting is introduced in the
1383/// permutation_map of the vector.transfer_read operations. The eager
1384/// broadcasting makes it trivial to determine where broadcast, transposes and
1385/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1386/// the absence of good canonicalizations, the amount of work increases.
1387/// This is not deemed a problem as we expect canonicalizations and foldings to
1388/// aggressively clean up the useless work.
1389static LogicalResult
1390vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1391 LinalgOp linalgOp,
1392 SmallVectorImpl<Value> &newResults) {
1393 LDBG("Vectorizing operation as linalg generic\n");
1394 Block *block = linalgOp.getBlock();
1395
1396 // 2. Values defined above the region can only be broadcast for now. Make them
1397 // map to themselves.
1398 IRMapping bvm;
1399 SetVector<Value> valuesSet;
1400 mlir::getUsedValuesDefinedAbove(regions: linalgOp->getRegion(index: 0), values&: valuesSet);
1401 bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef());
1402
1403 if (linalgOp.getNumDpsInits() == 0)
1404 return failure();
1405
1406 // 3. Turn all BBArgs into vector.transfer_read / load.
1407 Location loc = linalgOp.getLoc();
1408 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1409 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1410 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1411 if (linalgOp.isScalar(opOperand)) {
1412 bvm.map(from: bbarg, to: opOperand->get());
1413 continue;
1414 }
1415
1416 // 3.a. Convert the indexing map for this input/output to a transfer read
1417 // permutation map and masking map.
1418 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1419
1420 AffineMap readMap;
1421 VectorType readType;
1422 Type elemType = getElementTypeOrSelf(val: opOperand->get());
1423 if (linalgOp.isDpsInput(opOperand)) {
1424 // 3.a.i. For input reads we use the canonical vector shape.
1425 readMap = inverseAndBroadcastProjectedPermutation(map: indexingMap);
1426 readType = state.getCanonicalVecType(elementType: elemType);
1427 } else {
1428 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1429 // reductions), the vector shape is computed by mapping the canonical
1430 // vector shape to the output domain and back to the canonical domain.
1431 readMap = inversePermutation(map: reindexIndexingMap(map: indexingMap));
1432 readType =
1433 state.getCanonicalVecType(elementType: elemType, dimPermutation: readMap.compose(map: indexingMap));
1434 }
1435
1436 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1437
1438 Operation *read = rewriter.create<vector::TransferReadOp>(
1439 location: loc, args&: readType, args: opOperand->get(), args&: indices,
1440 /*padding=*/args: std::nullopt, args&: readMap);
1441 read = state.maskOperation(rewriter, opToMask: read, linalgOp, maybeIndexingMap: indexingMap);
1442 Value readValue = read->getResult(idx: 0);
1443
1444 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1445 // will be in-bounds.
1446 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(Val: read)) {
1447 SmallVector<bool> inBounds(readType.getRank(), true);
1448 cast<vector::TransferReadOp>(Val: maskOp.getMaskableOp())
1449 .setInBoundsAttr(rewriter.getBoolArrayAttr(values: inBounds));
1450 }
1451
1452 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1453 // TODO: remove this.
1454 if (readType.getRank() == 0)
1455 readValue = rewriter.create<vector::ExtractOp>(location: loc, args&: readValue,
1456 args: ArrayRef<int64_t>());
1457
1458 LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1459 << "\n");
1460 bvm.map(from: bbarg, to: readValue);
1461 bvm.map(from: opOperand->get(), to: readValue);
1462 }
1463
1464 SmallVector<CustomVectorizationHook> hooks;
1465 // 4a. Register CustomVectorizationHook for yieldOp.
1466 CustomVectorizationHook vectorizeYield =
1467 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1468 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1469 };
1470 hooks.push_back(Elt: vectorizeYield);
1471
1472 // 4b. Register CustomVectorizationHook for indexOp.
1473 CustomVectorizationHook vectorizeIndex =
1474 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1475 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1476 };
1477 hooks.push_back(Elt: vectorizeIndex);
1478
1479 // 4c. Register CustomVectorizationHook for extractOp.
1480 CustomVectorizationHook vectorizeExtract =
1481 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1482 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1483 };
1484 hooks.push_back(Elt: vectorizeExtract);
1485
1486 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1487 for (Operation &op : block->getOperations()) {
1488 VectorizationHookResult result =
1489 vectorizeOneOp(rewriter, state, linalgOp, op: &op, bvm, customVectorizationHooks: hooks);
1490 if (result.status == VectorizationHookStatus::Failure) {
1491 LDBG("failed to vectorize: " << op << "\n");
1492 return failure();
1493 }
1494 if (result.status == VectorizationHookStatus::NewOp) {
1495 Operation *maybeMaskedOp =
1496 state.maskOperation(rewriter, opToMask: result.newOp, linalgOp);
1497 LDBG("New vector op: " << *maybeMaskedOp << "\n");
1498 bvm.map(from: op.getResults(), to: maybeMaskedOp->getResults());
1499 }
1500 }
1501
1502 return success();
1503}
1504
1505/// Given a linalg::PackOp, return the `dest` shape before any packing
1506/// permutations.
1507static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1508 ArrayRef<int64_t> destShape) {
1509 return applyPermutation(input: destShape, permutation: linalg::getPackInverseDestPerm(packOp));
1510}
1511
1512/// Determines whether a mask for xfer_write is trivially "all true"
1513///
1514/// Given all the inputs required to generate a mask (mask sizes and shapes),
1515/// and an xfer_write operation (write indices and the destination tensor
1516/// shape), determines whether the corresponding mask would be trivially
1517/// foldable (i.e., trivially "all true").
1518///
1519/// Use this method to avoid generating spurious masks and relaying on
1520/// vectorization post-processing to remove them.
1521///
1522/// Pre-conditions for a mask to be trivially foldable:
1523/// * All involved shapes (mask + destination tensor) are static.
1524/// * All write indices are constant.
1525/// * All mask sizes are constant (including `arith.constant`).
1526///
1527/// If the pre-conditions are met, the method checks for each destination
1528/// dimension `d`:
1529/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1530/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1531///
1532/// rankDiff = rank(dest) - rank(mask).
1533///
1534/// This method takes a conservative view: it may return false even if the mask
1535/// is technically foldable.
1536///
1537/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1538/// of the dest tensor):
1539/// %c0 = arith.constant 0 : index
1540/// %mask = vector.create_mask 5, 1
1541/// vector.mask %mask {
1542/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1543/// {in_bounds = [true, true]}
1544/// : vector<5x1xi32>, tensor<5x1xi32>
1545/// }
1546///
1547/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1548/// mask is required to avoid out-of-bounds write):
1549/// %c0 = arith.constant 0 : index
1550/// %mask = vector.create_mask 5, 1
1551/// vector.mask %mask {
1552/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1553/// {in_bounds = [true, true]}
1554/// : vector<8x1xi32>, tensor<5x1xi32>
1555/// }
1556///
1557/// TODO: Re-use in createReadOrMaskedRead
1558static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
1559 SmallVector<Value> &writeIdxs,
1560 ArrayRef<int64_t> destShape,
1561 ArrayRef<int64_t> maskShape) {
1562 // Masking is unavoidable in the case of dynamic tensors.
1563 if (ShapedType::isDynamicShape(dSizes: destShape))
1564 return false;
1565
1566 // Collect all constant mask sizes.
1567 SmallVector<int64_t, 4> cstMaskSizes;
1568 for (auto [i, dimSize] : llvm::enumerate(First&: maskSizes)) {
1569 if (auto intSize = getConstantIntValue(ofr: dimSize)) {
1570 cstMaskSizes.push_back(Elt: *intSize);
1571 }
1572 }
1573
1574 // If any of the mask sizes is non-constant, bail out.
1575 if (cstMaskSizes.size() != maskShape.size())
1576 return false;
1577
1578 // Collect all constant write indices.
1579 SmallVector<int64_t, 4> cstWriteIdxs;
1580 for (auto [i, idx] : llvm::enumerate(First&: writeIdxs)) {
1581 APSInt intVal;
1582 if (matchPattern(value: idx, pattern: m_ConstantInt(bind_value: &intVal))) {
1583 cstWriteIdxs.push_back(Elt: intVal.getSExtValue());
1584 }
1585 }
1586
1587 // If any of the write indices is non-constant, bail out.
1588 if (cstWriteIdxs.size() != destShape.size())
1589 return false;
1590
1591 // Go over all destination dims and check (1) and (2). Take into account that:
1592 // * The number of mask sizes will match the rank of the vector to store.
1593 // This could be lower than the rank of the destination tensor.
1594 // * Mask sizes could be larger than the corresponding mask shape (hence
1595 // `clamp`).
1596 // TODO: The 2nd item should be rejected by the verifier.
1597 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1598 for (auto [i, idx] : llvm::enumerate(First&: cstMaskSizes)) {
1599 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1600 /*(2)*/ destShape[rankDiff + i] <
1601 (std::clamp(val: cstMaskSizes[i], lo: int64_t(0), hi: maskShape[i]) +
1602 cstWriteIdxs[i]))
1603 return false;
1604 }
1605
1606 return true;
1607}
1608
1609/// Creates an optionally masked TransferWriteOp
1610///
1611/// Generates the following operation:
1612/// %res = vector.transfer_write %vecToStore into %dest
1613///
1614/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1615///
1616/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1617/// %res = vector.mask %mask {
1618/// vector.transfer_write %vecToStore into %dest
1619/// }
1620///
1621/// The mask shape is identical to `vecToStore` (with the element type ==
1622/// i1), and the mask values are based on the shape of the `dest` tensor.
1623///
1624/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1625/// is used instead of masking:
1626///
1627/// %write = vector.transfer_write %vecToStore into %dest
1628/// in_bounds_flags = (...)
1629/// %res = vector.transfer_write %input into %dest
1630/// {in_bounds = in_bounds_flags}
1631///
1632/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1633/// are set to 0.
1634static Operation *
1635createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1636 Value dest, SmallVector<Value> writeIndices = {},
1637 bool useInBoundsInsteadOfMasking = false) {
1638
1639 ShapedType destType = cast<ShapedType>(Val: dest.getType());
1640 int64_t destRank = destType.getRank();
1641 auto destShape = destType.getShape();
1642
1643 VectorType vecToStoreType = cast<VectorType>(Val: vecToStore.getType());
1644 int64_t vecToStoreRank = vecToStoreType.getRank();
1645 auto vecToStoreShape = vecToStoreType.getShape();
1646
1647 // Compute the in_bounds attribute
1648 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1649 if (useInBoundsInsteadOfMasking) {
1650 // Update the inBounds attribute.
1651 // FIXME: This computation is too weak - it ignores the write indices.
1652 for (unsigned i = 0; i < vecToStoreRank; i++)
1653 inBoundsVal[i] =
1654 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1655 ShapedType::isStatic(dValue: destShape[destRank - vecToStoreRank + i]);
1656 }
1657
1658 // If missing, initialize the write indices to 0.
1659 assert((writeIndices.empty() ||
1660 writeIndices.size() == static_cast<size_t>(destRank)) &&
1661 "Invalid number of write indices!");
1662 if (writeIndices.empty()) {
1663 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
1664 writeIndices.assign(NumElts: destRank, Elt: zero);
1665 }
1666
1667 // Generate the xfer_write Op
1668 Operation *write =
1669 builder.create<vector::TransferWriteOp>(location: loc,
1670 /*vector=*/args&: vecToStore,
1671 /*source=*/args&: dest,
1672 /*indices=*/args&: writeIndices,
1673 /*inBounds=*/args&: inBoundsVal);
1674
1675 // If masking is disabled, exit.
1676 if (useInBoundsInsteadOfMasking)
1677 return write;
1678
1679 // Check if masking is needed. If not, exit.
1680 if (llvm::equal(LRange&: vecToStoreShape, RRange: destShape.take_back(N: vecToStoreRank)))
1681 return write;
1682
1683 // Compute the mask and mask the write Op.
1684 auto writeMaskType = VectorType::get(shape: vecToStoreShape, elementType: builder.getI1Type());
1685
1686 SmallVector<OpFoldResult> destSizes =
1687 tensor::getMixedSizes(builder, loc, value: dest);
1688 SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1689 destSizes.end());
1690
1691 if (isMaskTriviallyFoldable(maskSizes, writeIdxs&: writeIndices, destShape,
1692 maskShape: vecToStoreShape))
1693 return write;
1694
1695 Value maskForWrite =
1696 builder.createOrFold<vector::CreateMaskOp>(location: loc, args&: writeMaskType, args&: maskSizes);
1697 return mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite);
1698}
1699
1700/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1701/// padding value and (3) input vector sizes into:
1702///
1703/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1704///
1705/// As in the following example:
1706/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1707/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1708///
1709/// This pack would be vectorized to:
1710///
1711/// %load = vector.mask %mask {
1712/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1713/// {in_bounds = [true, true, true]} :
1714/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1715/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1716/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1717/// to vector<32x4x2x1x16xf32>
1718/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1719/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1720/// %write = vector.transfer_write %transpose,
1721/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1722/// {in_bounds = [true, true, true, true, true]}
1723/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1724///
1725/// If the (3) input vector sizes are not provided, the vector sizes are
1726/// determined by the result tensor shape and the `in_bounds`
1727/// attribute is used instead of masking to mark out-of-bounds accesses.
1728///
1729/// NOTE: The input vector sizes specify the dimensions corresponding to the
1730/// outer dimensions of the output tensor. The remaining dimensions are
1731/// computed based on, e.g., the static inner tiles.
1732/// Supporting dynamic inner tiles will require the user to specify the
1733/// missing vector sizes. This is left as a TODO.
1734static LogicalResult
1735vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1736 ArrayRef<int64_t> inputVectorSizes,
1737 SmallVectorImpl<Value> &newResults) {
1738 // TODO: Introduce a parent class that will handle the insertion point update.
1739 OpBuilder::InsertionGuard g(rewriter);
1740 rewriter.setInsertionPoint(packOp);
1741
1742 Location loc = packOp.getLoc();
1743 auto padValue = packOp.getPaddingValue();
1744 if (!padValue) {
1745 padValue = rewriter.create<arith::ConstantOp>(
1746 location: loc, args: rewriter.getZeroAttr(type: packOp.getSourceType().getElementType()));
1747 }
1748 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1749 LogicalResult status =
1750 cast<ReifyRankedShapedTypeOpInterface>(Val: packOp.getOperation())
1751 .reifyResultShapes(builder&: rewriter, reifiedReturnShapes);
1752 (void)status; // prevent unused variable warning on non-assert builds.
1753 assert(succeeded(status) && "failed to reify result shapes");
1754
1755 // If the input vector sizes are not provided, then the vector sizes are
1756 // determined by the result tensor shape. In case the vector sizes aren't
1757 // provided, we update the inBounds attribute instead of masking.
1758 bool useInBoundsInsteadOfMasking = false;
1759 if (inputVectorSizes.empty()) {
1760 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1761 inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank());
1762 useInBoundsInsteadOfMasking = true;
1763 }
1764
1765 // Create masked TransferReadOp.
1766 SmallVector<int64_t> inputShape(inputVectorSizes);
1767 auto innerTiles = packOp.getStaticInnerTiles();
1768 auto innerDimsPos = packOp.getInnerDimsPos();
1769 auto outerDimsPerm = packOp.getOuterDimsPerm();
1770 if (!outerDimsPerm.empty())
1771 applyPermutationToVector(inVec&: inputShape,
1772 permutation: invertPermutationVector(permutation: outerDimsPerm));
1773 for (auto [idx, size] : enumerate(First&: innerTiles))
1774 inputShape[innerDimsPos[idx]] *= size;
1775 auto maskedRead = vector::createReadOrMaskedRead(
1776 builder&: rewriter, loc, source: packOp.getSource(), inputVectorSizes: inputShape, padValue,
1777 useInBoundsInsteadOfMasking);
1778
1779 // Create ShapeCastOp.
1780 SmallVector<int64_t> destShape(inputVectorSizes);
1781 destShape.append(in_start: innerTiles.begin(), in_end: innerTiles.end());
1782 auto tiledPackType = VectorType::get(shape: getTiledPackShape(packOp, destShape),
1783 elementType: packOp.getDestType().getElementType());
1784 auto shapeCastOp =
1785 rewriter.create<vector::ShapeCastOp>(location: loc, args&: tiledPackType, args&: maskedRead);
1786
1787 // Create TransposeOp.
1788 auto destPermutation =
1789 invertPermutationVector(permutation: getPackInverseDestPerm(packOp));
1790 auto transposeOp = rewriter.create<vector::TransposeOp>(
1791 location: loc, args: shapeCastOp.getResult(), args&: destPermutation);
1792
1793 // Create TransferWriteOp.
1794 Value dest = rewriter.create<tensor::EmptyOp>(
1795 location: loc, args&: reifiedReturnShapes[0],
1796 args: transposeOp.getResult().getType().getElementType());
1797 Operation *write =
1798 createWriteOrMaskedWrite(builder&: rewriter, loc, vecToStore: transposeOp.getResult(), dest);
1799 newResults.push_back(Elt: write->getResult(idx: 0));
1800 return success();
1801}
1802
1803/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1804/// Vector::TransferReadOp - Reads a vector from the source tensor
1805/// vector::TransposeOp - Transpose the Source tensor
1806/// ShapeCastOp - Reshape the data based on the target.
1807/// vector::TransferWriteOp. - Write the result vector back to the destination
1808/// tensor.
1809/// If the vector sizes are not provided:
1810/// * the vector sizes are determined by the input operand and attributes,
1811/// * update the inBounds attribute instead of masking.
1812static LogicalResult
1813vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1814 ArrayRef<int64_t> inputVectorSizes,
1815 SmallVectorImpl<Value> &newResults) {
1816
1817 // TODO: Introduce a parent class that will handle the insertion point update.
1818 OpBuilder::InsertionGuard g(rewriter);
1819 rewriter.setInsertionPoint(unpackOp);
1820
1821 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1822
1823 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1824 ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1825 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1826 bool useInBoundsInsteadOfMasking = false;
1827 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1828
1829 auto destSize = unpackOp.getDestRank();
1830
1831 if (!inputVectorSizes.empty())
1832 assert(inputVectorSizes.size() == destSize &&
1833 "Incorrect number of input vector sizes");
1834
1835 // vectorSizes is the shape of the vector that will be used to do final
1836 // write on the destination tensor. It is set like this: Let's say the
1837 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1838 // Thus:
1839 // 1. vectorSizes = sourceShape.take_front(N)
1840 // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1841 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1842 // innerTiles attribute value.
1843 SmallVector<int64_t> vectorSizes(inputVectorSizes);
1844 if (vectorSizes.empty()) {
1845 llvm::append_range(C&: vectorSizes, R: sourceShape.take_front(N: destSize));
1846 if (!outerDimsPerm.empty())
1847 applyPermutationToVector(inVec&: vectorSizes, permutation: outerDimsPerm);
1848 for (auto [i, pos] : llvm::enumerate(First&: innerDimPos))
1849 vectorSizes[pos] *= innerTiles[i];
1850
1851 useInBoundsInsteadOfMasking = true;
1852 }
1853
1854 // readVectorSizes is the size of tensor used to read and apply mask. It is
1855 // set like this: Let's say the vectorSize (VS) array is size 'N' and
1856 // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1857 // size M-N
1858 // Thus:
1859 // - initially: readVectorSizes = vectorInputSizes
1860 // - Divide all the readMaskShape locations pointed by innerDimPos
1861 // by the innerTileSize attribute value.
1862 // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1863 // - Append the remaining shape from SS
1864 // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1865 // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1866 // 128] and outer_dims_perm is [1, 0] then read shape is:
1867 // ReadVectorSizes(initial): [512, 128]
1868 // Final Value(after innerDim Adjustment): [512/32, 128/16]
1869 // = [16, 8]
1870 // After applying outer_dims_perm: [8, 16]
1871 // After appending the rest of the sourceShape: [8, 16, 32, 16]
1872
1873 SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1874
1875 for (auto [index, size] : enumerate(First&: innerTiles)) {
1876 readVectorSizes[innerDimPos[index]] =
1877 llvm::divideCeil(Numerator: readVectorSizes[innerDimPos[index]], Denominator: size);
1878 }
1879 if (!outerDimsPerm.empty()) {
1880 applyPermutationToVector(inVec&: readVectorSizes, permutation: outerDimsPerm);
1881 }
1882 readVectorSizes.append(in_start: sourceShape.begin() + vectorSizes.size(),
1883 in_end: sourceShape.end());
1884
1885 ReifiedRankedShapedTypeDims reifiedRetShapes;
1886 LogicalResult status =
1887 cast<ReifyRankedShapedTypeOpInterface>(Val: unpackOp.getOperation())
1888 .reifyResultShapes(builder&: rewriter, reifiedReturnShapes&: reifiedRetShapes);
1889 if (status.failed()) {
1890 LDBG("Unable to reify result shapes of " << unpackOp);
1891 return failure();
1892 }
1893 Location loc = unpackOp->getLoc();
1894
1895 auto padValue = rewriter.create<arith::ConstantOp>(
1896 location: loc, args: rewriter.getZeroAttr(type: unpackOp.getSourceType().getElementType()));
1897
1898 // Read result, mask if necessary. If transferReadOp shape is not equal
1899 // to shape of source, then a mask is necessary.
1900 Value readResult = vector::createReadOrMaskedRead(
1901 builder&: rewriter, loc, source: unpackOp.getSource(), inputVectorSizes: readVectorSizes, padValue,
1902 /*useInBoundsInsteadOfMasking=*/false);
1903
1904 PackingMetadata packMetadata;
1905 SmallVector<int64_t> lastDimToInsertPosPerm =
1906 getUnPackInverseSrcPerm(unpackOp, metadata&: packMetadata);
1907 ShapedType maskedOpShapedType = cast<ShapedType>(Val: readResult.getType());
1908 SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1909 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1910 applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm);
1911 RankedTensorType stripMineTensorType =
1912 RankedTensorType::get(shape: stripMineShape, elementType: stripMineElemType);
1913 // Transpose the appropriate rows to match output.
1914 vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1915 location: loc, args&: readResult, args&: lastDimToInsertPosPerm);
1916
1917 // Collapse the vector to the size required by result.
1918 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1919 type: stripMineTensorType, reassociation: packMetadata.reassociations);
1920 mlir::VectorType vecCollapsedType =
1921 VectorType::get(shape: collapsedType.getShape(), elementType: collapsedType.getElementType());
1922 vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1923 location: loc, args&: vecCollapsedType, args: transposeOp->getResult(idx: 0));
1924
1925 // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1926 // otherwise the validator complains that the mask size is invalid.
1927 SmallVector<int64_t> writeVectorSizes(
1928 unpackOp.getDestType().hasStaticShape()
1929 ? vectorSizes
1930 : shapeCastOp.getResultVectorType().getShape());
1931 Value dest = rewriter.create<tensor::EmptyOp>(
1932 location: loc, args&: reifiedRetShapes[0],
1933 args: shapeCastOp.getResult().getType().getElementType());
1934 Operation *write = createWriteOrMaskedWrite(
1935 builder&: rewriter, loc, vecToStore: shapeCastOp.getResult(), dest,
1936 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1937 newResults.push_back(Elt: write->getResult(idx: 0));
1938 return success();
1939}
1940
1941/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1942/// and (3) all-zero lowPad to
1943/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1944static LogicalResult
1945vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1946 ArrayRef<int64_t> inputVectorSizes,
1947 SmallVectorImpl<Value> &newResults) {
1948 auto padValue = padOp.getConstantPaddingValue();
1949 Location loc = padOp.getLoc();
1950
1951 // TODO: Introduce a parent class that will handle the insertion point update.
1952 OpBuilder::InsertionGuard g(rewriter);
1953 rewriter.setInsertionPoint(padOp);
1954
1955 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1956 LogicalResult status =
1957 cast<ReifyRankedShapedTypeOpInterface>(Val: padOp.getOperation())
1958 .reifyResultShapes(builder&: rewriter, reifiedReturnShapes);
1959 (void)status; // prevent unused variable warning on non-assert builds
1960 assert(succeeded(status) && "failed to reify result shapes");
1961 auto maskedRead = vector::createReadOrMaskedRead(
1962 builder&: rewriter, loc, source: padOp.getSource(), inputVectorSizes, padValue,
1963 /*useInBoundsInsteadOfMasking=*/false);
1964
1965 // Create Xfer write Op
1966 Value dest = rewriter.create<tensor::EmptyOp>(
1967 location: loc, args&: reifiedReturnShapes[0], args: padOp.getResultType().getElementType());
1968 Operation *write = createWriteOrMaskedWrite(builder&: rewriter, loc, vecToStore: maskedRead, dest);
1969 newResults.push_back(Elt: write->getResult(idx: 0));
1970 return success();
1971}
1972
1973// TODO: probably need some extra checks for reduction followed by consumer
1974// ops that may not commute (e.g. linear reduction + non-linear instructions).
1975static LogicalResult reductionPreconditions(LinalgOp op) {
1976 if (llvm::none_of(Range: op.getIteratorTypesArray(), P: isReductionIterator)) {
1977 LDBG("reduction precondition failed: no reduction iterator\n");
1978 return failure();
1979 }
1980 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1981 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand: &opOperand);
1982 if (indexingMap.isPermutation())
1983 continue;
1984
1985 Operation *reduceOp = matchLinalgReduction(outputOperand: &opOperand);
1986 if (!reduceOp || !getCombinerOpKind(combinerOp: reduceOp)) {
1987 LDBG("reduction precondition failed: reduction detection failed\n");
1988 return failure();
1989 }
1990 }
1991 return success();
1992}
1993
1994static LogicalResult
1995vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1996 bool flatten1DDepthwiseConv) {
1997 if (flatten1DDepthwiseConv) {
1998 LDBG("Vectorization of flattened convs with dynamic shapes is not "
1999 "supported\n");
2000 return failure();
2001 }
2002
2003 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(Val: conv)) {
2004 LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2005 return failure();
2006 }
2007
2008 // Support dynamic shapes in 1D depthwise convolution, but only in the
2009 // _channel_ dimension.
2010 Value lhs = conv.getDpsInputOperand(i: 0)->get();
2011 ArrayRef<int64_t> lhsShape = cast<ShapedType>(Val: lhs.getType()).getShape();
2012 auto shapeWithoutCh = lhsShape.drop_back(N: 1);
2013 if (ShapedType::isDynamicShape(dSizes: shapeWithoutCh)) {
2014 LDBG("Dynamically-shaped op vectorization precondition failed: only "
2015 "channel dim can be dynamic\n");
2016 return failure();
2017 }
2018
2019 return success();
2020}
2021
2022static LogicalResult
2023vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2024 bool flatten1DDepthwiseConv) {
2025 if (isa<ConvolutionOpInterface>(Val: op.getOperation()))
2026 return vectorizeDynamicConvOpPrecondition(conv: op, flatten1DDepthwiseConv);
2027
2028 if (hasReductionIterator(op))
2029 return reductionPreconditions(op);
2030
2031 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2032 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2033 if (!isElementwise(op) &&
2034 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2035 Val: op.getOperation()))
2036 return failure();
2037
2038 LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
2039 return success();
2040}
2041
2042/// Need to check if the inner-tiles are static/constant.
2043static LogicalResult
2044vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2045 ArrayRef<int64_t> inputVectorSizes) {
2046
2047 if (llvm::any_of(Range: unpackOp.getInnerTiles(), P: [](OpFoldResult res) {
2048 return !getConstantIntValue(ofr: res).has_value();
2049 })) {
2050 LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
2051 return failure();
2052 }
2053 ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2054 bool satisfyEmptyCond = inputVectorSizes.empty() &&
2055 unpackOp.getDestType().hasStaticShape() &&
2056 unpackOp.getSourceType().hasStaticShape();
2057 if (!satisfyEmptyCond &&
2058 failed(Result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes)))
2059 return failure();
2060
2061 return success();
2062}
2063
2064static LogicalResult
2065vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2066 ArrayRef<int64_t> inputVectorSizes) {
2067
2068 TypedValue<RankedTensorType> source = sliceOp.getSource();
2069 auto sourceType = source.getType();
2070 if (!VectorType::isValidElementType(t: sourceType.getElementType()))
2071 return failure();
2072
2073 // Get the pad value.
2074 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2075 // scalar padding value. Note that:
2076 // * for in-bounds accesses,
2077 // the value is actually irrelevant. There are 2 cases in which xfer.read
2078 // accesses are known to be in-bounds:
2079 // 1. The source shape is static (output vector sizes would be based on
2080 // the source shape and hence all memory accesses would be in-bounds),
2081 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2082 // this case it is safe to assume that all memory accesses are in-bounds.
2083 //
2084 // When the value is not known and not needed, use 0. Otherwise, bail out.
2085 Value padValue = getStaticPadVal(op: sliceOp);
2086 bool isOutOfBoundsRead =
2087 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2088
2089 if (!padValue && isOutOfBoundsRead) {
2090 LDBG("Failed to get a pad value for out-of-bounds read access\n");
2091 return failure();
2092 }
2093 return success();
2094}
2095
2096namespace {
2097enum class ConvOperationKind { Conv, Pool };
2098} // namespace
2099
2100static bool isCastOfBlockArgument(Operation *op) {
2101 return isa<CastOpInterface>(Val: op) && op->getNumOperands() == 1 &&
2102 isa<BlockArgument>(Val: op->getOperand(idx: 0));
2103}
2104
2105// Returns the ConvOperationKind of the op using reduceOp of the generic
2106// payload. If it is neither a convolution nor a pooling, it returns
2107// std::nullopt.
2108//
2109// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2110// + yield) and rhs is not used) then it is the body of a pooling
2111// If conv, check for single `mul` predecessor. The `mul` operands must be
2112// block arguments or extension of block arguments.
2113// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2114// must be block arguments or extension of block arguments.
2115static std::optional<ConvOperationKind>
2116getConvOperationKind(Operation *reduceOp) {
2117 int numBlockArguments =
2118 llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>);
2119
2120 switch (numBlockArguments) {
2121 case 1: {
2122 // Will be convolution if feeder is a MulOp.
2123 // A strength reduced version of MulOp for i1 type is AndOp which is also
2124 // supported. Otherwise, it can be pooling. This strength reduction logic
2125 // is in `buildBinaryFn` helper in the Linalg dialect.
2126 auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(),
2127 P: llvm::IsaPred<BlockArgument>);
2128 assert(feedValIt != reduceOp->operand_end() &&
2129 "Expected a non-block argument operand");
2130 Operation *feedOp = (*feedValIt).getDefiningOp();
2131 if (isCastOfBlockArgument(op: feedOp)) {
2132 return ConvOperationKind::Pool;
2133 }
2134
2135 if (!((isa<arith::MulIOp, arith::MulFOp>(Val: feedOp) ||
2136 (isa<arith::AndIOp>(Val: feedOp) &&
2137 feedOp->getResultTypes()[0].isInteger(width: 1))) &&
2138 llvm::all_of(Range: feedOp->getOperands(), P: [](Value v) {
2139 if (isa<BlockArgument>(Val: v))
2140 return true;
2141 if (Operation *op = v.getDefiningOp())
2142 return isCastOfBlockArgument(op);
2143 return false;
2144 }))) {
2145 return std::nullopt;
2146 }
2147
2148 return ConvOperationKind::Conv;
2149 }
2150 case 2:
2151 // Must be pooling
2152 return ConvOperationKind::Pool;
2153 default:
2154 return std::nullopt;
2155 }
2156}
2157
2158static bool isSupportedPoolKind(vector::CombiningKind kind) {
2159 switch (kind) {
2160 case vector::CombiningKind::ADD:
2161 case vector::CombiningKind::MAXNUMF:
2162 case vector::CombiningKind::MAXIMUMF:
2163 case vector::CombiningKind::MAXSI:
2164 case vector::CombiningKind::MAXUI:
2165 case vector::CombiningKind::MINNUMF:
2166 case vector::CombiningKind::MINIMUMF:
2167 case vector::CombiningKind::MINSI:
2168 case vector::CombiningKind::MINUI:
2169 return true;
2170 default:
2171 return false;
2172 }
2173}
2174
2175static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2176 auto getOperandType = [&](auto operand) {
2177 return dyn_cast<ShapedType>((operand->get()).getType());
2178 };
2179 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(i: 0));
2180 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(i: 1));
2181 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(i: 0));
2182 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2183 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2184 // Note that this also ensures 2D and 3D convolutions are rejected.
2185 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2186 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2187 return failure();
2188
2189 Operation *reduceOp = matchLinalgReduction(outputOperand: convOp.getDpsInitOperand(i: 0));
2190 if (!reduceOp)
2191 return failure();
2192
2193 auto maybeOper = getConvOperationKind(reduceOp);
2194 if (!maybeOper.has_value())
2195 return failure();
2196
2197 auto maybeKind = getCombinerOpKind(combinerOp: reduceOp);
2198 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2199 // can get strength reduced to `OR` which is also supported. This strength
2200 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2201 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2202 *maybeKind != vector::CombiningKind::OR) &&
2203 (*maybeOper != ConvOperationKind::Pool ||
2204 !isSupportedPoolKind(kind: *maybeKind)))) {
2205 return failure();
2206 }
2207
2208 auto rhsRank = rhsShapedType.getRank();
2209 if (*maybeOper == ConvOperationKind::Pool) {
2210 if (rhsRank != 1)
2211 return failure();
2212 } else {
2213 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2214 return failure();
2215 }
2216
2217 return success();
2218}
2219
2220static LogicalResult vectorizeLinalgOpPrecondition(
2221 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2222 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2223 // tensor with dimension of 0 cannot be vectorized.
2224 if (llvm::any_of(Range: linalgOp->getOpOperands(), P: [&](OpOperand &operand) {
2225 return llvm::is_contained(Range: linalgOp.getShape(opOperand: &operand), Element: 0);
2226 }))
2227 return failure();
2228 // Check API contract for input vector sizes.
2229 if (!inputVectorSizes.empty() &&
2230 failed(Result: vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(),
2231 inputVectorSizes)))
2232 return failure();
2233
2234 if (linalgOp.hasDynamicShape() && failed(Result: vectorizeDynamicLinalgOpPrecondition(
2235 op: linalgOp, flatten1DDepthwiseConv))) {
2236 LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2237 return failure();
2238 }
2239
2240 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2241
2242 // Register CustomVectorizationPrecondition for extractOp.
2243 customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition);
2244
2245 // All types in the body should be a supported element type for VectorType.
2246 for (Operation &innerOp : linalgOp->getRegion(index: 0).front()) {
2247 // Check if any custom hook can vectorize the inner op.
2248 if (llvm::any_of(
2249 Range&: customPreconditions,
2250 P: [&](const CustomVectorizationPrecondition &customPrecondition) {
2251 return succeeded(
2252 Result: customPrecondition(&innerOp, vectorizeNDExtract));
2253 })) {
2254 continue;
2255 }
2256 if (!llvm::all_of(Range: innerOp.getOperandTypes(),
2257 P: VectorType::isValidElementType)) {
2258 return failure();
2259 }
2260 if (!llvm::all_of(Range: innerOp.getResultTypes(),
2261 P: VectorType::isValidElementType)) {
2262 return failure();
2263 }
2264 }
2265 if (isElementwise(op: linalgOp))
2266 return success();
2267
2268 // TODO: isaConvolutionOpInterface that can also infer from generic
2269 // features. But we will still need stride/dilation attributes that will be
2270 // annoying to reverse-engineer...
2271 if (isa<ConvolutionOpInterface>(Val: linalgOp.getOperation()))
2272 return vectorizeConvOpPrecondition(convOp: linalgOp);
2273
2274 // TODO: the common vector shape is equal to the static loop sizes only when
2275 // all indexing maps are projected permutations. For convs and stencils the
2276 // logic will need to evolve.
2277 if (!allIndexingsAreProjectedPermutation(op: linalgOp)) {
2278 LDBG("precondition failed: not projected permutations\n");
2279 return failure();
2280 }
2281 if (failed(Result: reductionPreconditions(op: linalgOp))) {
2282 LDBG("precondition failed: reduction preconditions\n");
2283 return failure();
2284 }
2285 return success();
2286}
2287
2288static LogicalResult
2289vectorizePackOpPrecondition(linalg::PackOp packOp,
2290 ArrayRef<int64_t> inputVectorSizes) {
2291 auto padValue = packOp.getPaddingValue();
2292 Attribute cstAttr;
2293 if (padValue && !matchPattern(value: padValue, pattern: m_Constant(bind_value: &cstAttr))) {
2294 LDBG("pad value is not constant: " << packOp << "\n");
2295 return failure();
2296 }
2297 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2298 bool satisfyEmptyCond = true;
2299 if (inputVectorSizes.empty()) {
2300 if (!packOp.getDestType().hasStaticShape() ||
2301 !packOp.getSourceType().hasStaticShape())
2302 satisfyEmptyCond = false;
2303 }
2304
2305 if (!satisfyEmptyCond &&
2306 failed(Result: vector::isValidMaskedInputVector(
2307 shape: resultTensorShape.take_front(N: packOp.getSourceRank()),
2308 inputVectorSizes)))
2309 return failure();
2310
2311 if (llvm::any_of(Range: packOp.getInnerTiles(), P: [](OpFoldResult v) {
2312 return !getConstantIntValue(ofr: v).has_value();
2313 })) {
2314 LDBG("inner_tiles must be constant: " << packOp << "\n");
2315 return failure();
2316 }
2317
2318 return success();
2319}
2320
2321static LogicalResult
2322vectorizePadOpPrecondition(tensor::PadOp padOp,
2323 ArrayRef<int64_t> inputVectorSizes) {
2324 auto padValue = padOp.getConstantPaddingValue();
2325 if (!padValue) {
2326 LDBG("pad value is not constant: " << padOp << "\n");
2327 return failure();
2328 }
2329
2330 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2331 if (failed(Result: vector::isValidMaskedInputVector(shape: resultTensorShape,
2332 inputVectorSizes)))
2333 return failure();
2334
2335 // Padding with non-zero low pad values is not supported, unless the
2336 // corresponding result dim is 1 as this would require shifting the results to
2337 // the right for the low padded dims by the required amount of low padding.
2338 // However, we do support low padding if the dims being low padded have result
2339 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2340 // input size of that dimension will be dynamically zero (as the sum of the
2341 // low pad and input dim size has to be one) and hence we will create a zero
2342 // mask as the lowering logic just makes the mask one for the input dim size -
2343 // which is zero here. Hence we will load the pad value which is what we want
2344 // in this case. If the low pad is dynamically zero then the lowering is
2345 // correct as well as no shifts are necessary.
2346 if (llvm::any_of(Range: llvm::enumerate(First: padOp.getLow()), P: [&](const auto &en) {
2347 Value padValue = en.value();
2348 unsigned pos = en.index();
2349 std::optional<int64_t> pad = getConstantIntValue(ofr: padValue);
2350 return (!pad.has_value() || pad.value() != 0) &&
2351 resultTensorShape[pos] != 1;
2352 })) {
2353 LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2354 return failure();
2355 }
2356
2357 return success();
2358}
2359
2360/// Preconditions for scalable vectors. This is quite restrictive - it models
2361/// the fact that in practice we would only make selected dimensions scalable.
2362static LogicalResult
2363vectorizeScalableVectorPrecondition(Operation *op,
2364 ArrayRef<int64_t> inputVectorSizes,
2365 ArrayRef<bool> inputScalableVecDims) {
2366 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2367 "Number of input vector sizes and scalable dims doesn't match");
2368
2369 size_t numOfScalableDims =
2370 llvm::count_if(Range&: inputScalableVecDims, P: [](bool flag) { return flag; });
2371
2372 if (numOfScalableDims == 0)
2373 return success();
2374
2375 auto linalgOp = dyn_cast<LinalgOp>(Val: op);
2376
2377 // Cond 1: There's been no need for scalable vectorisation of
2378 // non-linalg Ops so far
2379 if (!linalgOp)
2380 return failure();
2381
2382 // Cond 2: There's been no need for more than 2 scalable dims so far
2383 if (numOfScalableDims > 2)
2384 return failure();
2385
2386 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2387 // it matches one of the supported cases:
2388 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2389 // (*).
2390 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2391 // parallel dims.
2392 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2393 // dim.
2394 // The 2nd restriction above means that only Matmul-like Ops are supported
2395 // when 2 dims are scalable, e.g. :
2396 // * iterators = [parallel, parallel, reduction]
2397 // * scalable flags = [true, true, false]
2398 //
2399 // (*) Non-unit dims get folded away in practice.
2400 // TODO: Relax these conditions as good motivating examples are identified.
2401
2402 // Find the first scalable flag.
2403 bool seenNonUnitParallel = false;
2404 auto iterators = linalgOp.getIteratorTypesArray();
2405 SmallVector<bool> scalableFlags(inputScalableVecDims);
2406 int64_t idx = scalableFlags.size() - 1;
2407 while (!scalableFlags[idx]) {
2408 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2409 seenNonUnitParallel |=
2410 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2411
2412 iterators.pop_back();
2413 scalableFlags.pop_back();
2414 --idx;
2415 }
2416
2417 // Analyze the iterator corresponding to the first scalable dim.
2418 switch (iterators.back()) {
2419 case utils::IteratorType::reduction: {
2420 // Check 3. above is met.
2421 if (iterators.size() != inputVectorSizes.size()) {
2422 LDBG("Non-trailing reduction dim requested for scalable "
2423 "vectorization\n");
2424 return failure();
2425 }
2426 if (isa<linalg::MatmulOp>(Val: op) || isa<linalg::MatmulTransposeAOp>(Val: op)) {
2427 LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2428 "is not supported\n");
2429 return failure();
2430 }
2431 break;
2432 }
2433 case utils::IteratorType::parallel: {
2434 // Check 1. and 2. above are met.
2435 if (seenNonUnitParallel) {
2436 LDBG("Inner parallel dim not requested for scalable "
2437 "vectorization\n");
2438 return failure();
2439 }
2440 break;
2441 }
2442 }
2443
2444 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2445 // supported for which expect the folowing config:
2446 // * iterators = [parallel, parallel, reduction]
2447 // * scalable flags = [true, true, false]
2448 if (numOfScalableDims == 2) {
2449 // Disallow below case which breaks 3. above:
2450 // * iterators = [..., parallel, reduction]
2451 // * scalable flags = [..., true, true]
2452 if (iterators.back() == utils::IteratorType::reduction) {
2453 LDBG("Higher dim than the trailing reduction dim requested for scalable "
2454 "vectorization\n");
2455 return failure();
2456 }
2457 scalableFlags.pop_back();
2458 iterators.pop_back();
2459
2460 if (!scalableFlags.back() ||
2461 (iterators.back() != utils::IteratorType::parallel))
2462 return failure();
2463 }
2464
2465 // Check to not let go the matmul with extended semantic, through this
2466 // transform.
2467 if (linalgOp.hasUserDefinedMaps())
2468 return failure();
2469
2470 // Cond 4: Only the following ops are supported in the
2471 // presence of scalable vectors
2472 return success(IsSuccess: isElementwise(op: linalgOp) || isa<linalg::MatmulOp>(Val: op) ||
2473 isa<linalg::MatmulTransposeAOp>(Val: op) ||
2474 isa<linalg::DepthwiseConv1DNwcWcOp>(Val: op) ||
2475 isa<linalg::MatvecOp>(Val: op) || hasReductionIterator(op&: linalgOp));
2476}
2477
2478LogicalResult mlir::linalg::vectorizeOpPrecondition(
2479 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2480 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2481 bool flatten1DDepthwiseConv) {
2482
2483 if (!hasVectorizationImpl(op))
2484 return failure();
2485
2486 if (failed(Result: vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2487 inputScalableVecDims)))
2488 return failure();
2489
2490 return TypeSwitch<Operation *, LogicalResult>(op)
2491 .Case<linalg::LinalgOp>(caseFn: [&](auto linalgOp) {
2492 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2493 vectorizeNDExtract,
2494 flatten1DDepthwiseConv);
2495 })
2496 .Case<tensor::PadOp>(caseFn: [&](auto padOp) {
2497 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2498 })
2499 .Case<linalg::PackOp>(caseFn: [&](auto packOp) {
2500 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2501 })
2502 .Case<linalg::UnPackOp>(caseFn: [&](auto unpackOp) {
2503 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2504 })
2505 .Case<tensor::InsertSliceOp>(caseFn: [&](auto sliceOp) {
2506 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2507 })
2508 .Default(defaultFn: [](auto) { return failure(); });
2509}
2510
2511/// Converts affine.apply Ops to arithmetic operations.
2512static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2513 OpBuilder::InsertionGuard g(rewriter);
2514 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2515
2516 for (auto op : make_early_inc_range(Range&: toReplace)) {
2517 rewriter.setInsertionPoint(op);
2518 auto expanded = affine::expandAffineExpr(
2519 builder&: rewriter, loc: op->getLoc(), expr: op.getAffineMap().getResult(idx: 0),
2520 dimValues: op.getOperands().take_front(n: op.getAffineMap().getNumDims()),
2521 symbolValues: op.getOperands().take_back(n: op.getAffineMap().getNumSymbols()));
2522 rewriter.replaceOp(op, newValues: expanded);
2523 }
2524}
2525
2526bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2527 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2528 tensor::InsertSliceOp>(Val: op);
2529}
2530
2531FailureOr<VectorizationResult>
2532mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2533 ArrayRef<int64_t> inputVectorSizes,
2534 ArrayRef<bool> inputScalableVecDims,
2535 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2536 LDBG("Attempting to vectorize:\n" << *op << "\n");
2537 LDBG("Input vector sizes: ");
2538 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2539 LLVM_DEBUG(llvm::dbgs() << "\n");
2540 LDBG("Input scalable vector dims: ");
2541 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2542 LLVM_DEBUG(llvm::dbgs() << "\n");
2543
2544 if (failed(Result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2545 vectorizeNDExtract,
2546 flatten1DDepthwiseConv))) {
2547 LDBG("Vectorization pre-conditions failed\n");
2548 return failure();
2549 }
2550
2551 // Initialize vectorization state.
2552 VectorizationState state(rewriter);
2553 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op)) {
2554 if (failed(Result: state.initState(rewriter, linalgOp, inputVectorSizes,
2555 inputScalableVecDims))) {
2556 LDBG("Vectorization state couldn't be initialized\n");
2557 return failure();
2558 }
2559 }
2560
2561 SmallVector<Value> results;
2562 auto vectorizeResult =
2563 TypeSwitch<Operation *, LogicalResult>(op)
2564 .Case<linalg::LinalgOp>(caseFn: [&](auto linalgOp) {
2565 // TODO: isaConvolutionOpInterface that can also infer from
2566 // generic features. Will require stride/dilation attributes
2567 // inference.
2568 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2569 FailureOr<Operation *> convOr = vectorizeConvolution(
2570 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2571 flatten1DDepthwiseConv);
2572 if (succeeded(Result: convOr)) {
2573 llvm::append_range(C&: results, R: (*convOr)->getResults());
2574 return success();
2575 }
2576
2577 LDBG("Unsupported convolution can't be vectorized.\n");
2578 return failure();
2579 }
2580
2581 LDBG("Vectorize generic by broadcasting to the canonical vector "
2582 "shape\n");
2583
2584 // Pre-process before proceeding.
2585 convertAffineApply(rewriter, linalgOp);
2586
2587 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2588 // to 'OpBuilder' when it is passed over to some methods like
2589 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2590 // erase an op within these methods, the actual rewriter won't be
2591 // notified and we will end up with read-after-free issues!
2592 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2593 })
2594 .Case<tensor::PadOp>(caseFn: [&](auto padOp) {
2595 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2596 results);
2597 })
2598 .Case<linalg::PackOp>(caseFn: [&](auto packOp) {
2599 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2600 results);
2601 })
2602 .Case<linalg::UnPackOp>(caseFn: [&](auto unpackOp) {
2603 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2604 inputVectorSizes, results);
2605 })
2606 .Case<tensor::InsertSliceOp>(caseFn: [&](auto sliceOp) {
2607 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2608 results);
2609 })
2610 .Default(defaultFn: [](auto) { return failure(); });
2611
2612 if (failed(Result: vectorizeResult)) {
2613 LDBG("Vectorization failed\n");
2614 return failure();
2615 }
2616
2617 return VectorizationResult{.replacements: results};
2618}
2619
2620LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2621 memref::CopyOp copyOp) {
2622 auto srcType = cast<MemRefType>(Val: copyOp.getSource().getType());
2623 auto dstType = cast<MemRefType>(Val: copyOp.getTarget().getType());
2624 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2625 return failure();
2626
2627 auto srcElementType = getElementTypeOrSelf(type: srcType);
2628 auto dstElementType = getElementTypeOrSelf(type: dstType);
2629 if (!VectorType::isValidElementType(t: srcElementType) ||
2630 !VectorType::isValidElementType(t: dstElementType))
2631 return failure();
2632
2633 auto readType = VectorType::get(shape: srcType.getShape(), elementType: srcElementType);
2634 auto writeType = VectorType::get(shape: dstType.getShape(), elementType: dstElementType);
2635
2636 Location loc = copyOp->getLoc();
2637 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
2638 SmallVector<Value> indices(srcType.getRank(), zero);
2639
2640 Value readValue = rewriter.create<vector::TransferReadOp>(
2641 location: loc, args&: readType, args: copyOp.getSource(), args&: indices,
2642 /*padding=*/args: std::nullopt,
2643 args: rewriter.getMultiDimIdentityMap(rank: srcType.getRank()));
2644 if (cast<VectorType>(Val: readValue.getType()).getRank() == 0) {
2645 readValue =
2646 rewriter.create<vector::ExtractOp>(location: loc, args&: readValue, args: ArrayRef<int64_t>());
2647 readValue = rewriter.create<vector::BroadcastOp>(location: loc, args&: writeType, args&: readValue);
2648 }
2649 Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2650 location: loc, args&: readValue, args: copyOp.getTarget(), args&: indices,
2651 args: rewriter.getMultiDimIdentityMap(rank: srcType.getRank()));
2652 rewriter.replaceOp(op: copyOp, newValues: writeValue->getResults());
2653 return success();
2654}
2655
2656//----------------------------------------------------------------------------//
2657// Misc. vectorization patterns.
2658//----------------------------------------------------------------------------//
2659/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2660/// given operation type OpTy.
2661template <typename OpTy>
2662struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2663 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2664
2665 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2666 PatternRewriter &rewriter) const final {
2667 bool changed = false;
2668 // Insert users in vector, because some users may be replaced/removed.
2669 for (auto *user : llvm::to_vector<4>(Range: padOp->getUsers()))
2670 if (auto op = dyn_cast<OpTy>(user))
2671 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2672 return success(IsSuccess: changed);
2673 }
2674
2675protected:
2676 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2677 tensor::PadOp padOp, OpTy op) const = 0;
2678};
2679
2680/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2681/// ```
2682/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2683/// %r = vector.transfer_read %0[%c0, %c0], %cst
2684/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2685/// ```
2686/// is rewritten to:
2687/// ```
2688/// %r = vector.transfer_read %src[%c0, %c0], %padding
2689/// {in_bounds = [true, true]}
2690/// : tensor<?x?xf32>, vector<17x5xf32>
2691/// ```
2692/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2693/// sure that the original padding value %cst was never used.
2694///
2695/// This rewrite is possible if:
2696/// - `xferOp` has no out-of-bounds dims or mask.
2697/// - Low padding is static 0.
2698/// - Single, scalar padding value.
2699struct PadOpVectorizationWithTransferReadPattern
2700 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2701 using VectorizePadOpUserPattern<
2702 vector::TransferReadOp>::VectorizePadOpUserPattern;
2703
2704 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2705 vector::TransferReadOp xferOp) const override {
2706 // Low padding must be static 0.
2707 if (!padOp.hasZeroLowPad())
2708 return failure();
2709 // Pad value must be a constant.
2710 auto padValue = padOp.getConstantPaddingValue();
2711 if (!padValue)
2712 return failure();
2713 // Padding value of existing `xferOp` is unused.
2714 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2715 return failure();
2716
2717 rewriter.modifyOpInPlace(root: xferOp, callable: [&]() {
2718 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2719 xferOp->setAttr(name: xferOp.getInBoundsAttrName(),
2720 value: rewriter.getBoolArrayAttr(values: inBounds));
2721 xferOp.getBaseMutable().assign(value: padOp.getSource());
2722 xferOp.getPaddingMutable().assign(value: padValue);
2723 });
2724
2725 return success();
2726 }
2727};
2728
2729/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2730/// This pattern rewrites TransferWriteOps that write to a padded tensor
2731/// value, where the same amount of padding is immediately removed again after
2732/// the write. In such cases, the TransferWriteOp can write to the non-padded
2733/// tensor value and apply out-of-bounds masking. E.g.:
2734/// ```
2735/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2736/// : tensor<...> to tensor<?x?xf32>
2737/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2738/// %2 = vector.transfer_write %vec, %1[...]
2739/// : vector<17x5xf32>, tensor<17x5xf32>
2740/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2741/// : tensor<17x5xf32> to tensor<?x?xf32>
2742/// ```
2743/// is rewritten to:
2744/// ```
2745/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2746/// : tensor<...> to tensor<?x?xf32>
2747/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2748/// tensor<?x?xf32>
2749/// ```
2750/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2751/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2752/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2753/// from %r's old dimensions.
2754///
2755/// This rewrite is possible if:
2756/// - Low padding is static 0.
2757/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2758/// ExtractSliceOp trims the same amount of padding that was added
2759/// beforehand.
2760/// - Single, scalar padding value.
2761struct PadOpVectorizationWithTransferWritePattern
2762 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2763 using VectorizePadOpUserPattern<
2764 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2765
2766 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2767 vector::TransferWriteOp xferOp) const override {
2768 // TODO: support 0-d corner case.
2769 if (xferOp.getTransferRank() == 0)
2770 return failure();
2771
2772 // Low padding must be static 0.
2773 if (!padOp.hasZeroLowPad())
2774 return failure();
2775 // Pad value must be a constant.
2776 auto padValue = padOp.getConstantPaddingValue();
2777 if (!padValue)
2778 return failure();
2779 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2780 if (!xferOp->hasOneUse())
2781 return failure();
2782 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(Val: *xferOp->user_begin());
2783 if (!trimPadding)
2784 return failure();
2785 // Only static zero offsets supported when trimming padding.
2786 if (!trimPadding.hasZeroOffset())
2787 return failure();
2788 // trimPadding must remove the amount of padding that was added earlier.
2789 if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding))
2790 return failure();
2791
2792 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2793 rewriter.setInsertionPoint(xferOp);
2794
2795 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2796 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2797 op: xferOp, args: padOp.getSource().getType(), args: xferOp.getVector(),
2798 args: padOp.getSource(), args: xferOp.getIndices(), args: xferOp.getPermutationMapAttr(),
2799 args: xferOp.getMask(), args: rewriter.getBoolArrayAttr(values: inBounds));
2800 rewriter.replaceOp(op: trimPadding, newValues: newXferOp->getResult(idx: 0));
2801
2802 return success();
2803 }
2804
2805 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2806 /// i.e., same dimensions.
2807 ///
2808 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2809 /// dimensions, this function tries to infer the (static) tensor size by
2810 /// looking at the defining op and utilizing op-specific knowledge.
2811 ///
2812 /// This is a conservative analysis. In case equal tensor sizes cannot be
2813 /// proven statically, this analysis returns `false` even though the tensor
2814 /// sizes may turn out to be equal at runtime.
2815 bool hasSameTensorSize(Value beforePadding,
2816 tensor::ExtractSliceOp afterTrimming) const {
2817 // If the input to tensor::PadOp is a CastOp, try with both CastOp
2818 // result and CastOp operand.
2819 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2820 if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming))
2821 return true;
2822
2823 auto t1 = dyn_cast<RankedTensorType>(Val: beforePadding.getType());
2824 auto t2 = dyn_cast<RankedTensorType>(Val: afterTrimming.getType());
2825 // Only RankedTensorType supported.
2826 if (!t1 || !t2)
2827 return false;
2828 // Rank of both values must be the same.
2829 if (t1.getRank() != t2.getRank())
2830 return false;
2831
2832 // All static dimensions must be the same. Mixed cases (e.g., dimension
2833 // static in `t1` but dynamic in `t2`) are not supported.
2834 for (unsigned i = 0; i < t1.getRank(); ++i) {
2835 if (t1.isDynamicDim(idx: i) != t2.isDynamicDim(idx: i))
2836 return false;
2837 if (!t1.isDynamicDim(idx: i) && t1.getDimSize(idx: i) != t2.getDimSize(idx: i))
2838 return false;
2839 }
2840
2841 // Nothing more to check if all dimensions are static.
2842 if (t1.getNumDynamicDims() == 0)
2843 return true;
2844
2845 // All dynamic sizes must be the same. The only supported case at the
2846 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2847 // thereof).
2848
2849 // Apart from CastOp, only ExtractSliceOp is supported.
2850 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2851 if (!beforeSlice)
2852 return false;
2853
2854 assert(static_cast<size_t>(t1.getRank()) ==
2855 beforeSlice.getMixedSizes().size());
2856 assert(static_cast<size_t>(t2.getRank()) ==
2857 afterTrimming.getMixedSizes().size());
2858
2859 for (unsigned i = 0; i < t1.getRank(); ++i) {
2860 // Skip static dimensions.
2861 if (!t1.isDynamicDim(idx: i))
2862 continue;
2863 auto size1 = beforeSlice.getMixedSizes()[i];
2864 auto size2 = afterTrimming.getMixedSizes()[i];
2865
2866 // Case 1: Same value or same constant int.
2867 if (isEqualConstantIntOrValue(ofr1: size1, ofr2: size2))
2868 continue;
2869
2870 // Other cases: Take a deeper look at defining ops of values.
2871 auto v1 = llvm::dyn_cast_if_present<Value>(Val&: size1);
2872 auto v2 = llvm::dyn_cast_if_present<Value>(Val&: size2);
2873 if (!v1 || !v2)
2874 return false;
2875
2876 // Case 2: Both values are identical AffineMinOps. (Should not happen if
2877 // CSE is run.)
2878 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2879 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2880 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2881 minOp1.getOperands() == minOp2.getOperands())
2882 continue;
2883
2884 // Add additional cases as needed.
2885 }
2886
2887 // All tests passed.
2888 return true;
2889 }
2890};
2891
2892/// Returns the effective Pad value for the input op, provided it's a scalar.
2893///
2894/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2895/// this Op performs padding, retrieve the padding value provided that it's
2896/// a scalar and static/fixed for all the padded values. Returns an empty value
2897/// otherwise.
2898///
2899/// TODO: This is used twice (when checking vectorization pre-conditions and
2900/// when vectorizing). Cache results instead of re-running.
2901static Value getStaticPadVal(Operation *op) {
2902 if (!op)
2903 return {};
2904
2905 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2906 // being broadcast, provided that it's a scalar.
2907 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(Val: op)) {
2908 auto source = bcast.getSource();
2909 if (llvm::dyn_cast<VectorType>(Val: source.getType()))
2910 return {};
2911
2912 return source;
2913 }
2914
2915 // 2. linalg.fill - use the scalar input value that used to fill the output
2916 // tensor.
2917 if (auto fill = llvm::dyn_cast<linalg::FillOp>(Val: op)) {
2918 return fill.getInputs()[0];
2919 }
2920
2921 // 3. tensor.generateOp - can't guarantee the value is fixed without
2922 // analysing, bail out.
2923 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(Val: op)) {
2924 return {};
2925 }
2926
2927 // 4. vector.transfer_write - inspect the input vector that's written from. If
2928 // if contains a single value that has been broadcast (e.g. via
2929 // vector.broadcast), extract it, fail otherwise.
2930 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(Val: op))
2931 return getStaticPadVal(op: xferWrite.getVector().getDefiningOp());
2932
2933 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2934 // than the input tensor, then, provided it's constant, we'll extract the
2935 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2936 // TODO: Clarify the semantics when the input tensor is larger than the
2937 // destination.
2938 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(Val: op))
2939 return getStaticPadVal(op: slice.getDest().getDefiningOp());
2940
2941 return {};
2942}
2943
2944static LogicalResult
2945vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2946 ArrayRef<int64_t> inputVectorSizes,
2947 SmallVectorImpl<Value> &newResults) {
2948 // TODO: Introduce a parent class that will handle the insertion point update.
2949 OpBuilder::InsertionGuard g(rewriter);
2950 rewriter.setInsertionPoint(sliceOp);
2951
2952 TypedValue<RankedTensorType> source = sliceOp.getSource();
2953 auto sourceType = source.getType();
2954 auto resultType = sliceOp.getResultType();
2955
2956 Value padValue = getStaticPadVal(op: sliceOp);
2957
2958 if (!padValue) {
2959 auto elemType = sourceType.getElementType();
2960 padValue = rewriter.create<arith::ConstantOp>(
2961 location: sliceOp.getLoc(), args&: elemType, args: rewriter.getZeroAttr(type: elemType));
2962 }
2963
2964 // 2. Get the vector shape
2965 SmallVector<int64_t> vecShape;
2966 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2967 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2968 if (!inputVectorSizes.empty()) {
2969 vecShape.push_back(Elt: inputVectorSizes[i]);
2970 } else if (!sourceType.isDynamicDim(idx: i)) {
2971 vecShape.push_back(Elt: sourceType.getDimSize(idx: i));
2972 } else if (!resultType.isDynamicDim(idx: i)) {
2973 // Source shape is not statically known, but result shape is.
2974 // Vectorize with size of result shape. This may be larger than the
2975 // source size.
2976 // FIXME: Using rankDiff implies that the source tensor is inserted at
2977 // the end of the destination tensor. However, that's not required.
2978 vecShape.push_back(Elt: resultType.getDimSize(idx: rankDiff + i));
2979 } else {
2980 // Neither source nor result dim of padOp is static. Cannot vectorize
2981 // the copy.
2982 return failure();
2983 }
2984 }
2985 auto vecType = VectorType::get(shape: vecShape, elementType: sourceType.getElementType());
2986
2987 // 3. Generate TransferReadOp + TransferWriteOp
2988 auto loc = sliceOp.getLoc();
2989
2990 // Create read
2991 SmallVector<Value> readIndices(
2992 vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
2993 Value read = mlir::vector::createReadOrMaskedRead(
2994 builder&: rewriter, loc, source, inputVectorSizes: vecType.getShape(), padValue,
2995 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
2996
2997 // Create write
2998 auto writeIndices =
2999 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: sliceOp.getMixedOffsets());
3000 Operation *write =
3001 createWriteOrMaskedWrite(builder&: rewriter, loc, vecToStore: read, dest: sliceOp.getDest(),
3002 writeIndices, useInBoundsInsteadOfMasking: inputVectorSizes.empty());
3003
3004 // 4. Finalize
3005 newResults.push_back(Elt: write->getResult(idx: 0));
3006
3007 return success();
3008}
3009
3010/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3011/// ```
3012/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3013/// %r = tensor.insert_slice %0
3014/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3015/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3016/// ```
3017/// is rewritten to:
3018/// ```
3019/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3020/// : tensor<?x?xf32>, vector<17x5xf32>
3021/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3022/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3023/// ```
3024///
3025/// This rewrite is possible if:
3026/// - Low padding is static 0.
3027/// - `padOp` result shape is static.
3028/// - The entire padded tensor is inserted.
3029/// (Implies that sizes of `insertOp` are all static.)
3030/// - Only unit strides in `insertOp`.
3031/// - Single, scalar padding value.
3032/// - `padOp` result not used as destination.
3033struct PadOpVectorizationWithInsertSlicePattern
3034 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3035 using VectorizePadOpUserPattern<
3036 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3037
3038 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3039 tensor::InsertSliceOp insertOp) const override {
3040 // Low padding must be static 0.
3041 if (!padOp.hasZeroLowPad())
3042 return failure();
3043 // Only unit stride supported.
3044 if (!insertOp.hasUnitStride())
3045 return failure();
3046 // Pad value must be a constant.
3047 auto padValue = padOp.getConstantPaddingValue();
3048 if (!padValue)
3049 return failure();
3050 // Dynamic shapes not supported.
3051 if (!cast<ShapedType>(Val: padOp.getResult().getType()).hasStaticShape())
3052 return failure();
3053 // Pad result not used as destination.
3054 if (insertOp.getDest() == padOp.getResult())
3055 return failure();
3056
3057 auto vecType = VectorType::get(shape: padOp.getType().getShape(),
3058 elementType: padOp.getType().getElementType());
3059 unsigned vecRank = vecType.getRank();
3060 unsigned tensorRank = insertOp.getType().getRank();
3061
3062 // Check if sizes match: Insert the entire tensor into most minor dims.
3063 // (No permutations allowed.)
3064 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3065 expectedSizes.append(in_start: vecType.getShape().begin(), in_end: vecType.getShape().end());
3066 if (!llvm::all_of(
3067 Range: llvm::zip(t: insertOp.getMixedSizes(), u&: expectedSizes), P: [](auto it) {
3068 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3069 }))
3070 return failure();
3071
3072 // Insert the TransferReadOp and TransferWriteOp at the position of the
3073 // InsertSliceOp.
3074 rewriter.setInsertionPoint(insertOp);
3075
3076 // Generate TransferReadOp: Read entire source tensor and add high
3077 // padding.
3078 SmallVector<Value> readIndices(
3079 vecRank, rewriter.create<arith::ConstantIndexOp>(location: padOp.getLoc(), args: 0));
3080 auto read = rewriter.create<vector::TransferReadOp>(
3081 location: padOp.getLoc(), args&: vecType, args: padOp.getSource(), args&: readIndices, args&: padValue);
3082
3083 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3084 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3085 // source must fit into the destination at the specified offsets.
3086 auto writeIndices = getValueOrCreateConstantIndexOp(
3087 b&: rewriter, loc: padOp.getLoc(), valueOrAttrVec: insertOp.getMixedOffsets());
3088 SmallVector<bool> inBounds(vecRank, true);
3089 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3090 op: insertOp, args&: read, args: insertOp.getDest(), args&: writeIndices,
3091 args: ArrayRef<bool>{inBounds});
3092
3093 return success();
3094 }
3095};
3096
3097void mlir::linalg::populatePadOpVectorizationPatterns(
3098 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3099 patterns.add<PadOpVectorizationWithTransferReadPattern,
3100 PadOpVectorizationWithTransferWritePattern,
3101 PadOpVectorizationWithInsertSlicePattern>(
3102 arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1);
3103}
3104
3105//----------------------------------------------------------------------------//
3106// Forwarding patterns
3107//----------------------------------------------------------------------------//
3108
3109/// Check whether there is any interleaved use of any `values` between
3110/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3111/// is in a different block.
3112static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3113 ValueRange values) {
3114 if (firstOp->getBlock() != secondOp->getBlock() ||
3115 !firstOp->isBeforeInBlock(other: secondOp)) {
3116 LDBG("interleavedUses precondition failed, firstOp: "
3117 << *firstOp << ", second op: " << *secondOp << "\n");
3118 return true;
3119 }
3120 for (auto v : values) {
3121 for (auto &u : v.getUses()) {
3122 Operation *owner = u.getOwner();
3123 if (owner == firstOp || owner == secondOp)
3124 continue;
3125 // TODO: this is too conservative, use dominance info in the future.
3126 if (owner->getBlock() == firstOp->getBlock() &&
3127 (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner)))
3128 continue;
3129 LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3130 << ", second op: " << *secondOp << "\n");
3131 return true;
3132 }
3133 }
3134 return false;
3135}
3136
3137/// Return the unique subview use of `v` if it is indeed unique, null
3138/// otherwise.
3139static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3140 memref::SubViewOp subViewOp;
3141 for (auto &u : v.getUses()) {
3142 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(Val: u.getOwner())) {
3143 if (subViewOp)
3144 return memref::SubViewOp();
3145 subViewOp = newSubViewOp;
3146 }
3147 }
3148 return subViewOp;
3149}
3150
3151/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3152/// when available.
3153LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
3154 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3155
3156 // TODO: support mask.
3157 if (xferOp.getMask())
3158 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "unsupported mask");
3159
3160 // Transfer into `view`.
3161 Value viewOrAlloc = xferOp.getBase();
3162 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3163 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3164 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "source not a view or alloc");
3165
3166 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3167 memref::SubViewOp subViewOp = getSubViewUseIfUnique(v: viewOrAlloc);
3168 if (!subViewOp)
3169 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no subview found");
3170 Value subView = subViewOp.getResult();
3171
3172 // Find the copy into `subView` without interleaved uses.
3173 memref::CopyOp copyOp;
3174 for (auto &u : subView.getUses()) {
3175 if (auto newCopyOp = dyn_cast<memref::CopyOp>(Val: u.getOwner())) {
3176 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3177 if (newCopyOp.getTarget() != subView)
3178 continue;
3179 if (mayExistInterleavedUses(firstOp: newCopyOp, secondOp: xferOp, values: {viewOrAlloc, subView}))
3180 continue;
3181 copyOp = newCopyOp;
3182 break;
3183 }
3184 }
3185 if (!copyOp)
3186 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no copy found");
3187
3188 // Find the fill into `viewOrAlloc` without interleaved uses before the
3189 // copy.
3190 FillOp maybeFillOp;
3191 for (auto &u : viewOrAlloc.getUses()) {
3192 if (auto newFillOp = dyn_cast<FillOp>(Val: u.getOwner())) {
3193 assert(isa<MemRefType>(newFillOp.output().getType()));
3194 if (newFillOp.output() != viewOrAlloc)
3195 continue;
3196 if (mayExistInterleavedUses(firstOp: newFillOp, secondOp: copyOp, values: {viewOrAlloc, subView}))
3197 continue;
3198 maybeFillOp = newFillOp;
3199 break;
3200 }
3201 }
3202 // Ensure padding matches.
3203 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3204 return rewriter.notifyMatchFailure(arg&: xferOp,
3205 msg: "padding value does not match fill");
3206
3207 // `in` is the subview that memref.copy reads. Replace it.
3208 Value in = copyOp.getSource();
3209
3210 // memref.copy + linalg.fill can be used to create a padded local buffer.
3211 // The `masked` attribute is only valid on this padded buffer.
3212 // When forwarding to vector.transfer_read, the attribute must be reset
3213 // conservatively.
3214 auto vectorType = xferOp.getVectorType();
3215 Value res = rewriter.create<vector::TransferReadOp>(
3216 location: xferOp.getLoc(), args&: vectorType, args&: in, args: xferOp.getIndices(),
3217 args: xferOp.getPermutationMapAttr(), args: xferOp.getPadding(), args: xferOp.getMask(),
3218 args: rewriter.getBoolArrayAttr(
3219 values: SmallVector<bool>(vectorType.getRank(), false)));
3220
3221 if (maybeFillOp)
3222 rewriter.eraseOp(op: maybeFillOp);
3223 rewriter.eraseOp(op: copyOp);
3224 rewriter.replaceOp(op: xferOp, newValues: res);
3225
3226 return success();
3227}
3228
3229/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3230/// when available.
3231LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
3232 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3233 // TODO: support mask.
3234 if (xferOp.getMask())
3235 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "unsupported mask");
3236
3237 // Transfer into `viewOrAlloc`.
3238 Value viewOrAlloc = xferOp.getBase();
3239 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3240 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3241 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "source not a view or alloc");
3242
3243 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3244 memref::SubViewOp subViewOp = getSubViewUseIfUnique(v: viewOrAlloc);
3245 if (!subViewOp)
3246 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no subview found");
3247 Value subView = subViewOp.getResult();
3248
3249 // Find the copy from `subView` without interleaved uses.
3250 memref::CopyOp copyOp;
3251 for (auto &u : subViewOp.getResult().getUses()) {
3252 if (auto newCopyOp = dyn_cast<memref::CopyOp>(Val: u.getOwner())) {
3253 if (newCopyOp.getSource() != subView)
3254 continue;
3255 if (mayExistInterleavedUses(firstOp: xferOp, secondOp: newCopyOp, values: {viewOrAlloc, subView}))
3256 continue;
3257 copyOp = newCopyOp;
3258 break;
3259 }
3260 }
3261 if (!copyOp)
3262 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no copy found");
3263
3264 // `out` is the subview copied into that we replace.
3265 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3266 Value out = copyOp.getTarget();
3267
3268 // Forward vector.transfer into copy.
3269 // memref.copy + linalg.fill can be used to create a padded local buffer.
3270 // The `masked` attribute is only valid on this padded buffer.
3271 // When forwarding to vector.transfer_write, the attribute must be reset
3272 // conservatively.
3273 auto vector = xferOp.getVector();
3274 rewriter.create<vector::TransferWriteOp>(
3275 location: xferOp.getLoc(), args&: vector, args&: out, args: xferOp.getIndices(),
3276 args: xferOp.getPermutationMapAttr(), args: xferOp.getMask(),
3277 args: rewriter.getBoolArrayAttr(values: SmallVector<bool>(
3278 dyn_cast<VectorType>(Val: vector.getType()).getRank(), false)));
3279
3280 rewriter.eraseOp(op: copyOp);
3281 rewriter.eraseOp(op: xferOp);
3282
3283 return success();
3284}
3285
3286//===----------------------------------------------------------------------===//
3287// Convolution vectorization patterns
3288//===----------------------------------------------------------------------===//
3289
3290template <int N>
3291static void bindShapeDims(ShapedType shapedType) {}
3292
3293template <int N, typename IntTy, typename... IntTy2>
3294static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3295 val = shapedType.getShape()[N];
3296 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3297}
3298
3299/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3300template <typename... IntTy>
3301static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3302 bindShapeDims<0>(shapedType, vals...);
3303}
3304
3305namespace {
3306/// Generate a vector implementation for either:
3307/// ```
3308/// Op def: ( w, kw )
3309/// Iters: ({Par(), Red()})
3310/// Layout: {{w + kw}, {kw}, {w}}
3311/// ```
3312/// kw is unrolled.
3313///
3314/// or
3315///
3316/// ```
3317/// Op def: ( n, w, c, kw, f )
3318/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3319/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3320/// ```
3321/// kw is unrolled, w is unrolled iff dilationW > 1.
3322///
3323/// or
3324///
3325/// ```
3326/// Op def: ( n, c, w, f, kw )
3327/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3328/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3329/// ```
3330/// kw is unrolled, w is unrolled iff dilationW > 1.
3331///
3332/// or
3333///
3334/// ```
3335/// Op def: ( n, w, c, kw )
3336/// Iters: ({Par(), Par(), Par(), Red()})
3337/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3338/// ```
3339/// kw is unrolled, w is unrolled iff dilationW > 1.
3340struct Conv1DGenerator
3341 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3342 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3343 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3344
3345 lhsShaped = linalgOp.getDpsInputOperand(i: 0)->get();
3346 rhsShaped = linalgOp.getDpsInputOperand(i: 1)->get();
3347 resShaped = linalgOp.getDpsInitOperand(i: 0)->get();
3348 lhsShapedType = dyn_cast<ShapedType>(Val: lhsShaped.getType());
3349 rhsShapedType = dyn_cast<ShapedType>(Val: rhsShaped.getType());
3350 resShapedType = dyn_cast<ShapedType>(Val: resShaped.getType());
3351
3352 Operation *reduceOp = matchLinalgReduction(outputOperand: linalgOp.getDpsInitOperand(i: 0));
3353 redOp = reduceOp->getName().getIdentifier();
3354
3355 setConvOperationKind(reduceOp);
3356
3357 auto maybeKind = getCombinerOpKind(combinerOp: reduceOp);
3358 reductionKind = maybeKind.value();
3359
3360 // The ConvolutionOpInterface gives us guarantees of existence for
3361 // strides/dilations. However, we do not need to rely on those, we can
3362 // simply use them if present, otherwise use the default and let the generic
3363 // conv. matcher in the ConvGenerator succeed or fail.
3364 auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "strides");
3365 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "dilations");
3366 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3367 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3368 }
3369
3370 /// Generate a vector implementation for:
3371 /// ```
3372 /// Op def: ( w, kw )
3373 /// Iters: ({Par(), Red()})
3374 /// Layout: {{w + kw}, {kw}, {w}}
3375 /// ```
3376 /// kw is always unrolled.
3377 ///
3378 /// or
3379 ///
3380 /// ```
3381 /// Op def: ( n, w, c, kw, f )
3382 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3383 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3384 /// ```
3385 /// kw is always unrolled.
3386 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3387 /// > 1.
3388 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3389 int64_t nSize, wSize, cSize, kwSize, fSize;
3390 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3391 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3392 switch (conv1DOpOrder) {
3393 case Conv1DOpOrder::W:
3394 // Initialize unused dimensions
3395 nSize = fSize = cSize = 0;
3396 // out{W}
3397 bindShapeDims(shapedType: resShapedType, vals&: wSize);
3398 // kernel{kw}
3399 bindShapeDims(shapedType: rhsShapedType, vals&: kwSize);
3400 lhsShape = {// iw = ow + kw - 1
3401 // (i.e. 16 convolved with 3 -> 14)
3402 (wSize + kwSize - 1)};
3403 rhsShape = {kwSize};
3404 resShape = {wSize};
3405 break;
3406 case Conv1DOpOrder::Nwc:
3407 // out{n, w, f}
3408 bindShapeDims(shapedType: resShapedType, vals&: nSize, vals&: wSize, vals&: fSize);
3409 switch (oper) {
3410 case ConvOperationKind::Conv:
3411 // kernel{kw, c, f}
3412 bindShapeDims(shapedType: rhsShapedType, vals&: kwSize, vals&: cSize);
3413 break;
3414 case ConvOperationKind::Pool:
3415 // kernel{kw}
3416 bindShapeDims(shapedType: rhsShapedType, vals&: kwSize);
3417 cSize = fSize;
3418 break;
3419 }
3420 lhsShape = {nSize,
3421 // iw = ow * sw + kw * dw - 1
3422 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3423 // Perform the proper inclusive -> exclusive -> inclusive.
3424 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3425 1,
3426 cSize};
3427 switch (oper) {
3428 case ConvOperationKind::Conv:
3429 rhsShape = {kwSize, cSize, fSize};
3430 break;
3431 case ConvOperationKind::Pool:
3432 rhsShape = {kwSize};
3433 break;
3434 }
3435 resShape = {nSize, wSize, fSize};
3436 break;
3437 case Conv1DOpOrder::Ncw:
3438 // out{n, f, w}
3439 bindShapeDims(shapedType: resShapedType, vals&: nSize, vals&: fSize, vals&: wSize);
3440 switch (oper) {
3441 case ConvOperationKind::Conv:
3442 // kernel{f, c, kw}
3443 bindShapeDims(shapedType: rhsShapedType, vals&: fSize, vals&: cSize, vals&: kwSize);
3444 break;
3445 case ConvOperationKind::Pool:
3446 // kernel{kw}
3447 bindShapeDims(shapedType: rhsShapedType, vals&: kwSize);
3448 cSize = fSize;
3449 break;
3450 }
3451 lhsShape = {nSize, cSize,
3452 // iw = ow * sw + kw * dw - 1
3453 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3454 // Perform the proper inclusive -> exclusive -> inclusive.
3455 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3456 1};
3457 switch (oper) {
3458 case ConvOperationKind::Conv:
3459 rhsShape = {fSize, cSize, kwSize};
3460 break;
3461 case ConvOperationKind::Pool:
3462 rhsShape = {kwSize};
3463 break;
3464 }
3465 resShape = {nSize, fSize, wSize};
3466 break;
3467 }
3468
3469 vector::TransferWriteOp write;
3470 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
3471
3472 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3473 // When strideW == 1, we can batch the contiguous loads and avoid
3474 // unrolling
3475 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3476
3477 Type lhsEltType = lhsShapedType.getElementType();
3478 Type rhsEltType = rhsShapedType.getElementType();
3479 Type resEltType = resShapedType.getElementType();
3480 auto lhsType = VectorType::get(shape: lhsShape, elementType: lhsEltType);
3481 auto rhsType = VectorType::get(shape: rhsShape, elementType: rhsEltType);
3482 auto resType = VectorType::get(shape: resShape, elementType: resEltType);
3483 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3484 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3485 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3486 SmallVector<Value> resPadding(resShape.size(), zero);
3487
3488 // Read the whole lhs, rhs and res in one shot (with zero padding).
3489 Value lhs = rewriter.create<vector::TransferReadOp>(
3490 location: loc, args&: lhsType, args&: lhsShaped, args&: lhsPadding,
3491 /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: lhsEltType));
3492 // This is needed only for Conv.
3493 Value rhs = nullptr;
3494 if (oper == ConvOperationKind::Conv)
3495 rhs = rewriter.create<vector::TransferReadOp>(
3496 location: loc, args&: rhsType, args&: rhsShaped, args&: rhsPadding,
3497 /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: rhsEltType));
3498 Value res = rewriter.create<vector::TransferReadOp>(
3499 location: loc, args&: resType, args&: resShaped, args&: resPadding,
3500 /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: resEltType));
3501
3502 // The base vectorization case for channeled convolution is input:
3503 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3504 // vectorization case, we do pre transpose on input, weight, and output.
3505 switch (conv1DOpOrder) {
3506 case Conv1DOpOrder::W:
3507 case Conv1DOpOrder::Nwc:
3508 // Base case, so no transposes necessary.
3509 break;
3510 case Conv1DOpOrder::Ncw: {
3511 // To match base vectorization case, we pre-transpose current case.
3512 // ncw -> nwc
3513 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3514 lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: permLhs);
3515 // fcw -> wcf
3516 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3517
3518 // This is needed only for Conv.
3519 if (oper == ConvOperationKind::Conv)
3520 rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: permRhs);
3521 // nfw -> nwf
3522 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3523 res = rewriter.create<vector::TransposeOp>(location: loc, args&: res, args: permRes);
3524 break;
3525 }
3526 }
3527
3528 //===------------------------------------------------------------------===//
3529 // Begin vector-only rewrite part
3530 //===------------------------------------------------------------------===//
3531 // Unroll along kw and read slices of lhs and rhs.
3532 SmallVector<Value> lhsVals, rhsVals, resVals;
3533 lhsVals = extractConvInputSlices(rewriter, loc, input: lhs, nSize, wSize, cSize,
3534 kwSize, strideW, dilationW, wSizeStep,
3535 isSingleChanneled);
3536 // Do not do for pooling.
3537 if (oper == ConvOperationKind::Conv)
3538 rhsVals = extractConvFilterSlices(rewriter, loc, filter: rhs, kwSize);
3539 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3540 wSizeStep, isSingleChanneled);
3541
3542 auto linearIndex = [&](int64_t kw, int64_t w) {
3543 return kw * (wSize / wSizeStep) + w;
3544 };
3545
3546 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3547 // or perform outerproduct for non-channeled convolution or perform simple
3548 // arith operation for pooling
3549 for (int64_t kw = 0; kw < kwSize; ++kw) {
3550 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3551 switch (oper) {
3552 case ConvOperationKind::Conv:
3553 if (isSingleChanneled) {
3554 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3555 lhs: lhsVals[linearIndex(kw, w)],
3556 rhs: rhsVals[kw], res: resVals[w]);
3557 } else {
3558 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3559 lhs: lhsVals[linearIndex(kw, w)],
3560 rhs: rhsVals[kw], res: resVals[w]);
3561 }
3562 break;
3563 case ConvOperationKind::Pool:
3564 resVals[w] = pool1dSlice(rewriter, loc, lhs: lhsVals[linearIndex(kw, w)],
3565 res: resVals[w]);
3566 break;
3567 }
3568 }
3569 }
3570
3571 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3572 isSingleChanneled);
3573 //===------------------------------------------------------------------===//
3574 // End vector-only rewrite part
3575 //===------------------------------------------------------------------===//
3576
3577 // The base vectorization case for channeled convolution is output:
3578 // {n,w,f} To reuse the result from base pattern vectorization case, we
3579 // post transpose the base case result.
3580 switch (conv1DOpOrder) {
3581 case Conv1DOpOrder::W:
3582 case Conv1DOpOrder::Nwc:
3583 // Base case, so no transposes necessary.
3584 break;
3585 case Conv1DOpOrder::Ncw: {
3586 // nwf -> nfw
3587 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3588 res = rewriter.create<vector::TransposeOp>(location: loc, args&: res, args: perm);
3589 break;
3590 }
3591 }
3592
3593 return rewriter
3594 .create<vector::TransferWriteOp>(location: loc, args&: res, args&: resShaped, args&: resPadding)
3595 .getOperation();
3596 }
3597
3598 // Take a value and widen to have the same element type as `ty`.
3599 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3600 const Type srcElementType = getElementTypeOrSelf(type: val.getType());
3601 const Type dstElementType = getElementTypeOrSelf(type: ty);
3602 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3603 if (srcElementType == dstElementType)
3604 return val;
3605
3606 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3607 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3608 const Type dstType =
3609 cast<ShapedType>(Val: val.getType()).cloneWith(shape: std::nullopt, elementType: dstElementType);
3610
3611 if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) {
3612 return rewriter.create<arith::SIToFPOp>(location: loc, args: dstType, args&: val);
3613 }
3614
3615 if (isa<FloatType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType) &&
3616 srcWidth < dstWidth)
3617 return rewriter.create<arith::ExtFOp>(location: loc, args: dstType, args&: val);
3618
3619 if (isa<IntegerType>(Val: srcElementType) && isa<IntegerType>(Val: dstElementType) &&
3620 srcWidth < dstWidth)
3621 return rewriter.create<arith::ExtSIOp>(location: loc, args: dstType, args&: val);
3622
3623 assert(false && "unhandled promotion case");
3624 return nullptr;
3625 }
3626
3627 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3628 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3629 Value lhs, Value rhs, Value res) {
3630 vector::IteratorType par = vector::IteratorType::parallel;
3631 vector::IteratorType red = vector::IteratorType::reduction;
3632 AffineExpr n, w, f, c;
3633 bindDims(ctx, exprs&: n, exprs&: w, exprs&: f, exprs&: c);
3634 lhs = promote(rewriter, loc, val: lhs, ty: res.getType());
3635 rhs = promote(rewriter, loc, val: rhs, ty: res.getType());
3636 auto contrationOp = rewriter.create<vector::ContractionOp>(
3637 location: loc, args&: lhs, args&: rhs, args&: res,
3638 /*indexingMaps=*/args: MapList{{n, w, c}, {c, f}, {n, w, f}},
3639 /*iteratorTypes=*/args: ArrayRef<vector::IteratorType>{par, par, par, red});
3640 contrationOp.setKind(reductionKind);
3641 return contrationOp;
3642 }
3643
3644 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3645 // convolution.
3646 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3647 Value lhs, Value rhs, Value res) {
3648 return rewriter.create<vector::OuterProductOp>(
3649 location: loc, args: res.getType(), args&: lhs, args&: rhs, args&: res, args: vector::CombiningKind::ADD);
3650 }
3651
3652 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3653 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3654 Value res) {
3655 if (isPoolExt)
3656 lhs = rewriter.create(loc, opName: poolExtOp, operands: lhs, types: res.getType())->getResult(idx: 0);
3657 return rewriter
3658 .create(loc, opName: redOp, operands: ArrayRef<Value>{lhs, res}, types: res.getType())
3659 ->getResult(idx: 0);
3660 }
3661
3662 /// Generate a vector implementation for:
3663 /// ```
3664 /// Op def: ( n, w, c, kw)
3665 /// Iters: ({Par(), Par(), Par(), Red()})
3666 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3667 /// ```
3668 /// kw is always unrolled.
3669 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3670 /// > 1.
3671 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3672 bool channelDimScalableFlag,
3673 bool flatten) {
3674 bool scalableChDim = false;
3675 bool useMasking = false;
3676 int64_t nSize, wSize, cSize, kwSize;
3677 // kernel{kw, c}
3678 bindShapeDims(shapedType: rhsShapedType, vals&: kwSize, vals&: cSize);
3679 if (ShapedType::isDynamic(dValue: cSize)) {
3680 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3681 cSize = channelDimVecSize;
3682 // Scalable vectors are only used when both conditions are met:
3683 // 1. channel dim is dynamic
3684 // 2. channelDimScalableFlag is set
3685 scalableChDim = channelDimScalableFlag;
3686 useMasking = true;
3687 }
3688
3689 assert(!(useMasking && flatten) &&
3690 "Unsupported flattened conv with dynamic shapes");
3691
3692 // out{n, w, c}
3693 bindShapeDims(shapedType: resShapedType, vals&: nSize, vals&: wSize);
3694
3695 vector::TransferWriteOp write;
3696 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
3697
3698 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3699 // When strideW == 1, we can batch the contiguous loads and avoid
3700 // unrolling
3701 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3702
3703 Type lhsEltType = lhsShapedType.getElementType();
3704 Type rhsEltType = rhsShapedType.getElementType();
3705 Type resEltType = resShapedType.getElementType();
3706 VectorType lhsType = VectorType::get(
3707 shape: {nSize,
3708 // iw = ow * sw + kw * dw - 1
3709 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3710 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3711 cSize},
3712 elementType: lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3713 VectorType rhsType =
3714 VectorType::get(shape: {kwSize, cSize}, elementType: rhsEltType,
3715 /*scalableDims=*/{false, scalableChDim});
3716 VectorType resType =
3717 VectorType::get(shape: {nSize, wSize, cSize}, elementType: resEltType,
3718 /*scalableDims=*/{false, false, scalableChDim});
3719
3720 // Masks the input xfer Op along the channel dim, iff the corresponding
3721 // scalable flag is set.
3722 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3723 ArrayRef<bool> scalableDims,
3724 Operation *opToMask) {
3725 if (!useMasking)
3726 return opToMask;
3727 auto maskType =
3728 VectorType::get(shape: maskShape, elementType: rewriter.getI1Type(), scalableDims);
3729
3730 SmallVector<bool> inBounds(maskShape.size(), true);
3731 auto xferOp = cast<VectorTransferOpInterface>(Val: opToMask);
3732 xferOp->setAttr(name: xferOp.getInBoundsAttrName(),
3733 value: rewriter.getBoolArrayAttr(values: inBounds));
3734
3735 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3736 hasTensorSemantics: cast<LinalgOp>(Val: op).hasPureTensorSemantics(), xfer: opToMask, rewriter);
3737
3738 Value maskOp =
3739 rewriter.create<vector::CreateMaskOp>(location: loc, args&: maskType, args&: mixedDims);
3740
3741 return mlir::vector::maskOperation(builder&: rewriter, maskableOp: opToMask, mask: maskOp);
3742 };
3743
3744 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3745 // 0].
3746 Value lhs = rewriter.create<vector::TransferReadOp>(
3747 location: loc, args&: lhsType, args&: lhsShaped, args: ValueRange{zero, zero, zero},
3748 /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: lhsEltType));
3749 auto maybeMaskedLhs = maybeMaskXferOp(
3750 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3751
3752 // Read rhs slice of size {kw, c} @ [0, 0].
3753 Value rhs = rewriter.create<vector::TransferReadOp>(
3754 location: loc, args&: rhsType, args&: rhsShaped, args: ValueRange{zero, zero},
3755 /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: rhsEltType));
3756 auto maybeMaskedRhs = maybeMaskXferOp(
3757 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3758
3759 // Read res slice of size {n, w, c} @ [0, 0, 0].
3760 Value res = rewriter.create<vector::TransferReadOp>(
3761 location: loc, args&: resType, args&: resShaped, args: ValueRange{zero, zero, zero},
3762 /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: resEltType));
3763 auto maybeMaskedRes = maybeMaskXferOp(
3764 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3765
3766 //===------------------------------------------------------------------===//
3767 // Begin vector-only rewrite part
3768 //===------------------------------------------------------------------===//
3769 // Unroll along kw and read slices of lhs and rhs.
3770 SmallVector<Value> lhsVals, rhsVals, resVals;
3771 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3772 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3773
3774 // Extract lhs slice of size {n, wSizeStep, c}
3775 // @ [0, sw * w + dw * kw, 0].
3776 for (int64_t kw = 0; kw < kwSize; ++kw) {
3777 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3778 lhsVals.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>(
3779 location: loc, args: maybeMaskedLhs->getResult(idx: 0),
3780 /*offsets=*/args: ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3781 args&: inOutSliceSizes, args&: inOutStrides));
3782 }
3783 }
3784 // Extract rhs slice of size {c} @ [kw].
3785 for (int64_t kw = 0; kw < kwSize; ++kw) {
3786 rhsVals.push_back(Elt: rewriter.create<vector::ExtractOp>(
3787 location: loc, args: maybeMaskedRhs->getResult(idx: 0),
3788 /*offsets=*/args: ArrayRef<int64_t>{kw}));
3789 }
3790 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3791 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3792 resVals.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>(
3793 location: loc, args: maybeMaskedRes->getResult(idx: 0),
3794 /*offsets=*/args: ArrayRef<int64_t>{0, w, 0}, args&: inOutSliceSizes,
3795 args&: inOutStrides));
3796 }
3797
3798 auto linearIndex = [&](int64_t kw, int64_t w) {
3799 return kw * (wSize / wSizeStep) + w;
3800 };
3801
3802 // Note - the scalable flags are ignored as flattening combined with
3803 // scalable vectorization is not supported.
3804 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3805 auto lhsTypeAfterFlattening =
3806 VectorType::get(shape: inOutFlattenSliceSizes, elementType: lhsEltType);
3807 auto resTypeAfterFlattening =
3808 VectorType::get(shape: inOutFlattenSliceSizes, elementType: resEltType);
3809
3810 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3811 for (int64_t kw = 0; kw < kwSize; ++kw) {
3812 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3813 Value lhsVal = lhsVals[linearIndex(kw, w)];
3814 Value resVal = resVals[w];
3815 if (flatten) {
3816 // Flatten the input and output vectors (collapse the channel
3817 // dimension)
3818 lhsVal = rewriter.create<vector::ShapeCastOp>(
3819 location: loc, args&: lhsTypeAfterFlattening, args&: lhsVals[linearIndex(kw, w)]);
3820 resVal = rewriter.create<vector::ShapeCastOp>(
3821 location: loc, args&: resTypeAfterFlattening, args&: resVals[w]);
3822 }
3823 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhs: lhsVal,
3824 rhs: rhsVals[kw], res: resVal, flatten);
3825 if (flatten) {
3826 // Un-flatten the output vector (restore the channel dimension)
3827 resVals[w] = rewriter.create<vector::ShapeCastOp>(
3828 location: loc, args: VectorType::get(shape: inOutSliceSizes, elementType: resEltType), args&: resVals[w]);
3829 }
3830 }
3831 }
3832
3833 // Its possible we failed to create the Fma.
3834 if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) {
3835 // Manually revert (in reverse order) to avoid leaving a bad IR state.
3836 for (auto &collection :
3837 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3838 for (Value v : collection)
3839 rewriter.eraseOp(op: v.getDefiningOp());
3840 return rewriter.notifyMatchFailure(arg&: op, msg: "failed to create FMA");
3841 }
3842
3843 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3844 // This does not depend on kw.
3845 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3846 maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3847 location: loc, args&: resVals[w], args: maybeMaskedRes->getResult(idx: 0),
3848 /*offsets=*/args: ArrayRef<int64_t>{0, w, 0},
3849 /*strides=*/args: ArrayRef<int64_t>{1, 1, 1});
3850 }
3851 //===------------------------------------------------------------------===//
3852 // End vector-only rewrite part
3853 //===------------------------------------------------------------------===//
3854
3855 // Write back res slice of size {n, w, c} @ [0, 0, 0].
3856 Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3857 location: loc, args: maybeMaskedRes->getResult(idx: 0), args&: resShaped,
3858 args: ValueRange{zero, zero, zero});
3859 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3860 resOut);
3861 }
3862
3863 /// Lower:
3864 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3865 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3866 /// to MulAcc.
3867 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3868 Value lhs, Value rhs, Value res,
3869 bool flatten) {
3870 auto rhsTy = cast<ShapedType>(Val: rhs.getType());
3871 auto resTy = cast<ShapedType>(Val: res.getType());
3872
3873 // TODO(suderman): Change this to use a vector.ima intrinsic.
3874 lhs = promote(rewriter, loc, val: lhs, ty: resTy);
3875
3876 if (flatten) {
3877 // NOTE: This following logic won't work for scalable vectors. For this
3878 // reason, "flattening" is not supported when shapes are dynamic (this
3879 // should be captured by one of the pre-conditions).
3880
3881 // There are two options for handling the filter:
3882 // * shape_cast(broadcast(filter))
3883 // * broadcast(shuffle(filter))
3884 // Opt for the option without shape_cast to simplify the codegen.
3885 auto rhsSize = cast<VectorType>(Val: rhs.getType()).getShape()[0];
3886 auto resSize = cast<VectorType>(Val: res.getType()).getShape()[1];
3887
3888 SmallVector<int64_t, 16> indices;
3889 for (int i = 0; i < resSize / rhsSize; ++i) {
3890 for (int j = 0; j < rhsSize; ++j)
3891 indices.push_back(Elt: j);
3892 }
3893
3894 rhs = rewriter.create<vector::ShuffleOp>(location: loc, args&: rhs, args&: rhs, args&: indices);
3895 }
3896 // Broadcast the filter to match the output vector
3897 rhs = rewriter.create<vector::BroadcastOp>(
3898 location: loc, args: resTy.clone(elementType: rhsTy.getElementType()), args&: rhs);
3899
3900 rhs = promote(rewriter, loc, val: rhs, ty: resTy);
3901
3902 if (!lhs || !rhs)
3903 return nullptr;
3904
3905 if (isa<FloatType>(Val: resTy.getElementType()))
3906 return rewriter.create<vector::FMAOp>(location: loc, args&: lhs, args&: rhs, args&: res);
3907
3908 auto mul = rewriter.create<arith::MulIOp>(location: loc, args&: lhs, args&: rhs);
3909 return rewriter.create<arith::AddIOp>(location: loc, args&: mul, args&: res);
3910 }
3911
3912 /// Entry point for non-channeled convolution:
3913 /// {{w + kw}, {kw}, {w}}
3914 FailureOr<Operation *> generateNonChanneledConv() {
3915 AffineExpr w, kw;
3916 bindDims(ctx, exprs&: w, exprs&: kw);
3917 if (!iters(its: {Par(), Red()}))
3918 return rewriter.notifyMatchFailure(arg&: op,
3919 msg: "failed to match conv::W 1-par 1-red");
3920
3921 // No transposition needed.
3922 if (layout(l: {/*lhsIndex*/ {w + kw},
3923 /*rhsIndex*/ {kw},
3924 /*resIndex*/ {w}}))
3925 return conv(conv1DOpOrder: Conv1DOpOrder::W);
3926
3927 return rewriter.notifyMatchFailure(arg&: op, msg: "not a conv::W layout");
3928 }
3929
3930 /// Entry point that transposes into the common form:
3931 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3932 FailureOr<Operation *> generateNwcConv() {
3933 AffineExpr n, w, f, kw, c;
3934 bindDims(ctx, exprs&: n, exprs&: w, exprs&: f, exprs&: kw, exprs&: c);
3935 if (!iters(its: {Par(), Par(), Par(), Red(), Red()}))
3936 return rewriter.notifyMatchFailure(
3937 arg&: op, msg: "failed to match conv::Nwc 3-par 2-red");
3938
3939 // No transposition needed.
3940 if (layout(l: {/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3941 /*rhsIndex*/ {kw, c, f},
3942 /*resIndex*/ {n, w, f}}))
3943 return conv(conv1DOpOrder: Conv1DOpOrder::Nwc);
3944
3945 return rewriter.notifyMatchFailure(arg&: op, msg: "not a conv::Nwc layout");
3946 }
3947
3948 /// Entry point that transposes into the common form:
3949 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3950 FailureOr<Operation *> generateNcwConv() {
3951 AffineExpr n, w, f, kw, c;
3952 bindDims(ctx, exprs&: n, exprs&: f, exprs&: w, exprs&: c, exprs&: kw);
3953 if (!iters(its: {Par(), Par(), Par(), Red(), Red()}))
3954 return rewriter.notifyMatchFailure(
3955 arg&: op, msg: "failed to match conv::Ncw 3-par 2-red");
3956
3957 if (layout(l: {/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3958 /*rhsIndex*/ {f, c, kw},
3959 /*resIndex*/ {n, f, w}}))
3960 return conv(conv1DOpOrder: Conv1DOpOrder::Ncw);
3961
3962 return rewriter.notifyMatchFailure(arg&: op, msg: "not a conv::Ncw layout");
3963 }
3964
3965 /// Entry point that transposes into the common form:
3966 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3967 FailureOr<Operation *> generateNwcPooling() {
3968 AffineExpr n, w, c, kw;
3969 bindDims(ctx, exprs&: n, exprs&: w, exprs&: c, exprs&: kw);
3970 if (!iters(its: {Par(), Par(), Par(), Red()}))
3971 return rewriter.notifyMatchFailure(arg&: op,
3972 msg: "failed to match pooling 3-par 1-red");
3973
3974 // No transposition needed.
3975 if (layout(l: {/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3976 /*rhsIndex*/ {kw},
3977 /*resIndex*/ {n, w, c}}))
3978 return conv(conv1DOpOrder: Conv1DOpOrder::Nwc);
3979
3980 return rewriter.notifyMatchFailure(arg&: op, msg: "not a pooling::Nwc layout");
3981 }
3982
3983 /// Entry point that transposes into the common form:
3984 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3985 FailureOr<Operation *> generateNcwPooling() {
3986 AffineExpr n, w, c, kw;
3987 bindDims(ctx, exprs&: n, exprs&: c, exprs&: w, exprs&: kw);
3988 if (!iters(its: {Par(), Par(), Par(), Red()}))
3989 return rewriter.notifyMatchFailure(arg&: op,
3990 msg: "failed to match pooling 3-par 1-red");
3991
3992 if (layout(l: {/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3993 /*rhsIndex*/ {kw},
3994 /*resIndex*/ {n, c, w}}))
3995 return conv(conv1DOpOrder: Conv1DOpOrder::Ncw);
3996
3997 return rewriter.notifyMatchFailure(arg&: op, msg: "not a pooling::Ncw layout");
3998 }
3999
4000 /// Entry point that transposes into the common form:
4001 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4002 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4003 bool vecChDimScalableFlag = false,
4004 bool flatten = false) {
4005 AffineExpr n, w, c, kw;
4006 bindDims(ctx, exprs&: n, exprs&: w, exprs&: c, exprs&: kw);
4007 if (!iters(its: {Par(), Par(), Par(), Red()}))
4008 return rewriter.notifyMatchFailure(
4009 arg&: op, msg: "failed to match depthwise::Nwc conv 3-par 1-red");
4010
4011 // No transposition needed.
4012 if (layout(l: {/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4013 /*rhsIndex*/ {kw, c},
4014 /*resIndex*/ {n, w, c}}))
4015 return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten);
4016
4017 return rewriter.notifyMatchFailure(arg&: op, msg: "not a depthwise::Nwc layout");
4018 }
4019
4020private:
4021 ConvOperationKind oper = ConvOperationKind::Conv;
4022 StringAttr redOp;
4023 StringAttr poolExtOp;
4024 bool isPoolExt = false;
4025 int strideW, dilationW;
4026 Value lhsShaped, rhsShaped, resShaped;
4027 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4028 vector::CombiningKind reductionKind;
4029
4030 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4031 void setConvOperationKind(Operation *reduceOp) {
4032 int numBlockArguments =
4033 llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>);
4034 if (numBlockArguments == 1) {
4035 // Will be convolution if feeder is a MulOp.
4036 // A strength reduced version of MulOp for i1 type is AndOp which is also
4037 // supported. Otherwise, it can be pooling. This strength reduction logic
4038 // is in `buildBinaryFn` helper in the Linalg dialect.
4039 auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(),
4040 P: llvm::IsaPred<BlockArgument>);
4041 Operation *feedOp = (*feedValIt).getDefiningOp();
4042 if (isCastOfBlockArgument(op: feedOp)) {
4043 oper = ConvOperationKind::Pool;
4044 isPoolExt = true;
4045 poolExtOp = feedOp->getName().getIdentifier();
4046 return;
4047 }
4048 oper = ConvOperationKind::Conv;
4049 return;
4050 }
4051 // numBlockArugments == 2 and this is a pooling op.
4052 oper = ConvOperationKind::Pool;
4053 isPoolExt = false;
4054 }
4055};
4056} // namespace
4057
4058/// Helper function to vectorize a LinalgOp with convolution semantics.
4059// TODO: extend the generic vectorization to support windows and drop this.
4060static FailureOr<Operation *> vectorizeConvolution(
4061 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4062 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4063 Conv1DGenerator conv1dGen(rewriter, op);
4064 auto res = conv1dGen.generateNonChanneledConv();
4065 if (succeeded(Result: res))
4066 return res;
4067 res = conv1dGen.generateNwcConv();
4068 if (succeeded(Result: res))
4069 return res;
4070 res = conv1dGen.generateNcwConv();
4071 if (succeeded(Result: res))
4072 return res;
4073 res = conv1dGen.generateNwcPooling();
4074 if (succeeded(Result: res))
4075 return res;
4076 res = conv1dGen.generateNcwPooling();
4077 if (succeeded(Result: res))
4078 return res;
4079
4080 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4081 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4082 // masked/scalable) is the channel dim (i.e. the trailing dim).
4083 uint64_t vecChDimSize = ShapedType::kDynamic;
4084 bool vecChDimScalableFlag = false;
4085 if (!inputVecSizes.empty()) {
4086 // Only use the input vector size corresponding to the channel dim. Other
4087 // vector dims will be inferred from the Ops.
4088 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4089 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4090 "Not a 1D depthwise conv!");
4091 size_t chDimIdx =
4092 TypeSwitch<Operation *, size_t>(op)
4093 .Case<linalg::DepthwiseConv1DNwcWcOp>(caseFn: [](auto conv) { return 2; })
4094 .Case<linalg::DepthwiseConv1DNcwCwOp>(caseFn: [](auto conv) { return 1; });
4095
4096 vecChDimSize = inputVecSizes[chDimIdx];
4097 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4098 }
4099 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4100 flatten: flatten1DDepthwiseConv);
4101}
4102
4103struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4104 using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
4105
4106 LogicalResult matchAndRewrite(LinalgOp op,
4107 PatternRewriter &rewriter) const override {
4108 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4109 if (failed(Result: resultOrFail))
4110 return failure();
4111 Operation *newOp = *resultOrFail;
4112 if (newOp->getNumResults() == 0) {
4113 rewriter.eraseOp(op: op.getOperation());
4114 return success();
4115 }
4116 assert(newOp->getNumResults() == 1 && "expected single result");
4117 rewriter.replaceOp(op: op.getOperation(), newValues: newOp->getResult(idx: 0));
4118 return success();
4119 }
4120};
4121
4122void mlir::linalg::populateConvolutionVectorizationPatterns(
4123 RewritePatternSet &patterns, PatternBenefit benefit) {
4124 patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit);
4125}
4126

source code of mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp