1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
12#include "mlir/Dialect/Affine/Utils.h"
13
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/Linalg/Utils/Utils.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tensor/Utils/Utils.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinTypeInterfaces.h"
31#include "mlir/IR/BuiltinTypes.h"
32#include "mlir/IR/OpDefinition.h"
33#include "mlir/IR/PatternMatch.h"
34#include "mlir/Support/LLVM.h"
35#include "mlir/Transforms/RegionUtils.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/Sequence.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/ADT/iterator_range.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/MathExtras.h"
43#include "llvm/Support/raw_ostream.h"
44#include <optional>
45#include <type_traits>
46
47using namespace mlir;
48using namespace mlir::linalg;
49
50#define DEBUG_TYPE "linalg-vectorization"
51
52#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54
55/// Try to vectorize `convOp` as a convolution.
56static FailureOr<Operation *>
57vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58 ArrayRef<int64_t> inputVecSizes = {},
59 ArrayRef<bool> inputVecScalableFlags = {},
60 bool flatten1DDepthwiseConv = false);
61
62/// Vectorize tensor::InsertSliceOp with:
63/// * vector::TransferReadOp + vector::TransferWriteOp
64/// The vector sizes are either:
65/// * user-provided in `inputVectorSizes`, or
66/// * inferred from the static dims in the input and output tensors.
67/// Bails out if:
68/// * vector sizes are not user-provided, and
69/// * at least one dim is dynamic (in both the input and output tensors).
70///
71/// Before:
72/// !t_in_type = tensor<1x2x3xf32>
73/// !t_out_type = tensor<9x8x7x1x2x3xf32>
74/// !v_type = vector<1x2x3xf32>
75/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
76/// into !t_out_type
77/// After:
78/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
79/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
80static LogicalResult
81vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
82 ArrayRef<int64_t> inputVectorSizes,
83 SmallVectorImpl<Value> &newResults);
84
85/// Returns the effective Pad value for the input op, provided it's a scalar.
86///
87/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
88/// this Op performs padding, retrieve the padding value provided that it's
89/// a scalar and static/fixed for all the padded values. Returns an empty value
90/// otherwise.
91static Value getStaticPadVal(Operation *op);
92
93/// Return the unique instance of OpType in `block` if it is indeed unique.
94/// Return null if none or more than 1 instances exist.
95template <typename OpType>
96static OpType getSingleOpOfType(Block &block) {
97 OpType res;
98 block.walk([&](OpType op) {
99 if (res) {
100 res = nullptr;
101 return WalkResult::interrupt();
102 }
103 res = op;
104 return WalkResult::advance();
105 });
106 return res;
107}
108
109/// Helper function to extract the input slices after filter is unrolled along
110/// kw.
111static SmallVector<Value>
112extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
113 int64_t nSize, int64_t wSize, int64_t cSize,
114 int64_t kwSize, int strideW, int dilationW,
115 int64_t wSizeStep, bool isSingleChanneled) {
116 SmallVector<Value> result;
117 if (isSingleChanneled) {
118 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
119 // convolution.
120 SmallVector<int64_t> sizes = {wSizeStep};
121 SmallVector<int64_t> strides = {1};
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 for (int64_t w = 0; w < wSize; w += wSizeStep) {
124 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
125 loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
126 }
127 }
128 } else {
129 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
130 // for channeled convolution.
131 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
132 SmallVector<int64_t> strides = {1, 1, 1};
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
136 loc, input,
137 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138 sizes, strides));
139 }
140 }
141 }
142 return result;
143}
144
145/// Helper function to extract the filter slices after filter is unrolled along
146/// kw.
147static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
148 Location loc, Value filter,
149 int64_t kwSize) {
150 SmallVector<Value> result;
151 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
152 // non-chanelled convolution] @ [kw].
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(rewriter.create<vector::ExtractOp>(
155 loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
156 }
157 return result;
158}
159
160/// Helper function to extract the result slices after filter is unrolled along
161/// kw.
162static SmallVector<Value>
163extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep, bool isSingleChanneled) {
166 SmallVector<Value> result;
167 if (isSingleChanneled) {
168 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
169 SmallVector<int64_t> sizes = {wSizeStep};
170 SmallVector<int64_t> strides = {1};
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
173 loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
174 }
175 } else {
176 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
177 // convolution.
178 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
179 SmallVector<int64_t> strides = {1, 1, 1};
180 for (int64_t w = 0; w < wSize; w += wSizeStep) {
181 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
182 loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
183 }
184 }
185 return result;
186}
187
188/// Helper function to insert the computed result slices.
189static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
190 Value res, int64_t wSize, int64_t wSizeStep,
191 SmallVectorImpl<Value> &resVals,
192 bool isSingleChanneled) {
193
194 if (isSingleChanneled) {
195 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196 // This does not depend on kw.
197 SmallVector<int64_t> strides = {1};
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = rewriter.create<vector::InsertStridedSliceOp>(
200 loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
201 }
202 } else {
203 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
204 // convolution. This does not depend on kw.
205 SmallVector<int64_t> strides = {1, 1, 1};
206 for (int64_t w = 0; w < wSize; w += wSizeStep) {
207 res = rewriter.create<vector::InsertStridedSliceOp>(
208 loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
209 strides);
210 }
211 }
212 return res;
213}
214
215/// Contains the vectorization state and related methods used across the
216/// vectorization process of a given operation.
217struct VectorizationState {
218 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
219
220 /// Initializes the vectorization state, including the computation of the
221 /// canonical vector shape for vectorization.
222 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
223 ArrayRef<int64_t> inputVectorSizes,
224 ArrayRef<bool> inputScalableVecDims);
225
226 /// Returns the canonical vector shape used to vectorize the iteration space.
227 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
228
229 /// Returns the vector dimensions that are scalable in the canonical vector
230 /// shape.
231 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
232
233 /// Returns a vector type of the provided `elementType` with the canonical
234 /// vector shape and the corresponding fixed/scalable dimensions bit. If
235 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
236 /// accordingly.
237 VectorType getCanonicalVecType(
238 Type elementType,
239 std::optional<AffineMap> dimPermutation = std::nullopt) const {
240 SmallVector<int64_t> vectorShape;
241 SmallVector<bool> scalableDims;
242 if (dimPermutation.has_value()) {
243 vectorShape =
244 applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape);
245 scalableDims =
246 applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims);
247 } else {
248 vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end());
249 scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end());
250 }
251
252 return VectorType::get(vectorShape, elementType, scalableDims);
253 }
254
255 /// Masks an operation with the canonical vector mask if the operation needs
256 /// masking. Returns the masked operation or the original operation if masking
257 /// is not needed. If provided, the canonical mask for this operation is
258 /// permuted using `maybeIndexingMap`.
259 Operation *
260 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
261 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
262
263private:
264 /// Initializes the iteration space static sizes using the Linalg op
265 /// information. This may become more complicated in the future.
266 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
267 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
268 }
269
270 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
271 /// all the static and dynamic dimensions of the iteration space to be
272 /// vectorized and store them in `iterSpaceValueSizes`.
273 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
274 LinalgOp linalgOp);
275
276 /// Create or retrieve an existing mask value to mask `opToMask` in the
277 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
278 /// permuted using that permutation map. If a new mask is created, it will be
279 /// cached for future users.
280 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
281 LinalgOp linalgOp,
282 std::optional<AffineMap> maybeMaskingMap);
283
284 /// Check whether this permutation map can be used for masking. At the
285 /// moment we only make sure that there are no broadcast dimensions, but this
286 /// might change if indexing maps evolve.
287 bool isValidMaskingMap(AffineMap maskingMap) {
288 return maskingMap.getBroadcastDims().size() == 0;
289 }
290
291 /// Turn the input indexing map into a valid masking map.
292 ///
293 /// The input indexing map may contain "zero" results, e.g.:
294 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
295 /// Applying such maps to canonical vector shapes like this one:
296 /// (1, 16, 16, 4)
297 /// would yield an invalid vector shape like this:
298 /// (16, 16, 1, 0)
299 /// Instead, drop the broadcasting dims that make no sense for masking perm.
300 /// maps:
301 /// (d0, d1, d2, d3) -> (d2, d1, d0)
302 /// This way, the corresponding vector/mask type will be:
303 /// vector<16x16x1xty>
304 /// rather than this invalid Vector type:
305 /// vector<16x16x1x0xty>
306 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
307 return indexingMap.dropZeroResults();
308 }
309
310 // Holds the compile-time static sizes of the iteration space to vectorize.
311 // Dynamic dimensions are represented using ShapedType::kDynamic.
312 SmallVector<int64_t> iterSpaceStaticSizes;
313
314 /// Holds the value sizes of the iteration space to vectorize. Static
315 /// dimensions are represented by 'arith.constant' and dynamic
316 /// dimensions by 'tensor/memref.dim'.
317 SmallVector<Value> iterSpaceValueSizes;
318
319 /// Holds the canonical vector shape used to vectorize the iteration space.
320 SmallVector<int64_t> canonicalVecShape;
321
322 /// Holds the vector dimensions that are scalable in the canonical vector
323 /// shape.
324 SmallVector<bool> scalableVecDims;
325
326 /// Holds the active masks for permutations of the canonical vector iteration
327 /// space.
328 DenseMap<AffineMap, Value> activeMaskCache;
329
330 /// Global vectorization guard for the incoming rewriter. It's initialized
331 /// when the vectorization state is initialized.
332 OpBuilder::InsertionGuard rewriterGuard;
333};
334
335LogicalResult
336VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
337 LinalgOp linalgOp) {
338 // TODO: Support 0-d vectors.
339 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
340 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
341 // Create constant index op for static dimensions.
342 iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(
343 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
344 continue;
345 }
346
347 // Find an operand defined on this dimension of the iteration space to
348 // extract the runtime dimension size.
349 Value operand;
350 unsigned operandDimPos;
351 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
352 operandDimPos)))
353 return failure();
354
355 Value dynamicDim = linalgOp.hasPureTensorSemantics()
356 ? (Value)rewriter.create<tensor::DimOp>(
357 linalgOp.getLoc(), operand, operandDimPos)
358 : (Value)rewriter.create<memref::DimOp>(
359 linalgOp.getLoc(), operand, operandDimPos);
360 iterSpaceValueSizes.push_back(Elt: dynamicDim);
361 }
362
363 return success();
364}
365
366/// Initializes the vectorization state, including the computation of the
367/// canonical vector shape for vectorization.
368// TODO: Move this to the constructor when we can remove the failure cases.
369LogicalResult
370VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
371 ArrayRef<int64_t> inputVectorSizes,
372 ArrayRef<bool> inputScalableVecDims) {
373 // Initialize the insertion point.
374 rewriter.setInsertionPoint(linalgOp);
375
376 if (!inputVectorSizes.empty()) {
377 // Get the canonical vector shape from the input vector sizes provided. This
378 // path should be taken to vectorize code with dynamic shapes and when using
379 // vector sizes greater than the iteration space sizes.
380 canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end());
381 scalableVecDims.append(in_start: inputScalableVecDims.begin(),
382 in_end: inputScalableVecDims.end());
383 } else {
384 // Compute the canonical vector shape from the operation shape. If there are
385 // dynamic shapes, the operation won't be vectorized. We assume all the
386 // vector dimensions are fixed.
387 canonicalVecShape = linalgOp.getStaticLoopRanges();
388 scalableVecDims.append(linalgOp.getNumLoops(), false);
389 }
390
391 LDBG("Canonical vector shape: ");
392 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
393 LLVM_DEBUG(llvm::dbgs() << "\n");
394 LDBG("Scalable vector dims: ");
395 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
396 LLVM_DEBUG(llvm::dbgs() << "\n");
397
398 if (ShapedType::isDynamicShape(canonicalVecShape))
399 return failure();
400
401 // Initialize iteration space static sizes.
402 initIterSpaceStaticSizes(linalgOp: linalgOp);
403
404 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
405 // all the static and dynamic dimensions of the iteration space, needed to
406 // compute a mask during vectorization.
407 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp: linalgOp)))
408 return failure();
409
410 return success();
411}
412
413/// Create or retrieve an existing mask value to mask `opToMask` in the
414/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
415/// using that permutation map. If a new mask is created, it will be cached for
416/// future users.
417Value VectorizationState::getOrCreateMaskFor(
418 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
419 std::optional<AffineMap> maybeMaskingMap) {
420
421 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
422 "Ill-formed masking map.");
423
424 // No mask is needed if the operation is not maskable.
425 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
426 if (!maskableOp)
427 return Value();
428
429 assert(!maskableOp.isMasked() &&
430 "Masking an operation that is already masked");
431
432 // If no masking map was provided, use an identity map with the loop dims.
433 assert((!maybeMaskingMap || *maybeMaskingMap) &&
434 "Unexpected null mask permutation map");
435 AffineMap maskingMap =
436 maybeMaskingMap ? *maybeMaskingMap
437 : AffineMap::getMultiDimIdentityMap(
438 numDims: linalgOp.getNumLoops(), context: rewriter.getContext());
439
440 LDBG("Masking map: " << maskingMap << "\n");
441
442 // Return the active mask for the masking map of this operation if it was
443 // already created.
444 auto activeMaskIt = activeMaskCache.find(Val: maskingMap);
445 if (activeMaskIt != activeMaskCache.end()) {
446 Value mask = activeMaskIt->second;
447 LDBG("Reusing mask: " << mask << "\n");
448 return mask;
449 }
450
451 // Compute permuted projection of the iteration space to be masked and the
452 // corresponding mask shape. If the resulting iteration space dimensions are
453 // static and identical to the mask shape, masking is not needed for this
454 // operation.
455 // TODO: Improve this check. Only projected permutation indexing maps are
456 // supported.
457 SmallVector<int64_t> permutedStaticSizes =
458 applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes);
459 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
460 auto maskShape = maskType.getShape();
461
462 LDBG("Mask shape: ");
463 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
464 LLVM_DEBUG(llvm::dbgs() << "\n");
465
466 if (permutedStaticSizes == maskShape) {
467 LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
468 activeMaskCache[maskingMap] = Value();
469 return Value();
470 }
471
472 // Permute the iteration space value sizes to compute the mask upper bounds.
473 SmallVector<Value> upperBounds =
474 applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes));
475 assert(!maskShape.empty() && !upperBounds.empty() &&
476 "Masked 0-d vectors are not supported yet");
477
478 // Create the mask based on the dimension values.
479 Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
480 maskType, upperBounds);
481 LDBG("Creating new mask: " << mask << "\n");
482 activeMaskCache[maskingMap] = mask;
483 return mask;
484}
485
486Operation *
487VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
488 LinalgOp linalgOp,
489 std::optional<AffineMap> maybeIndexingMap) {
490 LDBG("Trying to mask: " << *opToMask << "\n");
491
492 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
493 if (maybeIndexingMap)
494 maybeMaskingMap = getMaskingMapFromIndexingMap(indexingMap&: *maybeIndexingMap);
495
496 // Create or retrieve mask for this operation.
497 Value mask =
498 getOrCreateMaskFor(rewriter, opToMask, linalgOp: linalgOp, maybeMaskingMap);
499
500 if (!mask) {
501 LDBG("No mask required\n");
502 return opToMask;
503 }
504
505 // Wrap the operation with a new `vector.mask` and update D-U chain.
506 assert(opToMask && "Expected a valid operation to mask");
507 auto maskOp = cast<vector::MaskOp>(
508 mlir::vector::maskOperation(rewriter, opToMask, mask));
509 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
510
511 for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults()))
512 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
513 maskOpTerminator);
514
515 LDBG("Masked operation: " << *maskOp << "\n");
516 return maskOp;
517}
518
519/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
520/// projectedPermutation, compress the unused dimensions to serve as a
521/// permutation_map for a vector transfer operation.
522/// For example, given a linalg op such as:
523///
524/// ```
525/// %0 = linalg.generic {
526/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
527/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
528/// }
529/// ins(%0 : tensor<2x3x4xf32>)
530/// outs(%1 : tensor<5x6xf32>)
531/// ```
532///
533/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
534/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
535/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
536static AffineMap reindexIndexingMap(AffineMap map) {
537 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
538 "expected projected permutation");
539 auto res = compressUnusedDims(map);
540 assert(res.getNumDims() ==
541 (res.getNumResults() - res.getNumOfZeroResults()) &&
542 "expected reindexed map with same number of dims and results");
543 return res;
544}
545
546/// Helper enum to represent conv1d input traversal order.
547enum class Conv1DOpOrder {
548 W, // Corresponds to non-channeled 1D convolution operation.
549 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
550 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
551};
552
553/// Helper data structure to represent the result of vectorization.
554/// In certain specific cases, like terminators, we do not want to propagate/
555enum VectorizationStatus {
556 /// Op failed to vectorize.
557 Failure = 0,
558 /// Op vectorized and custom function took care of replacement logic
559 NoReplace,
560 /// Op vectorized into a new Op whose results will replace original Op's
561 /// results.
562 NewOp
563 // TODO: support values if Op vectorized to Many-Ops whose results we need to
564 // aggregate for replacement.
565};
566struct VectorizationResult {
567 /// Return status from vectorizing the current op.
568 enum VectorizationStatus status = VectorizationStatus::Failure;
569 /// New vectorized operation to replace the current op.
570 /// Replacement behavior is specified by `status`.
571 Operation *newOp;
572};
573
574std::optional<vector::CombiningKind>
575mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
576 using ::mlir::vector::CombiningKind;
577
578 if (!combinerOp)
579 return std::nullopt;
580 return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
581 .Case<arith::AddIOp, arith::AddFOp>(
582 [&](auto op) { return CombiningKind::ADD; })
583 .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
584 .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
585 .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
586 .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
587 .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
588 .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
589 .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
590 .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
591 .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
592 .Case<arith::MulIOp, arith::MulFOp>(
593 [&](auto op) { return CombiningKind::MUL; })
594 .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
595 .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
596 .Default([&](auto op) { return std::nullopt; });
597}
598
599/// Check whether `outputOperand` is a reduction with a single combiner
600/// operation. Return the combiner operation of the reduction. Return
601/// nullptr otherwise. Multiple reduction operations would impose an
602/// ordering between reduction dimensions and is currently unsupported in
603/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
604/// max(min(X))
605// TODO: use in LinalgOp verification, there is a circular dependency atm.
606static Operation *matchLinalgReduction(OpOperand *outputOperand) {
607 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
608 unsigned outputPos =
609 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
610 // Only single combiner operations are supported for now.
611 SmallVector<Operation *, 4> combinerOps;
612 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
613 combinerOps.size() != 1)
614 return nullptr;
615
616 // Return the combiner operation.
617 return combinerOps[0];
618}
619
620/// Broadcast `value` to a vector of `shape` if possible. Return value
621/// otherwise.
622static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
623 auto dstVecType = dyn_cast<VectorType>(dstType);
624 // If no shape to broadcast to, just return `value`.
625 if (dstVecType.getRank() == 0)
626 return value;
627 if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) !=
628 vector::BroadcastableToResult::Success)
629 return value;
630 Location loc = b.getInsertionPoint()->getLoc();
631 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
632}
633
634/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
635/// assumes that `reductionOp` has two operands and one of them is the reduction
636/// initial value.buildMultiDimReduce
637// Note: this is a true builder that notifies the OpBuilder listener.
638// TODO: Consider moving as a static helper on the ReduceOp.
639static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
640 Value valueToReduce, Value acc,
641 ArrayRef<bool> dimsToMask) {
642 auto maybeKind = getCombinerOpKind(reduceOp);
643 assert(maybeKind && "Failed precondition: could not get reduction kind");
644 return b.create<vector::MultiDimReductionOp>(
645 reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
646}
647
648static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
649 return llvm::to_vector(
650 llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
651}
652
653/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
654/// reduction iterator.
655static bool hasReductionIterator(LinalgOp &op) {
656 return isa<linalg::ReduceOp>(op) ||
657 (isa<linalg::GenericOp>(op) &&
658 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
659}
660
661/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
662/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
663/// currently being vectorized. If `dest` has null rank, build an memref.store.
664/// Return the produced value or null if no value is produced.
665// Note: this is a true builder that notifies the OpBuilder listener.
666// TODO: Consider moving as a static helper on the ReduceOp.
667static Value buildVectorWrite(RewriterBase &rewriter, Value value,
668 OpOperand *outputOperand,
669 VectorizationState &state) {
670 Location loc = value.getLoc();
671 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
672 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
673
674 // Compute the vector type of the value to store. This type should be an
675 // identity or projection of the canonical vector type without any permutation
676 // applied, given that any permutation in a transfer write happens as part of
677 // the write itself.
678 AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
679 ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(),
680 keepDimFilter: [&](AffineDimExpr dimExpr) -> bool {
681 return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr);
682 });
683 auto vectorType = state.getCanonicalVecType(
684 getElementTypeOrSelf(type: outputOperand->get().getType()), vectorTypeMap);
685
686 Operation *write;
687 if (vectorType.getRank() > 0) {
688 AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap));
689 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
690 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
691 value = broadcastIfNeeded(rewriter, value, vectorType);
692 assert(value.getType() == vectorType && "Incorrect type");
693 write = rewriter.create<vector::TransferWriteOp>(
694 loc, value, outputOperand->get(), indices, writeMap);
695 } else {
696 // 0-d case is still special: do not invert the reindexing writeMap.
697 if (!isa<VectorType>(value.getType()))
698 value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
699 assert(value.getType() == vectorType && "Incorrect type");
700 write = rewriter.create<vector::TransferWriteOp>(
701 loc, value, outputOperand->get(), ValueRange{});
702 }
703
704 write = state.maskOperation(rewriter, opToMask: write, linalgOp: linalgOp, maybeIndexingMap: opOperandMap);
705
706 // If masked, set in-bounds to true. Masking guarantees that the access will
707 // be in-bounds.
708 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
709 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
710 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
711 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
712 }
713
714 LDBG("vectorized op: " << *write << "\n");
715 if (!write->getResults().empty())
716 return write->getResult(idx: 0);
717 return Value();
718}
719
720// Custom vectorization precondition function type. This is intented to be used
721// with CustomVectorizationHook. Returns success if the corresponding custom
722// hook can vectorize the op.
723using CustomVectorizationPrecondition =
724 std::function<LogicalResult(Operation *, bool)>;
725
726// Custom vectorization function type. Produce a vector form of Operation*
727// assuming all its vectorized operands are already in the IRMapping.
728// Return nullptr if the Operation cannot be vectorized.
729using CustomVectorizationHook =
730 std::function<VectorizationResult(Operation *, const IRMapping &)>;
731
732/// Helper function to vectorize the terminator of a `linalgOp`. New result
733/// vector values are appended to `newResults`. Return
734/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
735/// should not try to map produced operations and instead return the results
736/// using the `newResults` vector making them available to the vectorization
737/// algorithm for RAUW. This function is meant to be used as a
738/// CustomVectorizationHook.
739static VectorizationResult
740vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
741 const IRMapping &bvm, VectorizationState &state,
742 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
743 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
744 if (!yieldOp)
745 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
746 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
747 // TODO: Scan for an opportunity for reuse.
748 // TODO: use a map.
749 Value vectorValue = bvm.lookup(output.value());
750 Value newResult =
751 buildVectorWrite(rewriter, vectorValue,
752 linalgOp.getDpsInitOperand(output.index()), state);
753 if (newResult)
754 newResults.push_back(newResult);
755 }
756
757 return VectorizationResult{.status: VectorizationStatus::NoReplace, .newOp: nullptr};
758}
759
760/// Helper function to vectorize the index operations of a `linalgOp`. Return
761/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
762/// should map the produced operations. This function is meant to be used as a
763/// CustomVectorizationHook.
764static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
765 VectorizationState &state,
766 Operation *op,
767 LinalgOp linalgOp) {
768 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
769 if (!indexOp)
770 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
771 auto loc = indexOp.getLoc();
772 // Compute the static loop sizes of the index op.
773 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
774 auto dim = indexOp.getDim();
775 // Compute a one-dimensional index vector for the index op dimension.
776 auto indexVectorType =
777 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
778 state.getScalableVecDims()[dim]);
779 auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
780 // Return the one-dimensional index vector if it lives in the trailing
781 // dimension of the iteration space since the vectorization algorithm in this
782 // case can handle the broadcast.
783 if (dim == targetShape.size() - 1)
784 return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
785 // Otherwise permute the targetShape to move the index dimension last,
786 // broadcast the one-dimensional index vector to the permuted shape, and
787 // finally transpose the broadcasted index vector to undo the permutation.
788 auto permPattern =
789 llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size()));
790 std::swap(permPattern[dim], permPattern.back());
791 auto permMap =
792 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
793
794 auto broadCastOp = rewriter.create<vector::BroadcastOp>(
795 loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
796 indexSteps);
797 SmallVector<int64_t> transposition =
798 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
799 std::swap(transposition.back(), transposition[dim]);
800 auto transposeOp =
801 rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
802 return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
803}
804
805/// Helper function to check if the tensor.extract can be vectorized by the
806/// custom hook vectorizeTensorExtract.
807static LogicalResult
808tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
809 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
810 if (!extractOp)
811 return failure();
812
813 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
814 return failure();
815
816 // Check the index type, but only for non 0-d tensors (for which we do need
817 // access indices).
818 if (not extractOp.getIndices().empty()) {
819 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
820 return failure();
821 }
822
823 if (!llvm::all_of(extractOp->getResultTypes(),
824 VectorType::isValidElementType)) {
825 return failure();
826 }
827
828 return success();
829}
830
831/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
832/// generated from `tensor.extract`. The offset is calculated as follows
833/// (example using scalar values):
834///
835/// offset = extractOp.indices[0]
836/// for (i = 1; i < numIndices; i++)
837/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
838///
839/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
840/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
841static Value calculateGatherOffset(RewriterBase &rewriter,
842 VectorizationState &state,
843 tensor::ExtractOp extractOp,
844 const IRMapping &bvm) {
845 // The vector of indices for GatherOp should be shaped as the output vector.
846 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
847 auto loc = extractOp.getLoc();
848
849 Value offset = broadcastIfNeeded(
850 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
851
852 const size_t numIndices = extractOp.getIndices().size();
853 for (size_t i = 1; i < numIndices; i++) {
854 Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
855
856 auto dimSize = broadcastIfNeeded(
857 rewriter,
858 rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
859 indexVecType);
860
861 offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
862
863 auto extractOpIndex = broadcastIfNeeded(
864 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
865
866 offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
867 }
868
869 return offset;
870}
871
872enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
873
874/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
875/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
876/// represents a contiguous load operation.
877///
878/// Note that when calling this hook, it is assumed that the output vector is
879/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
880/// labelled as a gather load before entering this method.
881///
882/// Following on from the above, it is assumed that:
883/// * for statically shaped loops, when no masks are used, only one dim is !=
884/// 1 (that's what the shape of the output vector is based on).
885/// * for dynamically shaped loops, there might be more non-unit dims
886/// as the output vector type is user-specified.
887///
888/// TODO: Statically shaped loops + vector masking
889static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
890 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
891 assert(
892 (linalgOp.hasDynamicShape() ||
893 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
894 "For statically shaped Linalg Ops, only one "
895 "non-unit loop dim is expected");
896 assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
897
898 size_t idx = loopRanges.size() - 1;
899 for (; idx != 0; idx--)
900 if (loopRanges[idx] != 1)
901 break;
902
903 return idx;
904}
905
906/// Checks whether `val` can be used for calculating a loop invariant index.
907static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
908 VectorType resType) {
909
910 assert(((llvm::count_if(resType.getShape(),
911 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
912 "n-D vectors are not yet supported");
913
914 // Blocks outside _this_ linalg.generic are effectively loop invariant.
915 // However, analysing block arguments for _this_ linalg.generic Op is a bit
916 // tricky. Just bail out in the latter case.
917 // TODO: We could try analysing the corresponding affine map here.
918 auto *block = linalgOp.getBlock();
919 if (isa<BlockArgument>(Val: val))
920 return llvm::all_of(block->getArguments(),
921 [&val](Value v) { return (v != val); });
922
923 Operation *defOp = val.getDefiningOp();
924 assert(defOp && "This is neither a block argument nor an operation result");
925
926 // IndexOp is loop invariant as long as its result remains constant across
927 // iterations. Note that for dynamic shapes, the corresponding dim will also
928 // be conservatively treated as != 1.
929 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
930 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
931 }
932
933 auto *ancestor = block->findAncestorOpInBlock(*defOp);
934
935 // Values define outside `linalgOp` are loop invariant.
936 if (!ancestor)
937 return true;
938
939 // Values defined inside `linalgOp`, which are constant, are loop invariant.
940 if (isa<arith::ConstantOp>(ancestor))
941 return true;
942
943 bool result = true;
944 for (auto op : ancestor->getOperands())
945 result &= isLoopInvariantIdx(linalgOp, op, resType);
946
947 return result;
948}
949
950/// Check whether `val` could be used for calculating the trailing index for a
951/// contiguous load operation.
952///
953/// There are currently 3 types of values that are allowed here:
954/// 1. loop-invariant values,
955/// 2. values that increment by 1 with every loop iteration,
956/// 3. results of basic arithmetic operations (linear and continuous)
957/// involving 1., 2. and 3.
958/// This method returns True if indeed only such values are used in calculating
959/// `val.`
960///
961/// Additionally, the trailing index for a contiguous load operation should
962/// increment by 1 with every loop iteration, i.e. be based on:
963/// * `linalg.index <dim>` ,
964/// where <dim> is the trailing non-unit dim of the iteration space (this way,
965/// `linalg.index <dim>` increments by 1 with every loop iteration).
966/// `foundIndexOp` is updated to `true` when such Op is found.
967static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
968 bool &foundIndexOp, VectorType resType) {
969
970 assert(((llvm::count_if(resType.getShape(),
971 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
972 "n-D vectors are not yet supported");
973
974 // Blocks outside _this_ linalg.generic are effectively loop invariant.
975 // However, analysing block arguments for _this_ linalg.generic Op is a bit
976 // tricky. Just bail out in the latter case.
977 // TODO: We could try analysing the corresponding affine map here.
978 auto *block = linalgOp.getBlock();
979 if (isa<BlockArgument>(Val: val))
980 return llvm::all_of(block->getArguments(),
981 [&val](Value v) { return (v != val); });
982
983 Operation *defOp = val.getDefiningOp();
984 assert(defOp && "This is neither a block argument nor an operation result");
985
986 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
987 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
988
989 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
990 return true;
991 }
992
993 auto *ancestor = block->findAncestorOpInBlock(*defOp);
994
995 if (!ancestor)
996 return false;
997
998 // Conservatively reject Ops that could lead to indices with stride other
999 // than 1.
1000 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1001 return false;
1002
1003 bool result = false;
1004 for (auto op : ancestor->getOperands())
1005 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1006
1007 return result;
1008}
1009
1010/// Infer the memory access pattern for the input ExtractOp
1011///
1012/// Based on the ExtratOp result shape and the access indices, decides whether
1013/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1014/// or a gather load. When analysing the ExtractOp indices (to identify
1015/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1016/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1017/// Op).
1018///
1019/// Note that it is always safe to use gather load operations for contiguous
1020/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1021/// that `extractOp` is a gather load.
1022static VectorMemoryAccessKind
1023getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1024 LinalgOp &linalgOp, VectorType resType) {
1025
1026 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1027
1028 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1029 if (inputShape.getShape().empty())
1030 return VectorMemoryAccessKind::ScalarBroadcast;
1031
1032 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1033 // otherwise.
1034 bool isOutput1DVector =
1035 (llvm::count_if(resType.getShape(),
1036 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1037 // 1. Assume that it's a gather load when reading non-1D vector.
1038 if (!isOutput1DVector)
1039 return VectorMemoryAccessKind::Gather;
1040
1041 bool leadingIdxsLoopInvariant = true;
1042
1043 // 2. Analyze the leading indices of `extractOp`.
1044 // Look at the way each index is calculated and decide whether it is suitable
1045 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1046 // gather load.
1047 auto indices = extractOp.getIndices();
1048 auto leadIndices = indices.drop_back(1);
1049
1050 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1051 if (inputShape.getShape()[i] == 1)
1052 continue;
1053
1054 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1055 }
1056
1057 if (!leadingIdxsLoopInvariant) {
1058 LDBG("Found gather load: " << extractOp);
1059 return VectorMemoryAccessKind::Gather;
1060 }
1061
1062 // 3. Analyze the trailing index for `extractOp`.
1063 // At this point we know that the leading indices are loop invariant. This
1064 // means that is potentially a scalar or a contiguous load. We can decide
1065 // based on the trailing idx.
1066 auto extractOpTrailingIdx = indices.back();
1067
1068 // 3a. Scalar broadcast load
1069 // If the trailing index is loop invariant then this is a scalar load.
1070 if (leadingIdxsLoopInvariant &&
1071 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1072 LDBG("Found scalar broadcast load: " << extractOp);
1073
1074 return VectorMemoryAccessKind::ScalarBroadcast;
1075 }
1076
1077 // 3b. Contiguous loads
1078 // The trailing `extractOp` index should increment with every loop iteration.
1079 // This effectively means that it must be based on the trailing loop index.
1080 // This is what the following bool captures.
1081 bool foundIndexOp = false;
1082 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1083 foundIndexOp, resType);
1084 // TODO: Support generating contiguous loads for column vectors - that will
1085 // require adding a permutation map to tranfer_read Ops.
1086 bool isRowVector = resType.getShape().back() != 1;
1087 isContiguousLoad &= (foundIndexOp && isRowVector);
1088
1089 if (isContiguousLoad) {
1090 LDBG("Found contigous load: " << extractOp);
1091 return VectorMemoryAccessKind::Contiguous;
1092 }
1093
1094 // 4. Fallback case - gather load.
1095 LDBG("Found gather load: " << extractOp);
1096 return VectorMemoryAccessKind::Gather;
1097}
1098
1099/// Helper function to vectorize the tensor.extract operations. Returns
1100/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1101/// should map the produced operations. This function is meant to be used as a
1102/// CustomVectorizationHook.
1103static VectorizationResult
1104vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1105 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1106 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1107 if (!extractOp)
1108 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
1109 auto loc = extractOp.getLoc();
1110
1111 // Compute the static loop sizes of the extract op.
1112 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1113 auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1114 loc,
1115 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1116 /*value=*/true));
1117 auto passThruConstantOp =
1118 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1119
1120 // Base indices are currently set to 0. We will need to re-visit if more
1121 // generic scenarios are to be supported.
1122 SmallVector<Value> baseIndices(
1123 extractOp.getIndices().size(),
1124 rewriter.create<arith::ConstantIndexOp>(loc, 0));
1125
1126 VectorMemoryAccessKind memAccessKind =
1127 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1128
1129 // 1. Handle gather access
1130 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1131 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1132
1133 // Generate the gather load
1134 Operation *gatherOp = rewriter.create<vector::GatherOp>(
1135 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1136 maskConstantOp, passThruConstantOp);
1137 gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp: linalgOp);
1138
1139 LDBG("Vectorised as gather load: " << extractOp << "\n");
1140 return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: gatherOp};
1141 }
1142
1143 // 2. Handle:
1144 // a. scalar loads + broadcast,
1145 // b. contiguous loads.
1146 // Both cases use vector.transfer_read.
1147
1148 // Collect indices for `vector.transfer_read`. At this point, the indices will
1149 // either be scalars or would have been broadcast to vectors matching the
1150 // result type. For indices that are vectors, there are two options:
1151 // * for non-trailing indices, all elements are identical (contiguous
1152 // loads are identified by looking for non-trailing indices that are
1153 // invariant with respect to the corresponding linalg.generic), or
1154 // * for trailing indices, the index vector will contain values with stride
1155 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1156 // needed.
1157 // This means that
1158 // * for scalar indices - just re-use it,
1159 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1160 // (0th) element and use that.
1161 SmallVector<Value> transferReadIdxs;
1162 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1163 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1164 if (idx.getType().isIndex()) {
1165 transferReadIdxs.push_back(Elt: idx);
1166 continue;
1167 }
1168
1169 auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1170 loc,
1171 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1172 resultType.getScalableDims().back()),
1173 idx);
1174 transferReadIdxs.push_back(
1175 rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1176 }
1177
1178 // `tensor.extract_element` is always in-bounds, hence the following holds.
1179 auto dstRank = resultType.getRank();
1180 auto srcRank = extractOp.getTensor().getType().getRank();
1181 SmallVector<bool> inBounds(dstRank, true);
1182
1183 // 2a. Handle scalar broadcast access.
1184 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1185 MLIRContext *ctx = rewriter.getContext();
1186 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx));
1187 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1188
1189 auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1190 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1191 permutationMap, inBounds);
1192
1193 // Mask this broadcasting xfer_read here rather than relying on the generic
1194 // path (the generic path assumes identity masking map, which wouldn't be
1195 // valid here).
1196 SmallVector<int64_t> readMaskShape = {1};
1197 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1198 auto allTrue = rewriter.create<vector::ConstantMaskOp>(
1199 loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1200 auto *maskedReadOp =
1201 mlir::vector::maskOperation(builder&: rewriter, maskableOp: transferReadOp, mask: allTrue);
1202
1203 LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1204 return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1205 }
1206
1207 // 2b. Handle contiguous access.
1208 auto permutationMap = AffineMap::getMinorIdentityMap(
1209 dims: srcRank, results: std::min(dstRank, srcRank), context: rewriter.getContext());
1210
1211 int32_t rankDiff = dstRank - srcRank;
1212 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1213 // dims so that the ranks match. This is done by extending the map with 0s.
1214 // For example, for dstRank = 3, srcRank = 2, the following map created
1215 // above:
1216 // (d0, d1) --> (d0, d1)
1217 // is extended as:
1218 // (d0, d1) --> (0, d0, d1)
1219 while (rankDiff > 0) {
1220 permutationMap = permutationMap.insertResult(
1221 mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), 0);
1222 rankDiff--;
1223 }
1224
1225 auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1226 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1227 inBounds);
1228
1229 LDBG("Vectorised as contiguous load: " << extractOp);
1230 return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1231}
1232
1233/// Emit reduction operations if the shapes of the value to reduce is different
1234/// that the result shape.
1235// Note: this is a true builder that notifies the OpBuilder listener.
1236// TODO: Consider moving as a static helper on the ReduceOp.
1237static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1238 Value reduceValue, Value initialValue,
1239 const IRMapping &bvm) {
1240 Value reduceVec = bvm.lookup(from: reduceValue);
1241 Value outputVec = bvm.lookup(from: initialValue);
1242 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1243 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1244 // Reduce only if needed as the value may already have been reduce for
1245 // contraction vectorization.
1246 if (!reduceType ||
1247 (outputType && reduceType.getShape() == outputType.getShape()))
1248 return nullptr;
1249 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1250 return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask);
1251}
1252
1253/// Generic vectorization for a single operation `op`, given already vectorized
1254/// operands carried by `bvm`. Vectorization occurs as follows:
1255/// 1. Try to apply any of the `customVectorizationHooks` and return its
1256/// result on success.
1257/// 2. Clone any constant in the current scope without vectorization: each
1258/// consumer of the constant will later determine the shape to which the
1259/// constant needs to be broadcast to.
1260/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1261/// of the `customVectorizationHooks` to cover such cases.
1262/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1263/// operand of maximal rank. Other operands have smaller rank and are
1264/// broadcast accordingly. It is assumed this broadcast is always legal,
1265/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1266///
1267/// This function assumes all operands of `op` have been vectorized and are in
1268/// the `bvm` mapping. As a consequence, this function is meant to be called on
1269/// a topologically-sorted list of ops.
1270/// This function does not update `bvm` but returns a VectorizationStatus that
1271/// instructs the caller what `bvm` update needs to occur.
1272static VectorizationResult
1273vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1274 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1275 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1276 LDBG("vectorize op " << *op << "\n");
1277
1278 // 1. Try to apply any CustomVectorizationHook.
1279 if (!customVectorizationHooks.empty()) {
1280 for (auto &customFunc : customVectorizationHooks) {
1281 VectorizationResult result = customFunc(op, bvm);
1282 if (result.status == VectorizationStatus::Failure)
1283 continue;
1284 return result;
1285 }
1286 }
1287
1288 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1289 // Clone so that the constant is not confined to the linalgOp block .
1290 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1291 return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: rewriter.clone(op&: *op)};
1292
1293 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1294 if (!OpTrait::hasElementwiseMappableTraits(op))
1295 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
1296
1297 // 4 . Check if the operation is a reduction.
1298 SmallVector<std::pair<Value, Value>> reductionOperands;
1299 for (Value operand : op->getOperands()) {
1300 auto blockArg = dyn_cast<BlockArgument>(Val&: operand);
1301 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1302 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1303 continue;
1304 SmallVector<Operation *> reductionOps;
1305 Value reduceValue = matchReduction(
1306 linalgOp.getRegionOutputArgs(),
1307 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1308 if (!reduceValue)
1309 continue;
1310 reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand));
1311 }
1312 if (!reductionOperands.empty()) {
1313 assert(reductionOperands.size() == 1);
1314 Operation *reduceOp =
1315 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1316 reductionOperands[0].second, bvm);
1317 if (reduceOp)
1318 return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: reduceOp};
1319 }
1320
1321 // 5. Generic vectorization path for ElementwiseMappable ops.
1322 // a. Get the first max ranked shape.
1323 VectorType firstMaxRankedType;
1324 for (Value operand : op->getOperands()) {
1325 auto vecOperand = bvm.lookup(from: operand);
1326 assert(vecOperand && "Vector operand couldn't be found");
1327
1328 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1329 if (vecType && (!firstMaxRankedType ||
1330 firstMaxRankedType.getRank() < vecType.getRank()))
1331 firstMaxRankedType = vecType;
1332 }
1333 // b. Broadcast each op if needed.
1334 SmallVector<Value> vecOperands;
1335 for (Value scalarOperand : op->getOperands()) {
1336 Value vecOperand = bvm.lookup(from: scalarOperand);
1337 assert(vecOperand && "Vector operand couldn't be found");
1338
1339 if (firstMaxRankedType) {
1340 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1341 getElementTypeOrSelf(vecOperand.getType()),
1342 firstMaxRankedType.getScalableDims());
1343 vecOperands.push_back(Elt: broadcastIfNeeded(rewriter, vecOperand, vecType));
1344 } else {
1345 vecOperands.push_back(Elt: vecOperand);
1346 }
1347 }
1348 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1349 SmallVector<Type> resultTypes;
1350 for (Type resultType : op->getResultTypes()) {
1351 resultTypes.push_back(
1352 firstMaxRankedType
1353 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1354 firstMaxRankedType.getScalableDims())
1355 : resultType);
1356 }
1357 // d. Build and return the new op.
1358 return VectorizationResult{
1359 .status: VectorizationStatus::NewOp,
1360 .newOp: rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1361 resultTypes, op->getAttrs())};
1362}
1363
1364/// Generic vectorization function that rewrites the body of a `linalgOp` into
1365/// vector form. Generic vectorization proceeds as follows:
1366/// 1. Verify the `linalgOp` has one non-empty region.
1367/// 2. Values defined above the region are mapped to themselves and will be
1368/// broadcasted on a per-need basis by their consumers.
1369/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1370/// load).
1371/// TODO: Reuse opportunities for RAR dependencies.
1372/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1373/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1374/// iteration indices.
1375/// 5. Iteratively call vectorizeOneOp on the region operations.
1376///
1377/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1378/// performed to the maximal common vector size implied by the `linalgOp`
1379/// iteration space. This eager broadcasting is introduced in the
1380/// permutation_map of the vector.transfer_read operations. The eager
1381/// broadcasting makes it trivial to detrmine where broadcast, transposes and
1382/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1383/// the absence of good canonicalizations, the amount of work increases.
1384/// This is not deemed a problem as we expect canonicalizations and foldings to
1385/// aggressively clean up the useless work.
1386static LogicalResult
1387vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1388 LinalgOp linalgOp,
1389 SmallVectorImpl<Value> &newResults) {
1390 LDBG("Vectorizing operation as linalg generic\n");
1391 Block *block = linalgOp.getBlock();
1392
1393 // 2. Values defined above the region can only be broadcast for now. Make them
1394 // map to themselves.
1395 IRMapping bvm;
1396 SetVector<Value> valuesSet;
1397 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1398 bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef());
1399
1400 if (linalgOp.getNumDpsInits() == 0)
1401 return failure();
1402
1403 // 3. Turn all BBArgs into vector.transfer_read / load.
1404 Location loc = linalgOp.getLoc();
1405 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1406 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1407 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1408 if (linalgOp.isScalar(opOperand)) {
1409 bvm.map(bbarg, opOperand->get());
1410 continue;
1411 }
1412
1413 // 3.a. Convert the indexing map for this input/output to a transfer read
1414 // permutation map and masking map.
1415 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1416
1417 AffineMap readMap;
1418 VectorType readType;
1419 Type elemType = getElementTypeOrSelf(opOperand->get());
1420 if (linalgOp.isDpsInput(opOperand)) {
1421 // 3.a.i. For input reads we use the canonical vector shape.
1422 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1423 readType = state.getCanonicalVecType(elemType);
1424 } else {
1425 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1426 // reductions), the vector shape is computed by mapping the canonical
1427 // vector shape to the output domain and back to the canonical domain.
1428 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1429 readType =
1430 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1431 }
1432
1433 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1434
1435 Operation *read = rewriter.create<vector::TransferReadOp>(
1436 loc, readType, opOperand->get(), indices, readMap);
1437 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1438 Value readValue = read->getResult(0);
1439
1440 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1441 // will be in-bounds.
1442 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1443 SmallVector<bool> inBounds(readType.getRank(), true);
1444 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1445 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1446 }
1447
1448 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1449 // TODO: remove this.
1450 if (readType.getRank() == 0)
1451 readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1452 ArrayRef<int64_t>());
1453
1454 LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1455 << "\n");
1456 bvm.map(bbarg, readValue);
1457 bvm.map(opOperand->get(), readValue);
1458 }
1459
1460 SmallVector<CustomVectorizationHook> hooks;
1461 // 4a. Register CustomVectorizationHook for yieldOp.
1462 CustomVectorizationHook vectorizeYield =
1463 [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1464 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1465 };
1466 hooks.push_back(Elt: vectorizeYield);
1467
1468 // 4b. Register CustomVectorizationHook for indexOp.
1469 CustomVectorizationHook vectorizeIndex =
1470 [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1471 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1472 };
1473 hooks.push_back(Elt: vectorizeIndex);
1474
1475 // 4c. Register CustomVectorizationHook for extractOp.
1476 CustomVectorizationHook vectorizeExtract =
1477 [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1478 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1479 };
1480 hooks.push_back(Elt: vectorizeExtract);
1481
1482 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1483 for (Operation &op : block->getOperations()) {
1484 VectorizationResult result =
1485 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1486 if (result.status == VectorizationStatus::Failure) {
1487 LDBG("failed to vectorize: " << op << "\n");
1488 return failure();
1489 }
1490 if (result.status == VectorizationStatus::NewOp) {
1491 Operation *maybeMaskedOp =
1492 state.maskOperation(rewriter, result.newOp, linalgOp);
1493 LDBG("New vector op: " << *maybeMaskedOp << "\n");
1494 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1495 }
1496 }
1497
1498 return success();
1499}
1500
1501/// Given a linalg::PackOp, return the `dest` shape before any packing
1502/// permutations.
1503static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
1504 ArrayRef<int64_t> destShape) {
1505 return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
1506}
1507
1508/// Determines whether a mask for xfer_write is trivially "all true"
1509///
1510/// Given all the inputs required to generate a mask (mask sizes and shapes),
1511/// and an xfer_write operation (write indices and the destination tensor
1512/// shape), determines whether the corresponding mask would be trivially
1513/// foldable (i.e., trivially "all true").
1514///
1515/// Use this method to avoid generating spurious masks and relaying on
1516/// vectorization post-processing to remove them.
1517///
1518/// Pre-conditions for a mask to be trivially foldable:
1519/// * All involved shapes (mask + destination tensor) are static.
1520/// * All write indices are constant.
1521/// * All mask sizes are constant (including `arith.constant`).
1522///
1523/// If the pre-conditions are met, the method checks for each destination
1524/// dimension `d`:
1525/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1526/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1527///
1528/// rankDiff = rank(dest) - rank(mask).
1529///
1530/// This method takes a conservative view: it may return false even if the mask
1531/// is technically foldable.
1532///
1533/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1534/// of the dest tensor):
1535/// %c0 = arith.constant 0 : index
1536/// %mask = vector.create_mask 5, 1
1537/// vector.mask %mask {
1538/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1539/// {in_bounds = [true, true]}
1540/// : vector<5x1xi32>, tensor<5x1xi32>
1541/// }
1542///
1543/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1544/// mask is required to avoid out-of-bounds write):
1545/// %c0 = arith.constant 0 : index
1546/// %mask = vector.create_mask 5, 1
1547/// vector.mask %mask {
1548/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1549/// {in_bounds = [true, true]}
1550/// : vector<8x1xi32>, tensor<5x1xi32>
1551/// }
1552///
1553/// TODO: Re-use in createReadOrMaskedRead
1554static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
1555 SmallVector<Value> &writeIdxs,
1556 ArrayRef<int64_t> destShape,
1557 ArrayRef<int64_t> maskShape) {
1558 // Masking is unavoidable in the case of dynamic tensors.
1559 if (ShapedType::isDynamicShape(destShape))
1560 return false;
1561
1562 // Collect all constant mask sizes.
1563 SmallVector<int64_t, 4> cstMaskSizes;
1564 for (auto [i, dimSize] : llvm::enumerate(First&: maskSizes)) {
1565 if (auto intSize = getConstantIntValue(ofr: dimSize)) {
1566 cstMaskSizes.push_back(Elt: *intSize);
1567 }
1568 }
1569
1570 // If any of the mask sizes is non-constant, bail out.
1571 if (cstMaskSizes.size() != maskShape.size())
1572 return false;
1573
1574 // Collect all constant write indices.
1575 SmallVector<int64_t, 4> cstWriteIdxs;
1576 for (auto [i, idx] : llvm::enumerate(First&: writeIdxs)) {
1577 APSInt intVal;
1578 if (matchPattern(idx, m_ConstantInt(&intVal))) {
1579 cstWriteIdxs.push_back(Elt: intVal.getSExtValue());
1580 }
1581 }
1582
1583 // If any of the write indices is non-constant, bail out.
1584 if (cstWriteIdxs.size() != destShape.size())
1585 return false;
1586
1587 // Go over all destination dims and check (1) and (2). Take into account that:
1588 // * The number of mask sizes will match the rank of the vector to store.
1589 // This could be lower than the rank of the destination tensor.
1590 // * Mask sizes could be larger than the corresponding mask shape (hence
1591 // `clamp`).
1592 // TODO: The 2nd item should be rejected by the verifier.
1593 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1594 for (auto [i, idx] : llvm::enumerate(First&: cstMaskSizes)) {
1595 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1596 /*(2)*/ destShape[rankDiff + i] <
1597 (std::clamp(val: cstMaskSizes[i], lo: int64_t(0), hi: maskShape[i]) +
1598 cstWriteIdxs[i]))
1599 return false;
1600 }
1601
1602 return true;
1603}
1604
1605/// Creates an optionally masked TransferWriteOp
1606///
1607/// Generates the following operation:
1608/// %res = vector.transfer_write %vecToStore into %dest
1609///
1610/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1611///
1612/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1613/// %res = vector.mask %mask {
1614/// vector.transfer_write %vecToStore into %dest
1615/// }
1616///
1617/// The mask shape is identical to `vecToStore` (with the element type ==
1618/// i1), and the mask values are based on the shape of the `dest` tensor.
1619///
1620/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1621/// is used instead of masking:
1622///
1623/// %write = vector.transfer_write %vecToStore into %dest
1624/// in_bounds_flags = (...)
1625/// %res = vector.transfer_write %input into %dest
1626/// {in_bounds = in_bounds_flags}
1627///
1628/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1629/// are set to 0.
1630static Operation *
1631createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1632 Value dest, SmallVector<Value> writeIndices = {},
1633 bool useInBoundsInsteadOfMasking = false) {
1634
1635 ShapedType destType = cast<ShapedType>(dest.getType());
1636 int64_t destRank = destType.getRank();
1637 auto destShape = destType.getShape();
1638
1639 VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1640 int64_t vecToStoreRank = vecToStoreType.getRank();
1641 auto vecToStoreShape = vecToStoreType.getShape();
1642
1643 // Compute the in_bounds attribute
1644 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1645 if (useInBoundsInsteadOfMasking) {
1646 // Update the inBounds attribute.
1647 // FIXME: This computation is too weak - it ignores the write indices.
1648 for (unsigned i = 0; i < vecToStoreRank; i++)
1649 inBoundsVal[i] =
1650 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1651 !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
1652 }
1653
1654 // If missing, initialize the write indices to 0.
1655 assert(writeIndices.empty() ||
1656 writeIndices.size() == static_cast<size_t>(destRank) &&
1657 "Invalid number of write indices!");
1658 if (writeIndices.empty()) {
1659 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
1660 writeIndices.assign(destRank, zero);
1661 }
1662
1663 // Generate the xfer_write Op
1664 Operation *write =
1665 builder.create<vector::TransferWriteOp>(loc,
1666 /*vector=*/vecToStore,
1667 /*source=*/dest,
1668 /*indices=*/writeIndices,
1669 /*inBounds=*/inBoundsVal);
1670
1671 // If masking is disabled, exit.
1672 if (useInBoundsInsteadOfMasking)
1673 return write;
1674
1675 // Check if masking is needed. If not, exit.
1676 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1677 return write;
1678
1679 // Compute the mask and mask the write Op.
1680 auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1681
1682 SmallVector<OpFoldResult> destSizes =
1683 tensor::getMixedSizes(builder, loc, value: dest);
1684 SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1685 destSizes.end());
1686
1687 if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1688 vecToStoreShape))
1689 return write;
1690
1691 Value maskForWrite =
1692 builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1693 return mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite);
1694}
1695
1696/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
1697/// padding value and (3) input vector sizes into:
1698///
1699/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1700///
1701/// As in the following example:
1702/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1703/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1704///
1705/// This pack would be vectorized to:
1706///
1707/// %load = vector.mask %mask {
1708/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1709/// {in_bounds = [true, true, true]} :
1710/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1711/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1712/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1713/// to vector<32x4x2x1x16xf32>
1714/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1715/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1716/// %write = vector.transfer_write %transpose,
1717/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1718/// {in_bounds = [true, true, true, true, true]}
1719/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1720///
1721/// If the (3) input vector sizes are not provided, the vector sizes are
1722/// determined by the result tensor shape and the `in_bounds`
1723/// attribute is used instead of masking to mark out-of-bounds accesses.
1724///
1725/// NOTE: The input vector sizes specify the dimensions corresponding to the
1726/// outer dimensions of the output tensor. The remaining dimensions are
1727/// computed based on, e.g., the static inner tiles.
1728/// Supporting dynamic inner tiles will require the user to specify the
1729/// missing vector sizes. This is left as a TODO.
1730static LogicalResult
1731vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1732 ArrayRef<int64_t> inputVectorSizes,
1733 SmallVectorImpl<Value> &newResults) {
1734 // TODO: Introduce a parent class that will handle the insertion point update.
1735 OpBuilder::InsertionGuard g(rewriter);
1736 rewriter.setInsertionPoint(packOp);
1737
1738 Location loc = packOp.getLoc();
1739 auto padValue = packOp.getPaddingValue();
1740 if (!padValue) {
1741 padValue = rewriter.create<arith::ConstantOp>(
1742 loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1743 }
1744 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1745 LogicalResult status =
1746 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1747 .reifyResultShapes(rewriter, reifiedReturnShapes);
1748 (void)status; // prevent unused variable warning on non-assert builds.
1749 assert(succeeded(status) && "failed to reify result shapes");
1750
1751 // If the input vector sizes are not provided, then the vector sizes are
1752 // determined by the result tensor shape. In case the vector sizes aren't
1753 // provided, we update the inBounds attribute instead of masking.
1754 bool useInBoundsInsteadOfMasking = false;
1755 if (inputVectorSizes.empty()) {
1756 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1757 inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank());
1758 useInBoundsInsteadOfMasking = true;
1759 }
1760
1761 // Create masked TransferReadOp.
1762 SmallVector<int64_t> inputShape(inputVectorSizes);
1763 auto innerTiles = packOp.getStaticInnerTiles();
1764 auto innerDimsPos = packOp.getInnerDimsPos();
1765 auto outerDimsPerm = packOp.getOuterDimsPerm();
1766 if (!outerDimsPerm.empty())
1767 applyPermutationToVector(inputShape,
1768 invertPermutationVector(outerDimsPerm));
1769 for (auto [idx, size] : enumerate(innerTiles))
1770 inputShape[innerDimsPos[idx]] *= size;
1771 auto maskedRead = vector::createReadOrMaskedRead(
1772 builder&: rewriter, loc, source: packOp.getSource(), inputVectorSizes: inputShape, padValue: padValue,
1773 useInBoundsInsteadOfMasking);
1774
1775 // Create ShapeCastOp.
1776 SmallVector<int64_t> destShape(inputVectorSizes);
1777 destShape.append(innerTiles.begin(), innerTiles.end());
1778 auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1779 packOp.getDestType().getElementType());
1780 auto shapeCastOp =
1781 rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1782
1783 // Create TransposeOp.
1784 auto destPermutation =
1785 invertPermutationVector(getPackInverseDestPerm(packOp));
1786 auto transposeOp = rewriter.create<vector::TransposeOp>(
1787 loc, shapeCastOp.getResult(), destPermutation);
1788
1789 // Create TransferWriteOp.
1790 Value dest = rewriter.create<tensor::EmptyOp>(
1791 loc, reifiedReturnShapes[0],
1792 transposeOp.getResult().getType().getElementType());
1793 Operation *write =
1794 createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
1795 newResults.push_back(Elt: write->getResult(idx: 0));
1796 return success();
1797}
1798
1799/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1800/// Vector::TransferReadOp - Reads a vector from the source tensor
1801/// vector::TransposeOp - Transpose the Source tensor
1802/// ShapeCastOp - Reshape the data based on the target.
1803/// vector::TransferWriteOp. - Write the result vector back to the destination
1804/// tensor.
1805/// If the vector sizes are not provided:
1806/// * the vector sizes are determined by the input operand and attributes,
1807/// * update the inBounds attribute instead of masking.
1808static LogicalResult
1809vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1810 ArrayRef<int64_t> inputVectorSizes,
1811 SmallVectorImpl<Value> &newResults) {
1812
1813 // TODO: Introduce a parent class that will handle the insertion point update.
1814 OpBuilder::InsertionGuard g(rewriter);
1815 rewriter.setInsertionPoint(unpackOp);
1816
1817 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1818
1819 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1820 ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1821 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1822 bool useInBoundsInsteadOfMasking = false;
1823 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1824
1825 auto destSize = unpackOp.getDestRank();
1826
1827 if (!inputVectorSizes.empty())
1828 assert(inputVectorSizes.size() == destSize &&
1829 "Incorrect number of input vector sizes");
1830
1831 // vectorSizes is the shape of the vector that will be used to do final
1832 // write on the destination tensor. It is set like this: Let's say the
1833 // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1834 // Thus:
1835 // 1. vectorSizes = sourceShape.take_front(N)
1836 // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1837 // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1838 // innerTiles attribute value.
1839 SmallVector<int64_t> vectorSizes(inputVectorSizes);
1840 if (vectorSizes.empty()) {
1841 llvm::append_range(vectorSizes, sourceShape.take_front(N: destSize));
1842 if (!outerDimsPerm.empty())
1843 applyPermutationToVector(inVec&: vectorSizes, permutation: outerDimsPerm);
1844 for (auto [i, pos] : llvm::enumerate(innerDimPos))
1845 vectorSizes[pos] *= innerTiles[i];
1846
1847 useInBoundsInsteadOfMasking = true;
1848 }
1849
1850 // readVectorSizes is the size of tensor used to read and apply mask. It is
1851 // set like this: Let's say the vectorSize (VS) array is size 'N' and
1852 // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1853 // size M-N
1854 // Thus:
1855 // - initially: readVectorSizes = vectorInputSizes
1856 // - Divide all the readMaskShape locations pointed by innerDimPos
1857 // by the innerTileSize attribute value.
1858 // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1859 // - Append the remaining shape from SS
1860 // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1861 // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1862 // 128] and outer_dims_perm is [1, 0] then read shape is:
1863 // ReadVectorSizes(initial): [512, 128]
1864 // Final Value(after innerDim Adjustment): [512/32, 128/16]
1865 // = [16, 8]
1866 // After applying outer_dims_perm: [8, 16]
1867 // After appending the rest of the sourceShape: [8, 16, 32, 16]
1868
1869 SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1870
1871 for (auto [index, size] : enumerate(innerTiles)) {
1872 readVectorSizes[innerDimPos[index]] =
1873 llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1874 }
1875 if (!outerDimsPerm.empty()) {
1876 applyPermutationToVector(inVec&: readVectorSizes, permutation: outerDimsPerm);
1877 }
1878 readVectorSizes.append(in_start: sourceShape.begin() + vectorSizes.size(),
1879 in_end: sourceShape.end());
1880
1881 ReifiedRankedShapedTypeDims reifiedRetShapes;
1882 LogicalResult status =
1883 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1884 .reifyResultShapes(rewriter, reifiedRetShapes);
1885 if (status.failed()) {
1886 LDBG("Unable to reify result shapes of " << unpackOp);
1887 return failure();
1888 }
1889 Location loc = unpackOp->getLoc();
1890
1891 auto padValue = rewriter.create<arith::ConstantOp>(
1892 loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1893
1894 // Read result, mask if necessary. If transferReadOp shape is not equal
1895 // to shape of source, then a mask is necessary.
1896 Value readResult = vector::createReadOrMaskedRead(
1897 builder&: rewriter, loc, source: unpackOp.getSource(), inputVectorSizes: readVectorSizes, padValue: padValue,
1898 /*useInBoundsInsteadOfMasking=*/false);
1899
1900 PackingMetadata packMetadata;
1901 SmallVector<int64_t> lastDimToInsertPosPerm =
1902 getUnPackInverseSrcPerm(unpackOp, packMetadata);
1903 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1904 SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1905 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1906 applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm);
1907 RankedTensorType stripMineTensorType =
1908 RankedTensorType::get(stripMineShape, stripMineElemType);
1909 // Transpose the appropriate rows to match output.
1910 vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1911 loc, readResult, lastDimToInsertPosPerm);
1912
1913 // Collapse the vector to the size required by result.
1914 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1915 stripMineTensorType, packMetadata.reassociations);
1916 mlir::VectorType vecCollapsedType =
1917 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1918 vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1919 loc, vecCollapsedType, transposeOp->getResult(0));
1920
1921 // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1922 // otherwise the validator complains that the mask size is invalid.
1923 SmallVector<int64_t> writeVectorSizes(
1924 unpackOp.getDestType().hasStaticShape()
1925 ? vectorSizes
1926 : shapeCastOp.getResultVectorType().getShape());
1927 Value dest = rewriter.create<tensor::EmptyOp>(
1928 loc, reifiedRetShapes[0],
1929 shapeCastOp.getResult().getType().getElementType());
1930 Operation *write = createWriteOrMaskedWrite(
1931 rewriter, loc, shapeCastOp.getResult(), dest,
1932 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1933 newResults.push_back(Elt: write->getResult(idx: 0));
1934 return success();
1935}
1936
1937/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1938/// and (3) all-zero lowPad to
1939/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1940static LogicalResult
1941vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1942 ArrayRef<int64_t> inputVectorSizes,
1943 SmallVectorImpl<Value> &newResults) {
1944 auto padValue = padOp.getConstantPaddingValue();
1945 Location loc = padOp.getLoc();
1946
1947 // TODO: Introduce a parent class that will handle the insertion point update.
1948 OpBuilder::InsertionGuard g(rewriter);
1949 rewriter.setInsertionPoint(padOp);
1950
1951 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1952 LogicalResult status =
1953 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1954 .reifyResultShapes(rewriter, reifiedReturnShapes);
1955 (void)status; // prevent unused variable warning on non-assert builds
1956 assert(succeeded(status) && "failed to reify result shapes");
1957 auto maskedRead = vector::createReadOrMaskedRead(
1958 builder&: rewriter, loc, source: padOp.getSource(), inputVectorSizes, padValue: padValue,
1959 /*useInBoundsInsteadOfMasking=*/false);
1960
1961 // Create Xfer write Op
1962 Value dest = rewriter.create<tensor::EmptyOp>(
1963 loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1964 Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
1965 newResults.push_back(Elt: write->getResult(idx: 0));
1966 return success();
1967}
1968
1969// TODO: probably need some extra checks for reduction followed by consumer
1970// ops that may not commute (e.g. linear reduction + non-linear instructions).
1971static LogicalResult reductionPreconditions(LinalgOp op) {
1972 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1973 LDBG("reduction precondition failed: no reduction iterator\n");
1974 return failure();
1975 }
1976 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1977 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1978 if (indexingMap.isPermutation())
1979 continue;
1980
1981 Operation *reduceOp = matchLinalgReduction(&opOperand);
1982 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1983 LDBG("reduction precondition failed: reduction detection failed\n");
1984 return failure();
1985 }
1986 }
1987 return success();
1988}
1989
1990static LogicalResult
1991vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1992 bool flatten1DDepthwiseConv) {
1993 if (flatten1DDepthwiseConv) {
1994 LDBG("Vectorization of flattened convs with dynamic shapes is not "
1995 "supported\n");
1996 return failure();
1997 }
1998
1999 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2000 LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2001 return failure();
2002 }
2003
2004 // Support dynamic shapes in 1D depthwise convolution, but only in the
2005 // _channel_ dimension.
2006 Value lhs = conv.getDpsInputOperand(0)->get();
2007 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2008 auto shapeWithoutCh = lhsShape.drop_back(N: 1);
2009 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2010 LDBG("Dynamically-shaped op vectorization precondition failed: only "
2011 "channel dim can be dynamic\n");
2012 return failure();
2013 }
2014
2015 return success();
2016}
2017
2018static LogicalResult
2019vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2020 bool flatten1DDepthwiseConv) {
2021 if (isa<ConvolutionOpInterface>(op.getOperation()))
2022 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2023
2024 if (hasReductionIterator(op))
2025 return reductionPreconditions(op);
2026
2027 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2028 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2029 if (!isElementwise(op) &&
2030 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2031 op.getOperation()))
2032 return failure();
2033
2034 LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
2035 return success();
2036}
2037
2038/// Need to check if the inner-tiles are static/constant.
2039static LogicalResult
2040vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2041 ArrayRef<int64_t> inputVectorSizes) {
2042
2043 if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2044 return !getConstantIntValue(ofr: res).has_value();
2045 })) {
2046 LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
2047 return failure();
2048 }
2049 ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2050 bool satisfyEmptyCond = inputVectorSizes.empty() &&
2051 unpackOp.getDestType().hasStaticShape() &&
2052 unpackOp.getSourceType().hasStaticShape();
2053 if (!satisfyEmptyCond &&
2054 failed(Result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes)))
2055 return failure();
2056
2057 return success();
2058}
2059
2060static LogicalResult
2061vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2062 ArrayRef<int64_t> inputVectorSizes) {
2063
2064 TypedValue<RankedTensorType> source = sliceOp.getSource();
2065 auto sourceType = source.getType();
2066 if (!VectorType::isValidElementType(sourceType.getElementType()))
2067 return failure();
2068
2069 // Get the pad value.
2070 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2071 // scalar padding value. Note that:
2072 // * for in-bounds accesses,
2073 // the value is actually irrelevant. There are 2 cases in which xfer.read
2074 // accesses are known to be in-bounds:
2075 // 1. The source shape is static (output vector sizes would be based on
2076 // the source shape and hence all memory accesses would be in-bounds),
2077 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2078 // this case it is safe to assume that all memory accesses are in-bounds.
2079 //
2080 // When the value is not known and not needed, use 0. Otherwise, bail out.
2081 Value padValue = getStaticPadVal(sliceOp);
2082 bool isOutOfBoundsRead =
2083 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2084
2085 if (!padValue && isOutOfBoundsRead) {
2086 LDBG("Failed to get a pad value for out-of-bounds read access\n");
2087 return failure();
2088 }
2089 return success();
2090}
2091
2092namespace {
2093enum class ConvOperationKind { Conv, Pool };
2094} // namespace
2095
2096static bool isCastOfBlockArgument(Operation *op) {
2097 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2098 isa<BlockArgument>(op->getOperand(0));
2099}
2100
2101// Returns the ConvOperationKind of the op using reduceOp of the generic
2102// payload. If it is neither a convolution nor a pooling, it returns
2103// std::nullopt.
2104//
2105// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2106// + yield) and rhs is not used) then it is the body of a pooling
2107// If conv, check for single `mul` predecessor. The `mul` operands must be
2108// block arguments or extension of block arguments.
2109// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2110// must be block arguments or extension of block arguments.
2111static std::optional<ConvOperationKind>
2112getConvOperationKind(Operation *reduceOp) {
2113 int numBlockArguments =
2114 llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>);
2115
2116 switch (numBlockArguments) {
2117 case 1: {
2118 // Will be convolution if feeder is a MulOp.
2119 // A strength reduced version of MulOp for i1 type is AndOp which is also
2120 // supported. Otherwise, it can be pooling. This strength reduction logic
2121 // is in `buildBinaryFn` helper in the Linalg dialect.
2122 auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(),
2123 P: llvm::IsaPred<BlockArgument>);
2124 assert(feedValIt != reduceOp->operand_end() &&
2125 "Expected a non-block argument operand");
2126 Operation *feedOp = (*feedValIt).getDefiningOp();
2127 if (isCastOfBlockArgument(op: feedOp)) {
2128 return ConvOperationKind::Pool;
2129 }
2130
2131 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2132 (isa<arith::AndIOp>(feedOp) &&
2133 feedOp->getResultTypes()[0].isInteger(1))) &&
2134 llvm::all_of(feedOp->getOperands(), [](Value v) {
2135 if (isa<BlockArgument>(v))
2136 return true;
2137 if (Operation *op = v.getDefiningOp())
2138 return isCastOfBlockArgument(op);
2139 return false;
2140 }))) {
2141 return std::nullopt;
2142 }
2143
2144 return ConvOperationKind::Conv;
2145 }
2146 case 2:
2147 // Must be pooling
2148 return ConvOperationKind::Pool;
2149 default:
2150 return std::nullopt;
2151 }
2152}
2153
2154static bool isSupportedPoolKind(vector::CombiningKind kind) {
2155 switch (kind) {
2156 case vector::CombiningKind::ADD:
2157 case vector::CombiningKind::MAXNUMF:
2158 case vector::CombiningKind::MAXIMUMF:
2159 case vector::CombiningKind::MAXSI:
2160 case vector::CombiningKind::MAXUI:
2161 case vector::CombiningKind::MINNUMF:
2162 case vector::CombiningKind::MINIMUMF:
2163 case vector::CombiningKind::MINSI:
2164 case vector::CombiningKind::MINUI:
2165 return true;
2166 default:
2167 return false;
2168 }
2169}
2170
2171static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2172 auto getOperandType = [&](auto operand) {
2173 return dyn_cast<ShapedType>((operand->get()).getType());
2174 };
2175 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2176 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2177 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2178 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2179 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2180 // Note that this also ensures 2D and 3D convolutions are rejected.
2181 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2182 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2183 return failure();
2184
2185 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2186 if (!reduceOp)
2187 return failure();
2188
2189 auto maybeOper = getConvOperationKind(reduceOp);
2190 if (!maybeOper.has_value())
2191 return failure();
2192
2193 auto maybeKind = getCombinerOpKind(reduceOp);
2194 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2195 // can get strength reduced to `OR` which is also supported. This strength
2196 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2197 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2198 *maybeKind != vector::CombiningKind::OR) &&
2199 (*maybeOper != ConvOperationKind::Pool ||
2200 !isSupportedPoolKind(*maybeKind)))) {
2201 return failure();
2202 }
2203
2204 auto rhsRank = rhsShapedType.getRank();
2205 if (*maybeOper == ConvOperationKind::Pool) {
2206 if (rhsRank != 1)
2207 return failure();
2208 } else {
2209 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2210 return failure();
2211 }
2212
2213 return success();
2214}
2215
2216static LogicalResult vectorizeLinalgOpPrecondition(
2217 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2218 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2219 // tensor with dimension of 0 cannot be vectorized.
2220 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2221 return failure();
2222 // Check API contract for input vector sizes.
2223 if (!inputVectorSizes.empty() &&
2224 failed(vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(),
2225 inputVectorSizes)))
2226 return failure();
2227
2228 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2229 linalgOp, flatten1DDepthwiseConv))) {
2230 LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
2231 return failure();
2232 }
2233
2234 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2235
2236 // Register CustomVectorizationPrecondition for extractOp.
2237 customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition);
2238
2239 // All types in the body should be a supported element type for VectorType.
2240 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2241 // Check if any custom hook can vectorize the inner op.
2242 if (llvm::any_of(
2243 customPreconditions,
2244 [&](const CustomVectorizationPrecondition &customPrecondition) {
2245 return succeeded(
2246 customPrecondition(&innerOp, vectorizeNDExtract));
2247 })) {
2248 continue;
2249 }
2250 if (!llvm::all_of(innerOp.getOperandTypes(),
2251 VectorType::isValidElementType)) {
2252 return failure();
2253 }
2254 if (!llvm::all_of(innerOp.getResultTypes(),
2255 VectorType::isValidElementType)) {
2256 return failure();
2257 }
2258 }
2259 if (isElementwise(linalgOp))
2260 return success();
2261
2262 // TODO: isaConvolutionOpInterface that can also infer from generic
2263 // features. But we will still need stride/dilation attributes that will be
2264 // annoying to reverse-engineer...
2265 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2266 return vectorizeConvOpPrecondition(linalgOp);
2267
2268 // TODO: the common vector shape is equal to the static loop sizes only when
2269 // all indexing maps are projected permutations. For convs and stencils the
2270 // logic will need to evolve.
2271 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2272 LDBG("precondition failed: not projected permutations\n");
2273 return failure();
2274 }
2275 if (failed(reductionPreconditions(linalgOp))) {
2276 LDBG("precondition failed: reduction preconditions\n");
2277 return failure();
2278 }
2279 return success();
2280}
2281
2282static LogicalResult
2283vectorizePackOpPrecondition(linalg::PackOp packOp,
2284 ArrayRef<int64_t> inputVectorSizes) {
2285 auto padValue = packOp.getPaddingValue();
2286 Attribute cstAttr;
2287 if (padValue && !matchPattern(padValue, m_Constant(bind_value: &cstAttr))) {
2288 LDBG("pad value is not constant: " << packOp << "\n");
2289 return failure();
2290 }
2291 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2292 bool satisfyEmptyCond = true;
2293 if (inputVectorSizes.empty()) {
2294 if (!packOp.getDestType().hasStaticShape() ||
2295 !packOp.getSourceType().hasStaticShape())
2296 satisfyEmptyCond = false;
2297 }
2298
2299 if (!satisfyEmptyCond &&
2300 failed(vector::isValidMaskedInputVector(
2301 shape: resultTensorShape.take_front(N: packOp.getSourceRank()),
2302 inputVectorSizes)))
2303 return failure();
2304
2305 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2306 return !getConstantIntValue(ofr: v).has_value();
2307 })) {
2308 LDBG("inner_tiles must be constant: " << packOp << "\n");
2309 return failure();
2310 }
2311
2312 return success();
2313}
2314
2315static LogicalResult
2316vectorizePadOpPrecondition(tensor::PadOp padOp,
2317 ArrayRef<int64_t> inputVectorSizes) {
2318 auto padValue = padOp.getConstantPaddingValue();
2319 if (!padValue) {
2320 LDBG("pad value is not constant: " << padOp << "\n");
2321 return failure();
2322 }
2323
2324 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2325 if (failed(Result: vector::isValidMaskedInputVector(shape: resultTensorShape,
2326 inputVectorSizes)))
2327 return failure();
2328
2329 // Padding with non-zero low pad values is not supported, unless the
2330 // corresponding result dim is 1 as this would require shifting the results to
2331 // the right for the low padded dims by the required amount of low padding.
2332 // However, we do support low padding if the dims being low padded have result
2333 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2334 // input size of that dimension will be dynamically zero (as the sum of the
2335 // low pad and input dim size has to be one) and hence we will create a zero
2336 // mask as the lowering logic just makes the mask one for the input dim size -
2337 // which is zero here. Hence we will load the pad value which is what we want
2338 // in this case. If the low pad is dynamically zero then the lowering is
2339 // correct as well as no shifts are necessary.
2340 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2341 Value padValue = en.value();
2342 unsigned pos = en.index();
2343 std::optional<int64_t> pad = getConstantIntValue(ofr: padValue);
2344 return (!pad.has_value() || pad.value() != 0) &&
2345 resultTensorShape[pos] != 1;
2346 })) {
2347 LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
2348 return failure();
2349 }
2350
2351 return success();
2352}
2353
2354/// Preconditions for scalable vectors. This is quite restrictive - it models
2355/// the fact that in practice we would only make selected dimensions scalable.
2356static LogicalResult
2357vectorizeScalableVectorPrecondition(Operation *op,
2358 ArrayRef<int64_t> inputVectorSizes,
2359 ArrayRef<bool> inputScalableVecDims) {
2360 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2361 "Number of input vector sizes and scalable dims doesn't match");
2362
2363 size_t numOfScalableDims =
2364 llvm::count_if(Range&: inputScalableVecDims, P: [](bool flag) { return flag; });
2365
2366 if (numOfScalableDims == 0)
2367 return success();
2368
2369 auto linalgOp = dyn_cast<LinalgOp>(op);
2370
2371 // Cond 1: There's been no need for scalable vectorisation of
2372 // non-linalg Ops so far
2373 if (!linalgOp)
2374 return failure();
2375
2376 // Cond 2: There's been no need for more than 2 scalable dims so far
2377 if (numOfScalableDims > 2)
2378 return failure();
2379
2380 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2381 // it matches one of the supported cases:
2382 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2383 // (*).
2384 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2385 // parallel dims.
2386 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2387 // dim.
2388 // The 2nd restriction above means that only Matmul-like Ops are supported
2389 // when 2 dims are scalable, e.g. :
2390 // * iterators = [parallel, parallel, reduction]
2391 // * scalable flags = [true, true, false]
2392 //
2393 // (*) Non-unit dims get folded away in practice.
2394 // TODO: Relax these conditions as good motivating examples are identified.
2395
2396 // Find the first scalable flag.
2397 bool seenNonUnitParallel = false;
2398 auto iterators = linalgOp.getIteratorTypesArray();
2399 SmallVector<bool> scalableFlags(inputScalableVecDims);
2400 int64_t idx = scalableFlags.size() - 1;
2401 while (!scalableFlags[idx]) {
2402 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2403 seenNonUnitParallel |=
2404 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2405
2406 iterators.pop_back();
2407 scalableFlags.pop_back();
2408 --idx;
2409 }
2410
2411 // Analyze the iterator corresponding to the first scalable dim.
2412 switch (iterators.back()) {
2413 case utils::IteratorType::reduction: {
2414 // Check 3. above is met.
2415 if (iterators.size() != inputVectorSizes.size()) {
2416 LDBG("Non-trailing reduction dim requested for scalable "
2417 "vectorization\n");
2418 return failure();
2419 }
2420 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2421 LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2422 "is not supported\n");
2423 return failure();
2424 }
2425 break;
2426 }
2427 case utils::IteratorType::parallel: {
2428 // Check 1. and 2. above are met.
2429 if (seenNonUnitParallel) {
2430 LDBG("Inner parallel dim not requested for scalable "
2431 "vectorization\n");
2432 return failure();
2433 }
2434 break;
2435 }
2436 }
2437
2438 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2439 // supported for which expect the folowing config:
2440 // * iterators = [parallel, parallel, reduction]
2441 // * scalable flags = [true, true, false]
2442 if (numOfScalableDims == 2) {
2443 // Disallow below case which breaks 3. above:
2444 // * iterators = [..., parallel, reduction]
2445 // * scalable flags = [..., true, true]
2446 if (iterators.back() == utils::IteratorType::reduction) {
2447 LDBG("Higher dim than the trailing reduction dim requested for scalable "
2448 "vectorization\n");
2449 return failure();
2450 }
2451 scalableFlags.pop_back();
2452 iterators.pop_back();
2453
2454 if (!scalableFlags.back() ||
2455 (iterators.back() != utils::IteratorType::parallel))
2456 return failure();
2457 }
2458
2459 // Check to not let go the matmul with extended semantic, through this
2460 // transform.
2461 if (linalgOp.hasUserDefinedMaps())
2462 return failure();
2463
2464 // Cond 4: Only the following ops are supported in the
2465 // presence of scalable vectors
2466 return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2467 isa<linalg::MatmulTransposeAOp>(op) ||
2468 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2469 isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2470}
2471
2472LogicalResult mlir::linalg::vectorizeOpPrecondition(
2473 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2474 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2475 bool flatten1DDepthwiseConv) {
2476
2477 if (!hasVectorizationImpl(op))
2478 return failure();
2479
2480 if (failed(Result: vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2481 inputScalableVecDims)))
2482 return failure();
2483
2484 return TypeSwitch<Operation *, LogicalResult>(op)
2485 .Case<linalg::LinalgOp>([&](auto linalgOp) {
2486 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2487 vectorizeNDExtract,
2488 flatten1DDepthwiseConv);
2489 })
2490 .Case<tensor::PadOp>([&](auto padOp) {
2491 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2492 })
2493 .Case<linalg::PackOp>([&](auto packOp) {
2494 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2495 })
2496 .Case<linalg::UnPackOp>([&](auto unpackOp) {
2497 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2498 })
2499 .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2500 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2501 })
2502 .Default([](auto) { return failure(); });
2503}
2504
2505/// Converts affine.apply Ops to arithmetic operations.
2506static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2507 OpBuilder::InsertionGuard g(rewriter);
2508 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2509
2510 for (auto op : make_early_inc_range(toReplace)) {
2511 rewriter.setInsertionPoint(op);
2512 auto expanded = affine::expandAffineExpr(
2513 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2514 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2515 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2516 rewriter.replaceOp(op, expanded);
2517 }
2518}
2519
2520bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2521 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2522 tensor::InsertSliceOp>(op);
2523}
2524
2525/// Emit a suitable vector form for an operation. If provided,
2526/// `inputVectorSizes` are used to vectorize this operation.
2527/// `inputVectorSizes` must match the rank of the iteration space of the
2528/// operation and the input vector sizes must be greater than or equal to
2529/// their counterpart iteration space sizes, if static. `inputVectorShapes`
2530/// also allows the vectorization of operations with dynamic shapes.
2531LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2532 ArrayRef<int64_t> inputVectorSizes,
2533 ArrayRef<bool> inputScalableVecDims,
2534 bool vectorizeNDExtract,
2535 bool flatten1DDepthwiseConv) {
2536 LDBG("Attempting to vectorize:\n" << *op << "\n");
2537 LDBG("Input vector sizes: ");
2538 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2539 LLVM_DEBUG(llvm::dbgs() << "\n");
2540 LDBG("Input scalable vector dims: ");
2541 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2542 LLVM_DEBUG(llvm::dbgs() << "\n");
2543
2544 if (failed(Result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2545 vectorizeNDExtract,
2546 flatten1DDepthwiseConv))) {
2547 LDBG("Vectorization pre-conditions failed\n");
2548 return failure();
2549 }
2550
2551 // Initialize vectorization state.
2552 VectorizationState state(rewriter);
2553 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2554 if (failed(state.initState(rewriter, linalgOp: linalgOp, inputVectorSizes,
2555 inputScalableVecDims))) {
2556 LDBG("Vectorization state couldn't be initialized\n");
2557 return failure();
2558 }
2559 }
2560
2561 SmallVector<Value> results;
2562 auto vectorizeResult =
2563 TypeSwitch<Operation *, LogicalResult>(op)
2564 .Case<linalg::LinalgOp>([&](auto linalgOp) {
2565 // TODO: isaConvolutionOpInterface that can also infer from
2566 // generic features. Will require stride/dilation attributes
2567 // inference.
2568 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2569 FailureOr<Operation *> convOr = vectorizeConvolution(
2570 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2571 flatten1DDepthwiseConv);
2572 if (succeeded(convOr)) {
2573 llvm::append_range(results, (*convOr)->getResults());
2574 return success();
2575 }
2576
2577 LDBG("Unsupported convolution can't be vectorized.\n");
2578 return failure();
2579 }
2580
2581 LDBG("Vectorize generic by broadcasting to the canonical vector "
2582 "shape\n");
2583
2584 // Pre-process before proceeding.
2585 convertAffineApply(rewriter, linalgOp);
2586
2587 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2588 // to 'OpBuilder' when it is passed over to some methods like
2589 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2590 // erase an op within these methods, the actual rewriter won't be
2591 // notified and we will end up with read-after-free issues!
2592 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2593 })
2594 .Case<tensor::PadOp>([&](auto padOp) {
2595 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2596 results);
2597 })
2598 .Case<linalg::PackOp>([&](auto packOp) {
2599 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2600 results);
2601 })
2602 .Case<linalg::UnPackOp>([&](auto unpackOp) {
2603 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2604 inputVectorSizes, results);
2605 })
2606 .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2607 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2608 results);
2609 })
2610 .Default([](auto) { return failure(); });
2611
2612 if (failed(vectorizeResult)) {
2613 LDBG("Vectorization failed\n");
2614 return failure();
2615 }
2616
2617 if (!results.empty())
2618 rewriter.replaceOp(op, newValues: results);
2619 else
2620 rewriter.eraseOp(op);
2621
2622 return success();
2623}
2624
2625LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2626 memref::CopyOp copyOp) {
2627 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2628 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2629 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2630 return failure();
2631
2632 auto srcElementType = getElementTypeOrSelf(srcType);
2633 auto dstElementType = getElementTypeOrSelf(dstType);
2634 if (!VectorType::isValidElementType(srcElementType) ||
2635 !VectorType::isValidElementType(dstElementType))
2636 return failure();
2637
2638 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2639 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2640
2641 Location loc = copyOp->getLoc();
2642 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
2643 SmallVector<Value> indices(srcType.getRank(), zero);
2644
2645 Value readValue = rewriter.create<vector::TransferReadOp>(
2646 loc, readType, copyOp.getSource(), indices,
2647 rewriter.getMultiDimIdentityMap(rank: srcType.getRank()));
2648 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2649 readValue =
2650 rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
2651 readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2652 }
2653 Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2654 loc, readValue, copyOp.getTarget(), indices,
2655 rewriter.getMultiDimIdentityMap(rank: srcType.getRank()));
2656 rewriter.replaceOp(copyOp, writeValue->getResults());
2657 return success();
2658}
2659
2660//----------------------------------------------------------------------------//
2661// Misc. vectorization patterns.
2662//----------------------------------------------------------------------------//
2663/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2664/// given operation type OpTy.
2665template <typename OpTy>
2666struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2667 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2668
2669 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2670 PatternRewriter &rewriter) const final {
2671 bool changed = false;
2672 // Insert users in vector, because some users may be replaced/removed.
2673 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2674 if (auto op = dyn_cast<OpTy>(user))
2675 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2676 return success(IsSuccess: changed);
2677 }
2678
2679protected:
2680 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2681 tensor::PadOp padOp, OpTy op) const = 0;
2682};
2683
2684/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2685/// ```
2686/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2687/// %r = vector.transfer_read %0[%c0, %c0], %cst
2688/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2689/// ```
2690/// is rewritten to:
2691/// ```
2692/// %r = vector.transfer_read %src[%c0, %c0], %padding
2693/// {in_bounds = [true, true]}
2694/// : tensor<?x?xf32>, vector<17x5xf32>
2695/// ```
2696/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2697/// sure that the original padding value %cst was never used.
2698///
2699/// This rewrite is possible if:
2700/// - `xferOp` has no out-of-bounds dims or mask.
2701/// - Low padding is static 0.
2702/// - Single, scalar padding value.
2703struct PadOpVectorizationWithTransferReadPattern
2704 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2705 using VectorizePadOpUserPattern<
2706 vector::TransferReadOp>::VectorizePadOpUserPattern;
2707
2708 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2709 vector::TransferReadOp xferOp) const override {
2710 // Low padding must be static 0.
2711 if (!padOp.hasZeroLowPad())
2712 return failure();
2713 // Pad value must be a constant.
2714 auto padValue = padOp.getConstantPaddingValue();
2715 if (!padValue)
2716 return failure();
2717 // Padding value of existing `xferOp` is unused.
2718 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2719 return failure();
2720
2721 rewriter.modifyOpInPlace(xferOp, [&]() {
2722 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2723 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2724 rewriter.getBoolArrayAttr(inBounds));
2725 xferOp.getBaseMutable().assign(padOp.getSource());
2726 xferOp.getPaddingMutable().assign(padValue);
2727 });
2728
2729 return success();
2730 }
2731};
2732
2733/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2734/// This pattern rewrites TransferWriteOps that write to a padded tensor
2735/// value, where the same amount of padding is immediately removed again after
2736/// the write. In such cases, the TransferWriteOp can write to the non-padded
2737/// tensor value and apply out-of-bounds masking. E.g.:
2738/// ```
2739/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2740/// : tensor<...> to tensor<?x?xf32>
2741/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2742/// %2 = vector.transfer_write %vec, %1[...]
2743/// : vector<17x5xf32>, tensor<17x5xf32>
2744/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2745/// : tensor<17x5xf32> to tensor<?x?xf32>
2746/// ```
2747/// is rewritten to:
2748/// ```
2749/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2750/// : tensor<...> to tensor<?x?xf32>
2751/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2752/// tensor<?x?xf32>
2753/// ```
2754/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2755/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2756/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2757/// from %r's old dimensions.
2758///
2759/// This rewrite is possible if:
2760/// - Low padding is static 0.
2761/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2762/// ExtractSliceOp trims the same amount of padding that was added
2763/// beforehand.
2764/// - Single, scalar padding value.
2765struct PadOpVectorizationWithTransferWritePattern
2766 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2767 using VectorizePadOpUserPattern<
2768 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2769
2770 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2771 vector::TransferWriteOp xferOp) const override {
2772 // TODO: support 0-d corner case.
2773 if (xferOp.getTransferRank() == 0)
2774 return failure();
2775
2776 // Low padding must be static 0.
2777 if (!padOp.hasZeroLowPad())
2778 return failure();
2779 // Pad value must be a constant.
2780 auto padValue = padOp.getConstantPaddingValue();
2781 if (!padValue)
2782 return failure();
2783 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2784 if (!xferOp->hasOneUse())
2785 return failure();
2786 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2787 if (!trimPadding)
2788 return failure();
2789 // Only static zero offsets supported when trimming padding.
2790 if (!trimPadding.hasZeroOffset())
2791 return failure();
2792 // trimPadding must remove the amount of padding that was added earlier.
2793 if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding))
2794 return failure();
2795
2796 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2797 rewriter.setInsertionPoint(xferOp);
2798
2799 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2800 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2801 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2802 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2803 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2804 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2805
2806 return success();
2807 }
2808
2809 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2810 /// i.e., same dimensions.
2811 ///
2812 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2813 /// dimensions, this function tries to infer the (static) tensor size by
2814 /// looking at the defining op and utilizing op-specific knowledge.
2815 ///
2816 /// This is a conservative analysis. In case equal tensor sizes cannot be
2817 /// proven statically, this analysis returns `false` even though the tensor
2818 /// sizes may turn out to be equal at runtime.
2819 bool hasSameTensorSize(Value beforePadding,
2820 tensor::ExtractSliceOp afterTrimming) const {
2821 // If the input to tensor::PadOp is a CastOp, try with both CastOp
2822 // result and CastOp operand.
2823 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2824 if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming: afterTrimming))
2825 return true;
2826
2827 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2828 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2829 // Only RankedTensorType supported.
2830 if (!t1 || !t2)
2831 return false;
2832 // Rank of both values must be the same.
2833 if (t1.getRank() != t2.getRank())
2834 return false;
2835
2836 // All static dimensions must be the same. Mixed cases (e.g., dimension
2837 // static in `t1` but dynamic in `t2`) are not supported.
2838 for (unsigned i = 0; i < t1.getRank(); ++i) {
2839 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2840 return false;
2841 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2842 return false;
2843 }
2844
2845 // Nothing more to check if all dimensions are static.
2846 if (t1.getNumDynamicDims() == 0)
2847 return true;
2848
2849 // All dynamic sizes must be the same. The only supported case at the
2850 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2851 // thereof).
2852
2853 // Apart from CastOp, only ExtractSliceOp is supported.
2854 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2855 if (!beforeSlice)
2856 return false;
2857
2858 assert(static_cast<size_t>(t1.getRank()) ==
2859 beforeSlice.getMixedSizes().size());
2860 assert(static_cast<size_t>(t2.getRank()) ==
2861 afterTrimming.getMixedSizes().size());
2862
2863 for (unsigned i = 0; i < t1.getRank(); ++i) {
2864 // Skip static dimensions.
2865 if (!t1.isDynamicDim(i))
2866 continue;
2867 auto size1 = beforeSlice.getMixedSizes()[i];
2868 auto size2 = afterTrimming.getMixedSizes()[i];
2869
2870 // Case 1: Same value or same constant int.
2871 if (isEqualConstantIntOrValue(size1, size2))
2872 continue;
2873
2874 // Other cases: Take a deeper look at defining ops of values.
2875 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2876 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2877 if (!v1 || !v2)
2878 return false;
2879
2880 // Case 2: Both values are identical AffineMinOps. (Should not happen if
2881 // CSE is run.)
2882 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2883 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2884 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2885 minOp1.getOperands() == minOp2.getOperands())
2886 continue;
2887
2888 // Add additional cases as needed.
2889 }
2890
2891 // All tests passed.
2892 return true;
2893 }
2894};
2895
2896/// Returns the effective Pad value for the input op, provided it's a scalar.
2897///
2898/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2899/// this Op performs padding, retrieve the padding value provided that it's
2900/// a scalar and static/fixed for all the padded values. Returns an empty value
2901/// otherwise.
2902///
2903/// TODO: This is used twice (when checking vectorization pre-conditions and
2904/// when vectorizing). Cache results instead of re-running.
2905static Value getStaticPadVal(Operation *op) {
2906 if (!op)
2907 return {};
2908
2909 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2910 // being broadcast, provided that it's a scalar.
2911 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2912 auto source = bcast.getSource();
2913 if (llvm::dyn_cast<VectorType>(source.getType()))
2914 return {};
2915
2916 return source;
2917 }
2918
2919 // 2. linalg.fill - use the scalar input value that used to fill the output
2920 // tensor.
2921 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2922 return fill.getInputs()[0];
2923 }
2924
2925 // 3. tensor.generateOp - can't guarantee the value is fixed without
2926 // analysing, bail out.
2927 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2928 return {};
2929 }
2930
2931 // 4. vector.transfer_write - inspect the input vector that's written from. If
2932 // if contains a single value that has been broadcast (e.g. via
2933 // vector.broadcast), extract it, fail otherwise.
2934 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2935 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2936
2937 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2938 // than the input tensor, then, provided it's constant, we'll extract the
2939 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2940 // TODO: Clarify the semantics when the input tensor is larger than the
2941 // destination.
2942 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2943 return getStaticPadVal(slice.getDest().getDefiningOp());
2944
2945 return {};
2946}
2947
2948static LogicalResult
2949vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2950 ArrayRef<int64_t> inputVectorSizes,
2951 SmallVectorImpl<Value> &newResults) {
2952 // TODO: Introduce a parent class that will handle the insertion point update.
2953 OpBuilder::InsertionGuard g(rewriter);
2954 rewriter.setInsertionPoint(sliceOp);
2955
2956 TypedValue<RankedTensorType> source = sliceOp.getSource();
2957 auto sourceType = source.getType();
2958 auto resultType = sliceOp.getResultType();
2959
2960 Value padValue = getStaticPadVal(sliceOp);
2961
2962 if (!padValue) {
2963 auto elemType = sourceType.getElementType();
2964 padValue = rewriter.create<arith::ConstantOp>(
2965 sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2966 }
2967
2968 // 2. Get the vector shape
2969 SmallVector<int64_t> vecShape;
2970 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2971 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2972 if (!inputVectorSizes.empty()) {
2973 vecShape.push_back(Elt: inputVectorSizes[i]);
2974 } else if (!sourceType.isDynamicDim(i)) {
2975 vecShape.push_back(Elt: sourceType.getDimSize(i));
2976 } else if (!resultType.isDynamicDim(i)) {
2977 // Source shape is not statically known, but result shape is.
2978 // Vectorize with size of result shape. This may be larger than the
2979 // source size.
2980 // FIXME: Using rankDiff implies that the source tensor is inserted at
2981 // the end of the destination tensor. However, that's not required.
2982 vecShape.push_back(Elt: resultType.getDimSize(rankDiff + i));
2983 } else {
2984 // Neither source nor result dim of padOp is static. Cannot vectorize
2985 // the copy.
2986 return failure();
2987 }
2988 }
2989 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2990
2991 // 3. Generate TransferReadOp + TransferWriteOp
2992 auto loc = sliceOp.getLoc();
2993
2994 // Create read
2995 SmallVector<Value> readIndices(
2996 vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
2997 Value read = mlir::vector::createReadOrMaskedRead(
2998 builder&: rewriter, loc: loc, source, inputVectorSizes: vecType.getShape(), padValue,
2999 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3000
3001 // Create write
3002 auto writeIndices =
3003 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3004 Operation *write =
3005 createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3006 writeIndices, inputVectorSizes.empty());
3007
3008 // 4. Finalize
3009 newResults.push_back(Elt: write->getResult(idx: 0));
3010
3011 return success();
3012}
3013
3014/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3015/// ```
3016/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3017/// %r = tensor.insert_slice %0
3018/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3019/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3020/// ```
3021/// is rewritten to:
3022/// ```
3023/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3024/// : tensor<?x?xf32>, vector<17x5xf32>
3025/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3026/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3027/// ```
3028///
3029/// This rewrite is possible if:
3030/// - Low padding is static 0.
3031/// - `padOp` result shape is static.
3032/// - The entire padded tensor is inserted.
3033/// (Implies that sizes of `insertOp` are all static.)
3034/// - Only unit strides in `insertOp`.
3035/// - Single, scalar padding value.
3036/// - `padOp` result not used as destination.
3037struct PadOpVectorizationWithInsertSlicePattern
3038 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3039 using VectorizePadOpUserPattern<
3040 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3041
3042 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3043 tensor::InsertSliceOp insertOp) const override {
3044 // Low padding must be static 0.
3045 if (!padOp.hasZeroLowPad())
3046 return failure();
3047 // Only unit stride supported.
3048 if (!insertOp.hasUnitStride())
3049 return failure();
3050 // Pad value must be a constant.
3051 auto padValue = padOp.getConstantPaddingValue();
3052 if (!padValue)
3053 return failure();
3054 // Dynamic shapes not supported.
3055 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3056 return failure();
3057 // Pad result not used as destination.
3058 if (insertOp.getDest() == padOp.getResult())
3059 return failure();
3060
3061 auto vecType = VectorType::get(padOp.getType().getShape(),
3062 padOp.getType().getElementType());
3063 unsigned vecRank = vecType.getRank();
3064 unsigned tensorRank = insertOp.getType().getRank();
3065
3066 // Check if sizes match: Insert the entire tensor into most minor dims.
3067 // (No permutations allowed.)
3068 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3069 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3070 if (!llvm::all_of(
3071 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3072 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3073 }))
3074 return failure();
3075
3076 // Insert the TransferReadOp and TransferWriteOp at the position of the
3077 // InsertSliceOp.
3078 rewriter.setInsertionPoint(insertOp);
3079
3080 // Generate TransferReadOp: Read entire source tensor and add high
3081 // padding.
3082 SmallVector<Value> readIndices(
3083 vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3084 auto read = rewriter.create<vector::TransferReadOp>(
3085 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3086
3087 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3088 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3089 // source must fit into the destination at the specified offsets.
3090 auto writeIndices = getValueOrCreateConstantIndexOp(
3091 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3092 SmallVector<bool> inBounds(vecRank, true);
3093 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3094 insertOp, read, insertOp.getDest(), writeIndices,
3095 ArrayRef<bool>{inBounds});
3096
3097 return success();
3098 }
3099};
3100
3101void mlir::linalg::populatePadOpVectorizationPatterns(
3102 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3103 patterns.add<PadOpVectorizationWithTransferReadPattern,
3104 PadOpVectorizationWithTransferWritePattern,
3105 PadOpVectorizationWithInsertSlicePattern>(
3106 arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1);
3107}
3108
3109//----------------------------------------------------------------------------//
3110// Forwarding patterns
3111//----------------------------------------------------------------------------//
3112
3113/// Check whether there is any interleaved use of any `values` between
3114/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3115/// is in a different block.
3116static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3117 ValueRange values) {
3118 if (firstOp->getBlock() != secondOp->getBlock() ||
3119 !firstOp->isBeforeInBlock(other: secondOp)) {
3120 LDBG("interleavedUses precondition failed, firstOp: "
3121 << *firstOp << ", second op: " << *secondOp << "\n");
3122 return true;
3123 }
3124 for (auto v : values) {
3125 for (auto &u : v.getUses()) {
3126 Operation *owner = u.getOwner();
3127 if (owner == firstOp || owner == secondOp)
3128 continue;
3129 // TODO: this is too conservative, use dominance info in the future.
3130 if (owner->getBlock() == firstOp->getBlock() &&
3131 (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner)))
3132 continue;
3133 LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
3134 << ", second op: " << *secondOp << "\n");
3135 return true;
3136 }
3137 }
3138 return false;
3139}
3140
3141/// Return the unique subview use of `v` if it is indeed unique, null
3142/// otherwise.
3143static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3144 memref::SubViewOp subViewOp;
3145 for (auto &u : v.getUses()) {
3146 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3147 if (subViewOp)
3148 return memref::SubViewOp();
3149 subViewOp = newSubViewOp;
3150 }
3151 }
3152 return subViewOp;
3153}
3154
3155/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3156/// when available.
3157LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
3158 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3159
3160 // TODO: support mask.
3161 if (xferOp.getMask())
3162 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3163
3164 // Transfer into `view`.
3165 Value viewOrAlloc = xferOp.getBase();
3166 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3167 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3168 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3169
3170 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3171 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3172 if (!subViewOp)
3173 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3174 Value subView = subViewOp.getResult();
3175
3176 // Find the copy into `subView` without interleaved uses.
3177 memref::CopyOp copyOp;
3178 for (auto &u : subView.getUses()) {
3179 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3180 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3181 if (newCopyOp.getTarget() != subView)
3182 continue;
3183 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3184 continue;
3185 copyOp = newCopyOp;
3186 break;
3187 }
3188 }
3189 if (!copyOp)
3190 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3191
3192 // Find the fill into `viewOrAlloc` without interleaved uses before the
3193 // copy.
3194 FillOp maybeFillOp;
3195 for (auto &u : viewOrAlloc.getUses()) {
3196 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3197 assert(isa<MemRefType>(newFillOp.output().getType()));
3198 if (newFillOp.output() != viewOrAlloc)
3199 continue;
3200 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3201 continue;
3202 maybeFillOp = newFillOp;
3203 break;
3204 }
3205 }
3206 // Ensure padding matches.
3207 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3208 return rewriter.notifyMatchFailure(xferOp,
3209 "padding value does not match fill");
3210
3211 // `in` is the subview that memref.copy reads. Replace it.
3212 Value in = copyOp.getSource();
3213
3214 // memref.copy + linalg.fill can be used to create a padded local buffer.
3215 // The `masked` attribute is only valid on this padded buffer.
3216 // When forwarding to vector.transfer_read, the attribute must be reset
3217 // conservatively.
3218 auto vectorType = xferOp.getVectorType();
3219 Value res = rewriter.create<vector::TransferReadOp>(
3220 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3221 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3222 rewriter.getBoolArrayAttr(
3223 SmallVector<bool>(vectorType.getRank(), false)));
3224
3225 if (maybeFillOp)
3226 rewriter.eraseOp(op: maybeFillOp);
3227 rewriter.eraseOp(op: copyOp);
3228 rewriter.replaceOp(xferOp, res);
3229
3230 return success();
3231}
3232
3233/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3234/// when available.
3235LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
3236 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3237 // TODO: support mask.
3238 if (xferOp.getMask())
3239 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3240
3241 // Transfer into `viewOrAlloc`.
3242 Value viewOrAlloc = xferOp.getBase();
3243 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3244 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3245 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3246
3247 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3248 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3249 if (!subViewOp)
3250 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3251 Value subView = subViewOp.getResult();
3252
3253 // Find the copy from `subView` without interleaved uses.
3254 memref::CopyOp copyOp;
3255 for (auto &u : subViewOp.getResult().getUses()) {
3256 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3257 if (newCopyOp.getSource() != subView)
3258 continue;
3259 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3260 continue;
3261 copyOp = newCopyOp;
3262 break;
3263 }
3264 }
3265 if (!copyOp)
3266 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3267
3268 // `out` is the subview copied into that we replace.
3269 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3270 Value out = copyOp.getTarget();
3271
3272 // Forward vector.transfer into copy.
3273 // memref.copy + linalg.fill can be used to create a padded local buffer.
3274 // The `masked` attribute is only valid on this padded buffer.
3275 // When forwarding to vector.transfer_write, the attribute must be reset
3276 // conservatively.
3277 auto vector = xferOp.getVector();
3278 rewriter.create<vector::TransferWriteOp>(
3279 xferOp.getLoc(), vector, out, xferOp.getIndices(),
3280 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3281 rewriter.getBoolArrayAttr(SmallVector<bool>(
3282 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3283
3284 rewriter.eraseOp(op: copyOp);
3285 rewriter.eraseOp(op: xferOp);
3286
3287 return success();
3288}
3289
3290//===----------------------------------------------------------------------===//
3291// Convolution vectorization patterns
3292//===----------------------------------------------------------------------===//
3293
3294template <int N>
3295static void bindShapeDims(ShapedType shapedType) {}
3296
3297template <int N, typename IntTy, typename... IntTy2>
3298static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3299 val = shapedType.getShape()[N];
3300 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3301}
3302
3303/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3304template <typename... IntTy>
3305static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3306 bindShapeDims<0>(shapedType, vals...);
3307}
3308
3309namespace {
3310/// Generate a vector implementation for either:
3311/// ```
3312/// Op def: ( w, kw )
3313/// Iters: ({Par(), Red()})
3314/// Layout: {{w + kw}, {kw}, {w}}
3315/// ```
3316/// kw is unrolled.
3317///
3318/// or
3319///
3320/// ```
3321/// Op def: ( n, w, c, kw, f )
3322/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3323/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3324/// ```
3325/// kw is unrolled, w is unrolled iff dilationW > 1.
3326///
3327/// or
3328///
3329/// ```
3330/// Op def: ( n, c, w, f, kw )
3331/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3332/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3333/// ```
3334/// kw is unrolled, w is unrolled iff dilationW > 1.
3335///
3336/// or
3337///
3338/// ```
3339/// Op def: ( n, w, c, kw )
3340/// Iters: ({Par(), Par(), Par(), Red()})
3341/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3342/// ```
3343/// kw is unrolled, w is unrolled iff dilationW > 1.
3344struct Conv1DGenerator
3345 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3346 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3347 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3348
3349 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3350 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3351 resShaped = linalgOp.getDpsInitOperand(0)->get();
3352 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3353 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3354 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3355
3356 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3357 redOp = reduceOp->getName().getIdentifier();
3358
3359 setConvOperationKind(reduceOp);
3360
3361 auto maybeKind = getCombinerOpKind(reduceOp);
3362 reductionKind = maybeKind.value();
3363
3364 // The ConvolutionOpInterface gives us guarantees of existence for
3365 // strides/dilations. However, we do not need to rely on those, we can
3366 // simply use them if present, otherwise use the default and let the generic
3367 // conv. matcher in the ConvGenerator succeed or fail.
3368 auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3369 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3370 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3371 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3372 }
3373
3374 /// Generate a vector implementation for:
3375 /// ```
3376 /// Op def: ( w, kw )
3377 /// Iters: ({Par(), Red()})
3378 /// Layout: {{w + kw}, {kw}, {w}}
3379 /// ```
3380 /// kw is always unrolled.
3381 ///
3382 /// or
3383 ///
3384 /// ```
3385 /// Op def: ( n, w, c, kw, f )
3386 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3387 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3388 /// ```
3389 /// kw is always unrolled.
3390 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3391 /// > 1.
3392 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3393 int64_t nSize, wSize, cSize, kwSize, fSize;
3394 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3395 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3396 switch (conv1DOpOrder) {
3397 case Conv1DOpOrder::W:
3398 // Initialize unused dimensions
3399 nSize = fSize = cSize = 0;
3400 // out{W}
3401 bindShapeDims(resShapedType, wSize);
3402 // kernel{kw}
3403 bindShapeDims(rhsShapedType, kwSize);
3404 lhsShape = {// iw = ow + kw - 1
3405 // (i.e. 16 convolved with 3 -> 14)
3406 (wSize + kwSize - 1)};
3407 rhsShape = {kwSize};
3408 resShape = {wSize};
3409 break;
3410 case Conv1DOpOrder::Nwc:
3411 // out{n, w, f}
3412 bindShapeDims(resShapedType, nSize, wSize, fSize);
3413 switch (oper) {
3414 case ConvOperationKind::Conv:
3415 // kernel{kw, c, f}
3416 bindShapeDims(rhsShapedType, kwSize, cSize);
3417 break;
3418 case ConvOperationKind::Pool:
3419 // kernel{kw}
3420 bindShapeDims(rhsShapedType, kwSize);
3421 cSize = fSize;
3422 break;
3423 }
3424 lhsShape = {nSize,
3425 // iw = ow * sw + kw * dw - 1
3426 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3427 // Perform the proper inclusive -> exclusive -> inclusive.
3428 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3429 1,
3430 cSize};
3431 switch (oper) {
3432 case ConvOperationKind::Conv:
3433 rhsShape = {kwSize, cSize, fSize};
3434 break;
3435 case ConvOperationKind::Pool:
3436 rhsShape = {kwSize};
3437 break;
3438 }
3439 resShape = {nSize, wSize, fSize};
3440 break;
3441 case Conv1DOpOrder::Ncw:
3442 // out{n, f, w}
3443 bindShapeDims(resShapedType, nSize, fSize, wSize);
3444 switch (oper) {
3445 case ConvOperationKind::Conv:
3446 // kernel{f, c, kw}
3447 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3448 break;
3449 case ConvOperationKind::Pool:
3450 // kernel{kw}
3451 bindShapeDims(rhsShapedType, kwSize);
3452 cSize = fSize;
3453 break;
3454 }
3455 lhsShape = {nSize, cSize,
3456 // iw = ow * sw + kw * dw - 1
3457 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3458 // Perform the proper inclusive -> exclusive -> inclusive.
3459 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3460 1};
3461 switch (oper) {
3462 case ConvOperationKind::Conv:
3463 rhsShape = {fSize, cSize, kwSize};
3464 break;
3465 case ConvOperationKind::Pool:
3466 rhsShape = {kwSize};
3467 break;
3468 }
3469 resShape = {nSize, fSize, wSize};
3470 break;
3471 }
3472
3473 vector::TransferWriteOp write;
3474 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3475
3476 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3477 // When strideW == 1, we can batch the contiguous loads and avoid
3478 // unrolling
3479 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3480
3481 Type lhsEltType = lhsShapedType.getElementType();
3482 Type rhsEltType = rhsShapedType.getElementType();
3483 Type resEltType = resShapedType.getElementType();
3484 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3485 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3486 auto resType = VectorType::get(resShape, resEltType);
3487 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3488 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3489 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3490 SmallVector<Value> resPadding(resShape.size(), zero);
3491
3492 // Read the whole lhs, rhs and res in one shot (with zero padding).
3493 Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3494 lhsPadding);
3495 // This is needed only for Conv.
3496 Value rhs = nullptr;
3497 if (oper == ConvOperationKind::Conv)
3498 rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3499 rhsPadding);
3500 Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
3501 resPadding);
3502
3503 // The base vectorization case for channeled convolution is input:
3504 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3505 // vectorization case, we do pre transpose on input, weight, and output.
3506 switch (conv1DOpOrder) {
3507 case Conv1DOpOrder::W:
3508 case Conv1DOpOrder::Nwc:
3509 // Base case, so no transposes necessary.
3510 break;
3511 case Conv1DOpOrder::Ncw: {
3512 // To match base vectorization case, we pre-transpose current case.
3513 // ncw -> nwc
3514 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3515 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
3516 // fcw -> wcf
3517 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3518
3519 // This is needed only for Conv.
3520 if (oper == ConvOperationKind::Conv)
3521 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
3522 // nfw -> nwf
3523 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3524 res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
3525 break;
3526 }
3527 }
3528
3529 //===------------------------------------------------------------------===//
3530 // Begin vector-only rewrite part
3531 //===------------------------------------------------------------------===//
3532 // Unroll along kw and read slices of lhs and rhs.
3533 SmallVector<Value> lhsVals, rhsVals, resVals;
3534 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3535 kwSize, strideW, dilationW, wSizeStep,
3536 isSingleChanneled);
3537 // Do not do for pooling.
3538 if (oper == ConvOperationKind::Conv)
3539 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3540 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3541 wSizeStep, isSingleChanneled);
3542
3543 auto linearIndex = [&](int64_t kw, int64_t w) {
3544 return kw * (wSize / wSizeStep) + w;
3545 };
3546
3547 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3548 // or perform outerproduct for non-channeled convolution or perform simple
3549 // arith operation for pooling
3550 for (int64_t kw = 0; kw < kwSize; ++kw) {
3551 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3552 switch (oper) {
3553 case ConvOperationKind::Conv:
3554 if (isSingleChanneled) {
3555 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3556 lhsVals[linearIndex(kw, w)],
3557 rhsVals[kw], resVals[w]);
3558 } else {
3559 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3560 lhsVals[linearIndex(kw, w)],
3561 rhsVals[kw], resVals[w]);
3562 }
3563 break;
3564 case ConvOperationKind::Pool:
3565 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3566 resVals[w]);
3567 break;
3568 }
3569 }
3570 }
3571
3572 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3573 isSingleChanneled);
3574 //===------------------------------------------------------------------===//
3575 // End vector-only rewrite part
3576 //===------------------------------------------------------------------===//
3577
3578 // The base vectorization case for channeled convolution is output:
3579 // {n,w,f} To reuse the result from base pattern vectorization case, we
3580 // post transpose the base case result.
3581 switch (conv1DOpOrder) {
3582 case Conv1DOpOrder::W:
3583 case Conv1DOpOrder::Nwc:
3584 // Base case, so no transposes necessary.
3585 break;
3586 case Conv1DOpOrder::Ncw: {
3587 // nwf -> nfw
3588 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3589 res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3590 break;
3591 }
3592 }
3593
3594 return rewriter
3595 .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3596 .getOperation();
3597 }
3598
3599 // Take a value and widen to have the same element type as `ty`.
3600 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3601 const Type srcElementType = getElementTypeOrSelf(type: val.getType());
3602 const Type dstElementType = getElementTypeOrSelf(type: ty);
3603 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3604 if (srcElementType == dstElementType)
3605 return val;
3606
3607 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3608 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3609 const Type dstType =
3610 cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3611
3612 if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) {
3613 return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3614 }
3615
3616 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3617 srcWidth < dstWidth)
3618 return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3619
3620 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3621 srcWidth < dstWidth)
3622 return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3623
3624 assert(false && "unhandled promotion case");
3625 return nullptr;
3626 }
3627
3628 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3629 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3630 Value lhs, Value rhs, Value res) {
3631 vector::IteratorType par = vector::IteratorType::parallel;
3632 vector::IteratorType red = vector::IteratorType::reduction;
3633 AffineExpr n, w, f, c;
3634 bindDims(ctx, n, w, f, c);
3635 lhs = promote(rewriter, loc, val: lhs, ty: res.getType());
3636 rhs = promote(rewriter, loc, val: rhs, ty: res.getType());
3637 auto contrationOp = rewriter.create<vector::ContractionOp>(
3638 loc, lhs, rhs, res,
3639 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3640 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3641 contrationOp.setKind(reductionKind);
3642 return contrationOp;
3643 }
3644
3645 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3646 // convolution.
3647 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3648 Value lhs, Value rhs, Value res) {
3649 return rewriter.create<vector::OuterProductOp>(
3650 loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3651 }
3652
3653 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3654 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3655 Value res) {
3656 if (isPoolExt)
3657 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3658 return rewriter
3659 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3660 ->getResult(0);
3661 }
3662
3663 /// Generate a vector implementation for:
3664 /// ```
3665 /// Op def: ( n, w, c, kw)
3666 /// Iters: ({Par(), Par(), Par(), Red()})
3667 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3668 /// ```
3669 /// kw is always unrolled.
3670 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3671 /// > 1.
3672 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3673 bool channelDimScalableFlag,
3674 bool flatten) {
3675 bool scalableChDim = false;
3676 bool useMasking = false;
3677 int64_t nSize, wSize, cSize, kwSize;
3678 // kernel{kw, c}
3679 bindShapeDims(rhsShapedType, kwSize, cSize);
3680 if (ShapedType::isDynamic(cSize)) {
3681 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3682 cSize = channelDimVecSize;
3683 // Scalable vectors are only used when both conditions are met:
3684 // 1. channel dim is dynamic
3685 // 2. channelDimScalableFlag is set
3686 scalableChDim = channelDimScalableFlag;
3687 useMasking = true;
3688 }
3689
3690 assert(!(useMasking && flatten) &&
3691 "Unsupported flattened conv with dynamic shapes");
3692
3693 // out{n, w, c}
3694 bindShapeDims(resShapedType, nSize, wSize);
3695
3696 vector::TransferWriteOp write;
3697 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3698
3699 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3700 // When strideW == 1, we can batch the contiguous loads and avoid
3701 // unrolling
3702 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3703
3704 Type lhsEltType = lhsShapedType.getElementType();
3705 Type rhsEltType = rhsShapedType.getElementType();
3706 Type resEltType = resShapedType.getElementType();
3707 VectorType lhsType = VectorType::get(
3708 {nSize,
3709 // iw = ow * sw + kw * dw - 1
3710 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3711 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3712 cSize},
3713 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3714 VectorType rhsType =
3715 VectorType::get({kwSize, cSize}, rhsEltType,
3716 /*scalableDims=*/{false, scalableChDim});
3717 VectorType resType =
3718 VectorType::get({nSize, wSize, cSize}, resEltType,
3719 /*scalableDims=*/{false, false, scalableChDim});
3720
3721 // Masks the input xfer Op along the channel dim, iff the corresponding
3722 // scalable flag is set.
3723 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3724 ArrayRef<bool> scalableDims,
3725 Operation *opToMask) {
3726 if (!useMasking)
3727 return opToMask;
3728 auto maskType =
3729 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3730
3731 SmallVector<bool> inBounds(maskShape.size(), true);
3732 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3733 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3734 rewriter.getBoolArrayAttr(inBounds));
3735
3736 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3737 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3738
3739 Value maskOp =
3740 rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3741
3742 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3743 };
3744
3745 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3746 // 0].
3747 Value lhs = rewriter.create<vector::TransferReadOp>(
3748 loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3749 auto maybeMaskedLhs = maybeMaskXferOp(
3750 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3751
3752 // Read rhs slice of size {kw, c} @ [0, 0].
3753 Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3754 ValueRange{zero, zero});
3755 auto maybeMaskedRhs = maybeMaskXferOp(
3756 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3757
3758 // Read res slice of size {n, w, c} @ [0, 0, 0].
3759 Value res = rewriter.create<vector::TransferReadOp>(
3760 loc, resType, resShaped, ValueRange{zero, zero, zero});
3761 auto maybeMaskedRes = maybeMaskXferOp(
3762 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3763
3764 //===------------------------------------------------------------------===//
3765 // Begin vector-only rewrite part
3766 //===------------------------------------------------------------------===//
3767 // Unroll along kw and read slices of lhs and rhs.
3768 SmallVector<Value> lhsVals, rhsVals, resVals;
3769 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3770 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3771
3772 // Extract lhs slice of size {n, wSizeStep, c}
3773 // @ [0, sw * w + dw * kw, 0].
3774 for (int64_t kw = 0; kw < kwSize; ++kw) {
3775 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3776 lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3777 loc, maybeMaskedLhs->getResult(0),
3778 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3779 inOutSliceSizes, inOutStrides));
3780 }
3781 }
3782 // Extract rhs slice of size {c} @ [kw].
3783 for (int64_t kw = 0; kw < kwSize; ++kw) {
3784 rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3785 loc, maybeMaskedRhs->getResult(0),
3786 /*offsets=*/ArrayRef<int64_t>{kw}));
3787 }
3788 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3789 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3790 resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3791 loc, maybeMaskedRes->getResult(0),
3792 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3793 inOutStrides));
3794 }
3795
3796 auto linearIndex = [&](int64_t kw, int64_t w) {
3797 return kw * (wSize / wSizeStep) + w;
3798 };
3799
3800 // Note - the scalable flags are ignored as flattening combined with
3801 // scalable vectorization is not supported.
3802 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3803 auto lhsTypeAfterFlattening =
3804 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3805 auto resTypeAfterFlattening =
3806 VectorType::get(inOutFlattenSliceSizes, resEltType);
3807
3808 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3809 for (int64_t kw = 0; kw < kwSize; ++kw) {
3810 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3811 Value lhsVal = lhsVals[linearIndex(kw, w)];
3812 Value resVal = resVals[w];
3813 if (flatten) {
3814 // Flatten the input and output vectors (collapse the channel
3815 // dimension)
3816 lhsVal = rewriter.create<vector::ShapeCastOp>(
3817 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3818 resVal = rewriter.create<vector::ShapeCastOp>(
3819 loc, resTypeAfterFlattening, resVals[w]);
3820 }
3821 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3822 rhsVals[kw], resVal, flatten);
3823 if (flatten) {
3824 // Un-flatten the output vector (restore the channel dimension)
3825 resVals[w] = rewriter.create<vector::ShapeCastOp>(
3826 loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3827 }
3828 }
3829 }
3830
3831 // Its possible we failed to create the Fma.
3832 if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) {
3833 // Manually revert (in reverse order) to avoid leaving a bad IR state.
3834 for (auto &collection :
3835 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3836 for (Value v : collection)
3837 rewriter.eraseOp(v.getDefiningOp());
3838 return rewriter.notifyMatchFailure(op, "failed to create FMA");
3839 }
3840
3841 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3842 // This does not depend on kw.
3843 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3844 maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3845 loc, resVals[w], maybeMaskedRes->getResult(0),
3846 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3847 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3848 }
3849 //===------------------------------------------------------------------===//
3850 // End vector-only rewrite part
3851 //===------------------------------------------------------------------===//
3852
3853 // Write back res slice of size {n, w, c} @ [0, 0, 0].
3854 Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3855 loc, maybeMaskedRes->getResult(0), resShaped,
3856 ValueRange{zero, zero, zero});
3857 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3858 resOut);
3859 }
3860
3861 /// Lower:
3862 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3863 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3864 /// to MulAcc.
3865 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3866 Value lhs, Value rhs, Value res,
3867 bool flatten) {
3868 auto rhsTy = cast<ShapedType>(rhs.getType());
3869 auto resTy = cast<ShapedType>(res.getType());
3870
3871 // TODO(suderman): Change this to use a vector.ima intrinsic.
3872 lhs = promote(rewriter, loc, val: lhs, ty: resTy);
3873
3874 if (flatten) {
3875 // NOTE: This following logic won't work for scalable vectors. For this
3876 // reason, "flattening" is not supported when shapes are dynamic (this
3877 // should be captured by one of the pre-conditions).
3878
3879 // There are two options for handling the filter:
3880 // * shape_cast(broadcast(filter))
3881 // * broadcast(shuffle(filter))
3882 // Opt for the option without shape_cast to simplify the codegen.
3883 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3884 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3885
3886 SmallVector<int64_t, 16> indices;
3887 for (int i = 0; i < resSize / rhsSize; ++i) {
3888 for (int j = 0; j < rhsSize; ++j)
3889 indices.push_back(Elt: j);
3890 }
3891
3892 rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3893 }
3894 // Broadcast the filter to match the output vector
3895 rhs = rewriter.create<vector::BroadcastOp>(
3896 loc, resTy.clone(rhsTy.getElementType()), rhs);
3897
3898 rhs = promote(rewriter, loc, val: rhs, ty: resTy);
3899
3900 if (!lhs || !rhs)
3901 return nullptr;
3902
3903 if (isa<FloatType>(resTy.getElementType()))
3904 return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3905
3906 auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3907 return rewriter.create<arith::AddIOp>(loc, mul, res);
3908 }
3909
3910 /// Entry point for non-channeled convolution:
3911 /// {{w + kw}, {kw}, {w}}
3912 FailureOr<Operation *> generateNonChanneledConv() {
3913 AffineExpr w, kw;
3914 bindDims(ctx, w, kw);
3915 if (!iters({Par(), Red()}))
3916 return rewriter.notifyMatchFailure(op,
3917 "failed to match conv::W 1-par 1-red");
3918
3919 // No transposition needed.
3920 if (layout({/*lhsIndex*/ {w + kw},
3921 /*rhsIndex*/ {kw},
3922 /*resIndex*/ {w}}))
3923 return conv(conv1DOpOrder: Conv1DOpOrder::W);
3924
3925 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3926 }
3927
3928 /// Entry point that transposes into the common form:
3929 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3930 FailureOr<Operation *> generateNwcConv() {
3931 AffineExpr n, w, f, kw, c;
3932 bindDims(ctx, n, w, f, kw, c);
3933 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3934 return rewriter.notifyMatchFailure(
3935 op, "failed to match conv::Nwc 3-par 2-red");
3936
3937 // No transposition needed.
3938 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3939 /*rhsIndex*/ {kw, c, f},
3940 /*resIndex*/ {n, w, f}}))
3941 return conv(conv1DOpOrder: Conv1DOpOrder::Nwc);
3942
3943 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3944 }
3945
3946 /// Entry point that transposes into the common form:
3947 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3948 FailureOr<Operation *> generateNcwConv() {
3949 AffineExpr n, w, f, kw, c;
3950 bindDims(ctx, n, f, w, c, kw);
3951 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3952 return rewriter.notifyMatchFailure(
3953 op, "failed to match conv::Ncw 3-par 2-red");
3954
3955 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3956 /*rhsIndex*/ {f, c, kw},
3957 /*resIndex*/ {n, f, w}}))
3958 return conv(conv1DOpOrder: Conv1DOpOrder::Ncw);
3959
3960 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3961 }
3962
3963 /// Entry point that transposes into the common form:
3964 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3965 FailureOr<Operation *> generateNwcPooling() {
3966 AffineExpr n, w, c, kw;
3967 bindDims(ctx, n, w, c, kw);
3968 if (!iters({Par(), Par(), Par(), Red()}))
3969 return rewriter.notifyMatchFailure(op,
3970 "failed to match pooling 3-par 1-red");
3971
3972 // No transposition needed.
3973 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3974 /*rhsIndex*/ {kw},
3975 /*resIndex*/ {n, w, c}}))
3976 return conv(conv1DOpOrder: Conv1DOpOrder::Nwc);
3977
3978 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3979 }
3980
3981 /// Entry point that transposes into the common form:
3982 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3983 FailureOr<Operation *> generateNcwPooling() {
3984 AffineExpr n, w, c, kw;
3985 bindDims(ctx, n, c, w, kw);
3986 if (!iters({Par(), Par(), Par(), Red()}))
3987 return rewriter.notifyMatchFailure(op,
3988 "failed to match pooling 3-par 1-red");
3989
3990 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3991 /*rhsIndex*/ {kw},
3992 /*resIndex*/ {n, c, w}}))
3993 return conv(conv1DOpOrder: Conv1DOpOrder::Ncw);
3994
3995 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3996 }
3997
3998 /// Entry point that transposes into the common form:
3999 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4000 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4001 bool vecChDimScalableFlag = false,
4002 bool flatten = false) {
4003 AffineExpr n, w, c, kw;
4004 bindDims(ctx, n, w, c, kw);
4005 if (!iters({Par(), Par(), Par(), Red()}))
4006 return rewriter.notifyMatchFailure(
4007 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4008
4009 // No transposition needed.
4010 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4011 /*rhsIndex*/ {kw, c},
4012 /*resIndex*/ {n, w, c}}))
4013 return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten);
4014
4015 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4016 }
4017
4018private:
4019 ConvOperationKind oper = ConvOperationKind::Conv;
4020 StringAttr redOp;
4021 StringAttr poolExtOp;
4022 bool isPoolExt = false;
4023 int strideW, dilationW;
4024 Value lhsShaped, rhsShaped, resShaped;
4025 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4026 vector::CombiningKind reductionKind;
4027
4028 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4029 void setConvOperationKind(Operation *reduceOp) {
4030 int numBlockArguments =
4031 llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>);
4032 if (numBlockArguments == 1) {
4033 // Will be convolution if feeder is a MulOp.
4034 // A strength reduced version of MulOp for i1 type is AndOp which is also
4035 // supported. Otherwise, it can be pooling. This strength reduction logic
4036 // is in `buildBinaryFn` helper in the Linalg dialect.
4037 auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(),
4038 P: llvm::IsaPred<BlockArgument>);
4039 Operation *feedOp = (*feedValIt).getDefiningOp();
4040 if (isCastOfBlockArgument(op: feedOp)) {
4041 oper = ConvOperationKind::Pool;
4042 isPoolExt = true;
4043 poolExtOp = feedOp->getName().getIdentifier();
4044 return;
4045 }
4046 oper = ConvOperationKind::Conv;
4047 return;
4048 }
4049 // numBlockArugments == 2 and this is a pooling op.
4050 oper = ConvOperationKind::Pool;
4051 isPoolExt = false;
4052 }
4053};
4054} // namespace
4055
4056/// Helper function to vectorize a LinalgOp with convolution semantics.
4057// TODO: extend the generic vectorization to support windows and drop this.
4058static FailureOr<Operation *> vectorizeConvolution(
4059 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4060 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4061 Conv1DGenerator conv1dGen(rewriter, op);
4062 auto res = conv1dGen.generateNonChanneledConv();
4063 if (succeeded(res))
4064 return res;
4065 res = conv1dGen.generateNwcConv();
4066 if (succeeded(res))
4067 return res;
4068 res = conv1dGen.generateNcwConv();
4069 if (succeeded(res))
4070 return res;
4071 res = conv1dGen.generateNwcPooling();
4072 if (succeeded(res))
4073 return res;
4074 res = conv1dGen.generateNcwPooling();
4075 if (succeeded(res))
4076 return res;
4077
4078 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4079 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4080 // masked/scalable) is the channel dim (i.e. the trailing dim).
4081 uint64_t vecChDimSize = ShapedType::kDynamic;
4082 bool vecChDimScalableFlag = false;
4083 if (!inputVecSizes.empty()) {
4084 // Only use the input vector size corresponding to the channel dim. Other
4085 // vector dims will be inferred from the Ops.
4086 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4087 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4088 "Not a 1D depthwise conv!");
4089 size_t chDimIdx =
4090 TypeSwitch<Operation *, size_t>(op)
4091 .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4092 .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4093
4094 vecChDimSize = inputVecSizes[chDimIdx];
4095 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4096 }
4097 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4098 flatten: flatten1DDepthwiseConv);
4099}
4100
4101struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4102 using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
4103
4104 LogicalResult matchAndRewrite(LinalgOp op,
4105 PatternRewriter &rewriter) const override {
4106 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4107 if (failed(Result: resultOrFail))
4108 return failure();
4109 Operation *newOp = *resultOrFail;
4110 if (newOp->getNumResults() == 0) {
4111 rewriter.eraseOp(op: op.getOperation());
4112 return success();
4113 }
4114 assert(newOp->getNumResults() == 1 && "expected single result");
4115 rewriter.replaceOp(op.getOperation(), newOp->getResult(idx: 0));
4116 return success();
4117 }
4118};
4119
4120void mlir::linalg::populateConvolutionVectorizationPatterns(
4121 RewritePatternSet &patterns, PatternBenefit benefit) {
4122 patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit);
4123}
4124

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp