1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
12#include "mlir/Dialect/Affine/Utils.h"
13
14#include "mlir/Analysis/SliceAnalysis.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20#include "mlir/Dialect/Linalg/Utils/Utils.h"
21#include "mlir/Dialect/Tensor/IR/Tensor.h"
22#include "mlir/Dialect/Tensor/Utils/Utils.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
27#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinTypeInterfaces.h"
31#include "mlir/IR/BuiltinTypes.h"
32#include "mlir/IR/OpDefinition.h"
33#include "mlir/IR/PatternMatch.h"
34#include "mlir/Support/LLVM.h"
35#include "mlir/Transforms/RegionUtils.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/Sequence.h"
38#include "llvm/ADT/SmallVector.h"
39#include "llvm/ADT/TypeSwitch.h"
40#include "llvm/ADT/iterator_range.h"
41#include "llvm/Support/Debug.h"
42#include "llvm/Support/MathExtras.h"
43#include "llvm/Support/raw_ostream.h"
44#include <optional>
45#include <type_traits>
46
47using namespace mlir;
48using namespace mlir::linalg;
49
50#define DEBUG_TYPE "linalg-vectorization"
51
52#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54
55/// Try to vectorize `convOp` as a convolution.
56static FailureOr<Operation *>
57vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
58 ArrayRef<int64_t> inputVecSizes = {},
59 ArrayRef<bool> inputVecScalableFlags = {},
60 bool flatten1DDepthwiseConv = false);
61
62/// Return the unique instance of OpType in `block` if it is indeed unique.
63/// Return null if none or more than 1 instances exist.
64template <typename OpType>
65static OpType getSingleOpOfType(Block &block) {
66 OpType res;
67 block.walk([&](OpType op) {
68 if (res) {
69 res = nullptr;
70 return WalkResult::interrupt();
71 }
72 res = op;
73 return WalkResult::advance();
74 });
75 return res;
76}
77
78/// Helper function to extract the input slices after filter is unrolled along
79/// kw.
80static SmallVector<Value>
81extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize, int strideW, int dilationW,
84 int64_t wSizeStep, bool isSingleChanneled) {
85 SmallVector<Value> result;
86 if (isSingleChanneled) {
87 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
88 // convolution.
89 SmallVector<int64_t> sizes{wSizeStep};
90 SmallVector<int64_t> strides{1};
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
94 loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
95 }
96 }
97 } else {
98 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
99 // for channeled convolution.
100 SmallVector<int64_t> sizes{nSize, wSizeStep, cSize};
101 SmallVector<int64_t> strides{1, 1, 1};
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
105 loc, input,
106 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
107 sizes, strides));
108 }
109 }
110 }
111 return result;
112}
113
114/// Helper function to extract the filter slices after filter is unrolled along
115/// kw.
116static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
117 Location loc, Value filter,
118 int64_t kwSize) {
119 SmallVector<Value> result;
120 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
121 // non-chanelled convolution] @ [kw].
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.create<vector::ExtractOp>(
124 loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
125 }
126 return result;
127}
128
129/// Helper function to extract the result slices after filter is unrolled along
130/// kw.
131static SmallVector<Value>
132extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep, bool isSingleChanneled) {
135 SmallVector<Value> result;
136 if (isSingleChanneled) {
137 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
138 SmallVector<int64_t> sizes{wSizeStep};
139 SmallVector<int64_t> strides{1};
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
142 loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
143 }
144 } else {
145 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
146 // convolution.
147 SmallVector<int64_t> sizes{nSize, wSizeStep, fSize};
148 SmallVector<int64_t> strides{1, 1, 1};
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
151 loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
152 }
153 }
154 return result;
155}
156
157/// Helper function to insert the computed result slices.
158static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
159 Value res, int64_t wSize, int64_t wSizeStep,
160 SmallVectorImpl<Value> &resVals,
161 bool isSingleChanneled) {
162
163 if (isSingleChanneled) {
164 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
165 // This does not depend on kw.
166 SmallVector<int64_t> strides{1};
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.create<vector::InsertStridedSliceOp>(
169 loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
170 }
171 } else {
172 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
173 // convolution. This does not depend on kw.
174 SmallVector<int64_t> strides{1, 1, 1};
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.create<vector::InsertStridedSliceOp>(
177 loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
178 strides);
179 }
180 }
181 return res;
182}
183
184/// Contains the vectorization state and related methods used across the
185/// vectorization process of a given operation.
186struct VectorizationState {
187 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
188
189 /// Initializes the vectorization state, including the computation of the
190 /// canonical vector shape for vectorization.
191 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
192 ArrayRef<int64_t> inputVectorSizes,
193 ArrayRef<bool> inputScalableVecDims);
194
195 /// Returns the canonical vector shape used to vectorize the iteration space.
196 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197
198 /// Returns a vector type of the provided `elementType` with the canonical
199 /// vector shape and the corresponding fixed/scalable dimensions bit. If
200 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
201 /// accordingly.
202 VectorType getCanonicalVecType(
203 Type elementType,
204 std::optional<AffineMap> dimPermutation = std::nullopt) const {
205 SmallVector<int64_t> vectorShape;
206 SmallVector<bool> scalableDims;
207 if (dimPermutation.has_value()) {
208 vectorShape =
209 applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape);
210 scalableDims =
211 applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims);
212 } else {
213 vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end());
214 scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end());
215 }
216
217 return VectorType::get(vectorShape, elementType, scalableDims);
218 }
219
220 /// Masks an operation with the canonical vector mask if the operation needs
221 /// masking. Returns the masked operation or the original operation if masking
222 /// is not needed. If provided, the canonical mask for this operation is
223 /// permuted using `maybeMaskingMap`.
224 Operation *
225 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
226 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
227
228private:
229 /// Initializes the iteration space static sizes using the Linalg op
230 /// information. This may become more complicated in the future.
231 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
232 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
233 }
234
235 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
236 /// all the static and dynamic dimensions of the iteration space to be
237 /// vectorized and store them in `iterSpaceValueSizes`.
238 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
239 LinalgOp linalgOp);
240
241 /// Create or retrieve an existing mask value to mask `opToMask` in the
242 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
243 /// permuted using that permutation map. If a new mask is created, it will be
244 /// cached for future users.
245 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
246 LinalgOp linalgOp,
247 std::optional<AffineMap> maybeMaskingMap);
248
249 // Holds the compile-time static sizes of the iteration space to vectorize.
250 // Dynamic dimensions are represented using ShapedType::kDynamic.
251 SmallVector<int64_t> iterSpaceStaticSizes;
252
253 /// Holds the value sizes of the iteration space to vectorize. Static
254 /// dimensions are represented by 'arith.constant' and dynamic
255 /// dimensions by 'tensor/memref.dim'.
256 SmallVector<Value> iterSpaceValueSizes;
257
258 /// Holds the canonical vector shape used to vectorize the iteration space.
259 SmallVector<int64_t> canonicalVecShape;
260
261 /// Holds the vector dimensions that are scalable in the canonical vector
262 /// shape.
263 SmallVector<bool> scalableVecDims;
264
265 /// Holds the active masks for permutations of the canonical vector iteration
266 /// space.
267 DenseMap<AffineMap, Value> activeMaskCache;
268
269 /// Global vectorization guard for the incoming rewriter. It's initialized
270 /// when the vectorization state is initialized.
271 OpBuilder::InsertionGuard rewriterGuard;
272};
273
274LogicalResult
275VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276 LinalgOp linalgOp) {
277 // TODO: Support 0-d vectors.
278 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
279 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
280 // Create constant index op for static dimensions.
281 iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(
282 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
283 continue;
284 }
285
286 // Find an operand defined on this dimension of the iteration space to
287 // extract the runtime dimension size.
288 Value operand;
289 unsigned operandDimPos;
290 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
291 operandDimPos)))
292 return failure();
293
294 Value dynamicDim = linalgOp.hasPureTensorSemantics()
295 ? (Value)rewriter.create<tensor::DimOp>(
296 linalgOp.getLoc(), operand, operandDimPos)
297 : (Value)rewriter.create<memref::DimOp>(
298 linalgOp.getLoc(), operand, operandDimPos);
299 iterSpaceValueSizes.push_back(Elt: dynamicDim);
300 }
301
302 return success();
303}
304
305/// Initializes the vectorization state, including the computation of the
306/// canonical vector shape for vectorization.
307// TODO: Move this to the constructor when we can remove the failure cases.
308LogicalResult
309VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
310 ArrayRef<int64_t> inputVectorSizes,
311 ArrayRef<bool> inputScalableVecDims) {
312 // Initialize the insertion point.
313 rewriter.setInsertionPoint(linalgOp);
314
315 if (!inputVectorSizes.empty()) {
316 // Get the canonical vector shape from the input vector sizes provided. This
317 // path should be taken to vectorize code with dynamic shapes and when using
318 // vector sizes greater than the iteration space sizes.
319 canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end());
320 scalableVecDims.append(in_start: inputScalableVecDims.begin(),
321 in_end: inputScalableVecDims.end());
322 } else {
323 // Compute the canonical vector shape from the operation shape. If there are
324 // dynamic shapes, the operation won't be vectorized. We assume all the
325 // vector dimensions are fixed.
326 canonicalVecShape = linalgOp.getStaticLoopRanges();
327 scalableVecDims.append(linalgOp.getNumLoops(), false);
328 }
329
330 LDBG("Canonical vector shape: ");
331 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
332 LLVM_DEBUG(llvm::dbgs() << "\n");
333 LDBG("Scalable vector dims: ");
334 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
335 LLVM_DEBUG(llvm::dbgs() << "\n");
336
337 if (ShapedType::isDynamicShape(canonicalVecShape))
338 return failure();
339
340 // Initialize iteration space static sizes.
341 initIterSpaceStaticSizes(linalgOp: linalgOp);
342
343 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
344 // all the static and dynamic dimensions of the iteration space, needed to
345 // compute a mask during vectorization.
346 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp: linalgOp)))
347 return failure();
348
349 return success();
350}
351
352/// Create or retrieve an existing mask value to mask `opToMask` in the
353/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
354/// using that permutation map. If a new mask is created, it will be cached for
355/// future users.
356Value VectorizationState::getOrCreateMaskFor(
357 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
358 std::optional<AffineMap> maybeMaskingMap) {
359 // No mask is needed if the operation is not maskable.
360 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
361 if (!maskableOp)
362 return Value();
363
364 assert(!maskableOp.isMasked() &&
365 "Masking an operation that is already masked");
366
367 // If no masking map was provided, use an identity map with the loop dims.
368 assert((!maybeMaskingMap || *maybeMaskingMap) &&
369 "Unexpected null mask permutation map");
370 AffineMap maskingMap =
371 maybeMaskingMap ? *maybeMaskingMap
372 : AffineMap::getMultiDimIdentityMap(
373 numDims: linalgOp.getNumLoops(), context: rewriter.getContext());
374
375 LDBG("Masking map: " << maskingMap << "\n");
376
377 // Return the active mask for the masking map of this operation if it was
378 // already created.
379 auto activeMaskIt = activeMaskCache.find(Val: maskingMap);
380 if (activeMaskIt != activeMaskCache.end()) {
381 Value mask = activeMaskIt->second;
382 LDBG("Reusing mask: " << mask << "\n");
383 return mask;
384 }
385
386 // Compute permuted projection of the iteration space to be masked and the
387 // corresponding mask shape. If the resulting iteration space dimensions are
388 // static and identical to the mask shape, masking is not needed for this
389 // operation.
390 // TODO: Improve this check. Only projected permutation indexing maps are
391 // supported.
392 SmallVector<int64_t> permutedStaticSizes =
393 applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes);
394 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
395 auto maskShape = maskType.getShape();
396
397 LDBG("Mask shape: ");
398 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
399 LLVM_DEBUG(llvm::dbgs() << "\n");
400
401 if (permutedStaticSizes == maskShape) {
402 LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
403 activeMaskCache[maskingMap] = Value();
404 return Value();
405 }
406
407 // Permute the iteration space value sizes to compute the mask upper bounds.
408 SmallVector<Value> upperBounds =
409 applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes));
410 assert(!maskShape.empty() && !upperBounds.empty() &&
411 "Masked 0-d vectors are not supported yet");
412
413 // Create the mask based on the dimension values.
414 Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
415 maskType, upperBounds);
416 LDBG("Creating new mask: " << mask << "\n");
417 activeMaskCache[maskingMap] = mask;
418 return mask;
419}
420
421/// Masks an operation with the canonical vector mask if the operation needs
422/// masking. Returns the masked operation or the original operation if masking
423/// is not needed. If provided, the canonical mask for this operation is
424/// permuted using `maybeMaskingMap`.
425Operation *
426VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
427 LinalgOp linalgOp,
428 std::optional<AffineMap> maybeMaskingMap) {
429 LDBG("Trying to mask: " << *opToMask << "\n");
430
431 // Create or retrieve mask for this operation.
432 Value mask =
433 getOrCreateMaskFor(rewriter, opToMask, linalgOp: linalgOp, maybeMaskingMap);
434
435 if (!mask) {
436 LDBG("No mask required\n");
437 return opToMask;
438 }
439
440 // Wrap the operation with a new `vector.mask` and update D-U chain.
441 assert(opToMask && "Expected a valid operation to mask");
442 auto maskOp = cast<vector::MaskOp>(
443 mlir::vector::maskOperation(rewriter, opToMask, mask));
444 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
445
446 for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults()))
447 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
448 maskOpTerminator);
449
450 LDBG("Masked operation: " << *maskOp << "\n");
451 return maskOp;
452}
453
454/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
455/// projectedPermutation, compress the unused dimensions to serve as a
456/// permutation_map for a vector transfer operation.
457/// For example, given a linalg op such as:
458///
459/// ```
460/// %0 = linalg.generic {
461/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
462/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
463/// }
464/// ins(%0 : tensor<2x3x4xf32>)
465/// outs(%1 : tensor<5x6xf32>)
466/// ```
467///
468/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
469/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
470/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
471static AffineMap reindexIndexingMap(AffineMap map) {
472 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
473 "expected projected permutation");
474 auto res = compressUnusedDims(map);
475 assert(res.getNumDims() == res.getNumResults() &&
476 "expected reindexed map with same number of dims and results");
477 return res;
478}
479
480/// Helper enum to represent conv1d input traversal order.
481enum class Conv1DOpOrder {
482 W, // Corresponds to non-channeled 1D convolution operation.
483 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
484 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
485};
486
487/// Helper data structure to represent the result of vectorization.
488/// In certain specific cases, like terminators, we do not want to propagate/
489enum VectorizationStatus {
490 /// Op failed to vectorize.
491 Failure = 0,
492 /// Op vectorized and custom function took care of replacement logic
493 NoReplace,
494 /// Op vectorized into a new Op whose results will replace original Op's
495 /// results.
496 NewOp
497 // TODO: support values if Op vectorized to Many-Ops whose results we need to
498 // aggregate for replacement.
499};
500struct VectorizationResult {
501 /// Return status from vectorizing the current op.
502 enum VectorizationStatus status = VectorizationStatus::Failure;
503 /// New vectorized operation to replace the current op.
504 /// Replacement behavior is specified by `status`.
505 Operation *newOp;
506};
507
508std::optional<vector::CombiningKind>
509mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
510 using ::mlir::vector::CombiningKind;
511
512 if (!combinerOp)
513 return std::nullopt;
514 return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
515 .Case<arith::AddIOp, arith::AddFOp>(
516 [&](auto op) { return CombiningKind::ADD; })
517 .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
518 .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
519 .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
520 .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
521 .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
522 .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
523 .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
524 .Case<arith::MulIOp, arith::MulFOp>(
525 [&](auto op) { return CombiningKind::MUL; })
526 .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
527 .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
528 .Default([&](auto op) { return std::nullopt; });
529}
530
531/// Check whether `outputOperand` is a reduction with a single combiner
532/// operation. Return the combiner operation of the reduction. Return
533/// nullptr otherwise. Multiple reduction operations would impose an
534/// ordering between reduction dimensions and is currently unsupported in
535/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
536/// max(min(X))
537// TODO: use in LinalgOp verification, there is a circular dependency atm.
538static Operation *matchLinalgReduction(OpOperand *outputOperand) {
539 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
540 unsigned outputPos =
541 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
542 // Only single combiner operations are supported for now.
543 SmallVector<Operation *, 4> combinerOps;
544 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
545 combinerOps.size() != 1)
546 return nullptr;
547
548 // Return the combiner operation.
549 return combinerOps[0];
550}
551
552/// Broadcast `value` to a vector of `shape` if possible. Return value
553/// otherwise.
554static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
555 auto dstVecType = dyn_cast<VectorType>(dstType);
556 // If no shape to broadcast to, just return `value`.
557 if (dstVecType.getRank() == 0)
558 return value;
559 if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) !=
560 vector::BroadcastableToResult::Success)
561 return value;
562 Location loc = b.getInsertionPoint()->getLoc();
563 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
564}
565
566/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
567/// assumes that `reductionOp` has two operands and one of them is the reduction
568/// initial value.buildMultiDimReduce
569// Note: this is a true builder that notifies the OpBuilder listener.
570// TODO: Consider moving as a static helper on the ReduceOp.
571static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
572 Value valueToReduce, Value acc,
573 ArrayRef<bool> dimsToMask) {
574 auto maybeKind = getCombinerOpKind(reduceOp);
575 assert(maybeKind && "Failed precondition: could not get reduction kind");
576 return b.create<vector::MultiDimReductionOp>(
577 reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
578}
579
580static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
581 return llvm::to_vector(
582 llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
583}
584
585/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
586/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
587/// currently being vectorized. If `dest` has null rank, build an memref.store.
588/// Return the produced value or null if no value is produced.
589// Note: this is a true builder that notifies the OpBuilder listener.
590// TODO: Consider moving as a static helper on the ReduceOp.
591static Value buildVectorWrite(RewriterBase &rewriter, Value value,
592 OpOperand *outputOperand,
593 VectorizationState &state) {
594 Location loc = value.getLoc();
595 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
596 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
597
598 // Compute the vector type of the value to store. This type should be an
599 // identity or projection of the canonical vector type without any permutation
600 // applied, given that any permutation in a transfer write happens as part of
601 // the write itself.
602 AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
603 ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(),
604 keepDimFilter: [&](AffineDimExpr dimExpr) -> bool {
605 return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr);
606 });
607 auto vectorType = state.getCanonicalVecType(
608 getElementTypeOrSelf(type: outputOperand->get().getType()), vectorTypeMap);
609
610 Operation *write;
611 if (vectorType.getRank() > 0) {
612 AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap));
613 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
614 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
615 value = broadcastIfNeeded(rewriter, value, vectorType);
616 assert(value.getType() == vectorType && "Incorrect type");
617 write = rewriter.create<vector::TransferWriteOp>(
618 loc, value, outputOperand->get(), indices, writeMap);
619 } else {
620 // 0-d case is still special: do not invert the reindexing writeMap.
621 if (!isa<VectorType>(value.getType()))
622 value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
623 assert(value.getType() == vectorType && "Incorrect type");
624 write = rewriter.create<vector::TransferWriteOp>(
625 loc, value, outputOperand->get(), ValueRange{});
626 }
627
628 write = state.maskOperation(rewriter, opToMask: write, linalgOp: linalgOp, maybeMaskingMap: opOperandMap);
629
630 // If masked, set in-bounds to true. Masking guarantees that the access will
631 // be in-bounds.
632 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
633 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
634 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
635 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
636 }
637
638 LDBG("vectorized op: " << *write << "\n");
639 if (!write->getResults().empty())
640 return write->getResult(idx: 0);
641 return Value();
642}
643
644// Custom vectorization precondition function type. This is intented to be used
645// with CustomVectorizationHook. Returns success if the corresponding custom
646// hook can vectorize the op.
647using CustomVectorizationPrecondition =
648 std::function<LogicalResult(Operation *, bool)>;
649
650// Custom vectorization function type. Produce a vector form of Operation*
651// assuming all its vectorized operands are already in the IRMapping.
652// Return nullptr if the Operation cannot be vectorized.
653using CustomVectorizationHook =
654 std::function<VectorizationResult(Operation *, const IRMapping &)>;
655
656/// Helper function to vectorize the terminator of a `linalgOp`. New result
657/// vector values are appended to `newResults`. Return
658/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
659/// should not try to map produced operations and instead return the results
660/// using the `newResults` vector making them available to the vectorization
661/// algorithm for RAUW. This function is meant to be used as a
662/// CustomVectorizationHook.
663static VectorizationResult
664vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
665 const IRMapping &bvm, VectorizationState &state,
666 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
667 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
668 if (!yieldOp)
669 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
670 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
671 // TODO: Scan for an opportunity for reuse.
672 // TODO: use a map.
673 Value vectorValue = bvm.lookup(output.value());
674 Value newResult =
675 buildVectorWrite(rewriter, vectorValue,
676 linalgOp.getDpsInitOperand(output.index()), state);
677 if (newResult)
678 newResults.push_back(newResult);
679 }
680
681 return VectorizationResult{.status: VectorizationStatus::NoReplace, .newOp: nullptr};
682}
683
684/// Helper function to vectorize the index operations of a `linalgOp`. Return
685/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
686/// should map the produced operations. This function is meant to be used as a
687/// CustomVectorizationHook.
688static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
689 VectorizationState &state,
690 Operation *op,
691 LinalgOp linalgOp) {
692 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
693 if (!indexOp)
694 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
695 auto loc = indexOp.getLoc();
696 // Compute the static loop sizes of the index op.
697 auto targetShape = state.getCanonicalVecShape();
698 // Compute a one-dimensional index vector for the index op dimension.
699 auto constantSeq =
700 llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
701 auto indexSteps = rewriter.create<arith::ConstantOp>(
702 loc, rewriter.getIndexVectorAttr(constantSeq));
703 // Return the one-dimensional index vector if it lives in the trailing
704 // dimension of the iteration space since the vectorization algorithm in this
705 // case can handle the broadcast.
706 if (indexOp.getDim() == targetShape.size() - 1)
707 return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
708 // Otherwise permute the targetShape to move the index dimension last,
709 // broadcast the one-dimensional index vector to the permuted shape, and
710 // finally transpose the broadcasted index vector to undo the permutation.
711 auto permPattern =
712 llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size()));
713 std::swap(permPattern[indexOp.getDim()], permPattern.back());
714 auto permMap =
715 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
716
717 auto broadCastOp = rewriter.create<vector::BroadcastOp>(
718 loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
719 indexSteps);
720 SmallVector<int64_t> transposition =
721 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
722 std::swap(transposition.back(), transposition[indexOp.getDim()]);
723 auto transposeOp =
724 rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
725 return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
726}
727
728/// Helper function to check if the tensor.extract can be vectorized by the
729/// custom hook vectorizeTensorExtract.
730static LogicalResult
731tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
732 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
733 if (!extractOp)
734 return failure();
735
736 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
737 return failure();
738
739 // Check the index type, but only for non 0-d tensors (for which we do need
740 // access indices).
741 if (not extractOp.getIndices().empty()) {
742 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
743 return failure();
744 }
745
746 if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
747 return !VectorType::isValidElementType(type);
748 })) {
749 return failure();
750 }
751
752 return success();
753}
754
755/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
756/// generated from `tensor.extract`. The offset is calculated as follows
757/// (example using scalar values):
758///
759/// offset = extractOp.indices[0]
760/// for (i = 1; i < numIndices; i++)
761/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
762///
763/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
764/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
765static Value calculateGatherOffset(RewriterBase &rewriter,
766 VectorizationState &state,
767 tensor::ExtractOp extractOp,
768 const IRMapping &bvm) {
769 // The vector of indices for GatherOp should be shaped as the output vector.
770 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
771 auto loc = extractOp.getLoc();
772
773 Value offset = broadcastIfNeeded(
774 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
775
776 const size_t numIndices = extractOp.getIndices().size();
777 for (size_t i = 1; i < numIndices; i++) {
778 Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
779
780 auto dimSize = broadcastIfNeeded(
781 rewriter,
782 rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
783 indexVecType);
784
785 offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
786
787 auto extractOpIndex = broadcastIfNeeded(
788 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
789
790 offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
791 }
792
793 return offset;
794}
795
796enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
797
798/// Checks whether /p val can be used for calculating a loop invariant index.
799static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
800
801 auto targetShape = linalgOp.getStaticLoopRanges();
802 assert(((llvm::count_if(targetShape,
803 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
804 "n-D vectors are not yet supported");
805 assert(targetShape.back() != 1 &&
806 "1-D vectors with the trailing dim eqaual 1 are not yet supported");
807
808 // Blocks outside _this_ linalg.generic are effectively loop invariant.
809 // However, analysing block arguments for _this_ linalg.generic Op is a bit
810 // tricky. Just bail out in the latter case.
811 // TODO: We could try analysing the corresponding affine map here.
812 auto *block = linalgOp.getBlock();
813 if (isa<BlockArgument>(Val: val))
814 return llvm::all_of(block->getArguments(),
815 [&val](Value v) { return (v != val); });
816
817 Operation *defOp = val.getDefiningOp();
818 assert(defOp && "This is neither a block argument nor an operation result");
819
820 // IndexOp is loop invariant as long as its result remains constant across
821 // iterations. Given the assumptions on the loop ranges above, only the
822 // trailing loop dim ever changes.
823 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
824 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
825 return (indexOp.getDim() != trailingLoopDim);
826
827 auto *ancestor = block->findAncestorOpInBlock(*defOp);
828
829 // Values define outside `linalgOp` are loop invariant.
830 if (!ancestor)
831 return true;
832
833 // Values defined inside `linalgOp`, which are constant, are loop invariant.
834 if (isa<arith::ConstantOp>(ancestor))
835 return true;
836
837 bool result = true;
838 for (auto op : ancestor->getOperands())
839 result &= isLoopInvariantIdx(linalgOp, op);
840
841 return result;
842}
843
844/// Check whether \p val could be used for calculating the trailing index for a
845/// contiguous load operation.
846///
847/// There are currently 3 types of values that are allowed here:
848/// 1. loop-invariant values,
849/// 2. values that increment by 1 with every loop iteration,
850/// 3. results of basic arithmetic operations (linear and continuous)
851/// involving 1., 2. and 3.
852/// This method returns True if indeed only such values are used in calculating
853/// \p val.
854///
855/// Additionally, the trailing index for a contiguous load operation should
856/// increment by 1 with every loop iteration, i.e. be based on:
857/// * `linalg.index <dim>` ,
858/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
859/// updated to `true` when such an op is found.
860static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
861 bool &foundIndexOp) {
862
863 auto targetShape = linalgOp.getStaticLoopRanges();
864 assert(((llvm::count_if(targetShape,
865 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
866 "n-D vectors are not yet supported");
867 assert(targetShape.back() != 1 &&
868 "1-D vectors with the trailing dim 1 are not yet supported");
869
870 // Blocks outside _this_ linalg.generic are effectively loop invariant.
871 // However, analysing block arguments for _this_ linalg.generic Op is a bit
872 // tricky. Just bail out in the latter case.
873 // TODO: We could try analysing the corresponding affine map here.
874 auto *block = linalgOp.getBlock();
875 if (isa<BlockArgument>(Val: val))
876 return llvm::all_of(block->getArguments(),
877 [&val](Value v) { return (v != val); });
878
879 Operation *defOp = val.getDefiningOp();
880 assert(defOp && "This is neither a block argument nor an operation result");
881
882 // Given the assumption on the loop ranges above, only the trailing loop
883 // index is not constant.
884 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
885 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
886 foundIndexOp = (indexOp.getDim() == trailingLoopDim);
887 return true;
888 }
889
890 auto *ancestor = block->findAncestorOpInBlock(*defOp);
891
892 if (!ancestor)
893 return false;
894
895 // Conservatively reject Ops that could lead to indices with stride other
896 // than 1.
897 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
898 return false;
899
900 bool result = false;
901 for (auto op : ancestor->getOperands())
902 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
903
904 return result;
905}
906
907/// Infer the memory access pattern for the input ExtractOp
908///
909/// Based on the operation shapes and indices (usually based on the iteration
910/// space of the parent `linalgOp` operation), decides whether the input
911/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
912/// gather load.
913///
914/// Note that it is always safe to use gather load operations for contiguous
915/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
916/// that `extractOp` is a gather load.
917static VectorMemoryAccessKind
918getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
919 LinalgOp &linalgOp) {
920
921 auto targetShape = linalgOp.getStaticLoopRanges();
922 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
923
924 // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
925 if (inputShape.getShape().empty())
926 return VectorMemoryAccessKind::ScalarBroadcast;
927
928 // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
929 // gather load.
930 // TODO: Relax this condition.
931 if (linalgOp.hasDynamicShape())
932 return VectorMemoryAccessKind::Gather;
933
934 // 1. Assume that it's a gather load when reading _into_:
935 // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
936 // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
937 // TODO: Relax these conditions.
938 // FIXME: This condition assumes non-dynamic sizes.
939 if ((llvm::count_if(targetShape,
940 [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
941 targetShape.back() == 1)
942 return VectorMemoryAccessKind::Gather;
943
944 // 2. Assume that it's a gather load when reading _from_ a tensor for which
945 // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
946 // TODO: Relax this condition.
947 if (inputShape.getShape().back() == 1)
948 return VectorMemoryAccessKind::Gather;
949
950 bool leadingIdxsLoopInvariant = true;
951
952 // 3. Analyze the leading indices of `extractOp`.
953 // Look at the way each index is calculated and decide whether it is suitable
954 // for a contiguous load, i.e. whether it's loop invariant.
955 auto indices = extractOp.getIndices();
956 auto leadIndices = indices.drop_back(1);
957
958 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
959 if (inputShape.getShape()[i] == 1)
960 continue;
961
962 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
963 }
964
965 if (!leadingIdxsLoopInvariant) {
966 LDBG("Found gather load: " << extractOp);
967 return VectorMemoryAccessKind::Gather;
968 }
969
970 // 4. Analyze the trailing index for `extractOp`.
971 // At this point we know that the leading indices are loop invariant. This
972 // means that is potentially a scalar or a contiguous load. We can decide
973 // based on the trailing idx.
974 auto extractOpTrailingIdx = indices.back();
975
976 // 4a. Scalar broadcast load
977 // If the trailing index is loop invariant then this is a scalar load.
978 if (leadingIdxsLoopInvariant &&
979 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
980 LDBG("Found scalar broadcast load: " << extractOp);
981
982 return VectorMemoryAccessKind::ScalarBroadcast;
983 }
984
985 // 4b. Contiguous loads
986 // The trailing `extractOp` index should increment with every loop iteration.
987 // This effectively means that it must be based on the trailing loop index.
988 // This is what the following bool captures.
989 bool foundIndexOp = false;
990 bool isContiguousLoad =
991 isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
992 isContiguousLoad &= foundIndexOp;
993
994 if (isContiguousLoad) {
995 LDBG("Found contigous load: " << extractOp);
996 return VectorMemoryAccessKind::Contiguous;
997 }
998
999 // 5. Fallback case - gather load.
1000 LDBG("Found gather load: " << extractOp);
1001 return VectorMemoryAccessKind::Gather;
1002}
1003
1004/// Helper function to vectorize the tensor.extract operations. Returns
1005/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
1006/// should map the produced operations. This function is meant to be used as a
1007/// CustomVectorizationHook.
1008static VectorizationResult
1009vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1010 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1011 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1012 if (!extractOp)
1013 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
1014 auto loc = extractOp.getLoc();
1015
1016 // Compute the static loop sizes of the extract op.
1017 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1018 auto maskConstantOp = rewriter.create<arith::ConstantOp>(
1019 loc,
1020 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1021 /*value=*/true));
1022 auto passThruConstantOp =
1023 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
1024
1025 // Base indices are currently set to 0. We will need to re-visit if more
1026 // generic scenarios are to be supported.
1027 SmallVector<Value> baseIndices(
1028 extractOp.getIndices().size(),
1029 rewriter.create<arith::ConstantIndexOp>(loc, 0));
1030
1031 VectorMemoryAccessKind memAccessKind =
1032 getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1033
1034 // 1. Handle gather access
1035 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1036 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1037
1038 // Generate the gather load
1039 Operation *gatherOp = rewriter.create<vector::GatherOp>(
1040 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1041 maskConstantOp, passThruConstantOp);
1042 gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp: linalgOp);
1043
1044 LDBG("Vectorised as gather load: " << extractOp << "\n");
1045 return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: gatherOp};
1046 }
1047
1048 // 2. Handle:
1049 // a. scalar loads + broadcast,
1050 // b. contiguous loads.
1051 // Both cases use vector.transfer_read.
1052
1053 // Collect indices for `vector.transfer_read`. At this point, the indices will
1054 // either be scalars or would have been broadcast to vectors matching the
1055 // result type. For indices that are vectors, there are two options:
1056 // * for non-trailing indices, all elements are identical (contiguous
1057 // loads are identified by looking for non-trailing indices that are
1058 // invariant with respect to the corresponding linalg.generic), or
1059 // * for trailing indices, the index vector will contain values with stride
1060 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1061 // needed.
1062 // This means that
1063 // * for scalar indices - just re-use it,
1064 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1065 // (0th) element and use that.
1066 SmallVector<Value> transferReadIdxs;
1067 auto resTrailingDim = resultType.getShape().back();
1068 auto zero = rewriter.create<arith::ConstantOp>(
1069 loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
1070 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1071 auto idx = bvm.lookup(extractOp.getIndices()[i]);
1072 if (idx.getType().isIndex()) {
1073 transferReadIdxs.push_back(Elt: idx);
1074 continue;
1075 }
1076
1077 auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1078 loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
1079 bvm.lookup(extractOp.getIndices()[i]));
1080 transferReadIdxs.push_back(
1081 rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1082 }
1083
1084 // `tensor.extract_element` is always in-bounds, hence the following holds.
1085 auto dstRank = resultType.getRank();
1086 auto srcRank = extractOp.getTensor().getType().getRank();
1087 SmallVector<bool> inBounds(dstRank, true);
1088
1089 // 2a. Handle scalar broadcast access.
1090 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1091 MLIRContext *ctx = rewriter.getContext();
1092 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx));
1093 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1094
1095 auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1096 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1097 permutationMap, inBounds);
1098
1099 LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
1100 return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1101 }
1102
1103 // 2b. Handle contiguous access.
1104 auto permutationMap = AffineMap::getMinorIdentityMap(
1105 dims: srcRank, results: std::min(dstRank, srcRank), context: rewriter.getContext());
1106
1107 int32_t rankDiff = dstRank - srcRank;
1108 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1109 // dims so that the ranks match. This is done by extending the map with 0s.
1110 // For example, for dstRank = 3, srcRank = 2, the following map created
1111 // above:
1112 // (d0, d1) --> (d0, d1)
1113 // is extended as:
1114 // (d0, d1) --> (0, d0, d1)
1115 while (rankDiff > 0) {
1116 permutationMap = permutationMap.insertResult(
1117 mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), 0);
1118 rankDiff--;
1119 }
1120
1121 auto transferReadOp = rewriter.create<vector::TransferReadOp>(
1122 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1123 inBounds);
1124
1125 LDBG("Vectorised as contiguous load: " << extractOp);
1126 return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1127}
1128
1129/// Emit reduction operations if the shapes of the value to reduce is different
1130/// that the result shape.
1131// Note: this is a true builder that notifies the OpBuilder listener.
1132// TODO: Consider moving as a static helper on the ReduceOp.
1133static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1134 Value reduceValue, Value initialValue,
1135 const IRMapping &bvm) {
1136 Value reduceVec = bvm.lookup(from: reduceValue);
1137 Value outputVec = bvm.lookup(from: initialValue);
1138 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1139 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1140 // Reduce only if needed as the value may already have been reduce for
1141 // contraction vectorization.
1142 if (!reduceType ||
1143 (outputType && reduceType.getShape() == outputType.getShape()))
1144 return nullptr;
1145 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1146 return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask);
1147}
1148
1149/// Generic vectorization for a single operation `op`, given already vectorized
1150/// operands carried by `bvm`. Vectorization occurs as follows:
1151/// 1. Try to apply any of the `customVectorizationHooks` and return its
1152/// result on success.
1153/// 2. Clone any constant in the current scope without vectorization: each
1154/// consumer of the constant will later determine the shape to which the
1155/// constant needs to be broadcast to.
1156/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1157/// of the `customVectorizationHooks` to cover such cases.
1158/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1159/// operand of maximal rank. Other operands have smaller rank and are
1160/// broadcast accordingly. It is assumed this broadcast is always legal,
1161/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1162///
1163/// This function assumes all operands of `op` have been vectorized and are in
1164/// the `bvm` mapping. As a consequence, this function is meant to be called on
1165/// a topologically-sorted list of ops.
1166/// This function does not update `bvm` but returns a VectorizationStatus that
1167/// instructs the caller what `bvm` update needs to occur.
1168static VectorizationResult
1169vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1170 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1171 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1172 LDBG("vectorize op " << *op << "\n");
1173
1174 // 1. Try to apply any CustomVectorizationHook.
1175 if (!customVectorizationHooks.empty()) {
1176 for (auto &customFunc : customVectorizationHooks) {
1177 VectorizationResult result = customFunc(op, bvm);
1178 if (result.status == VectorizationStatus::Failure)
1179 continue;
1180 return result;
1181 }
1182 }
1183
1184 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1185 // Clone so that the constant is not confined to the linalgOp block .
1186 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1187 return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: rewriter.clone(op&: *op)};
1188
1189 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1190 if (!OpTrait::hasElementwiseMappableTraits(op))
1191 return VectorizationResult{.status: VectorizationStatus::Failure, .newOp: nullptr};
1192
1193 // 4 . Check if the operation is a reduction.
1194 SmallVector<std::pair<Value, Value>> reductionOperands;
1195 for (Value operand : op->getOperands()) {
1196 auto blockArg = dyn_cast<BlockArgument>(Val&: operand);
1197 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1198 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1199 continue;
1200 SmallVector<Operation *> reductionOps;
1201 Value reduceValue = matchReduction(
1202 linalgOp.getRegionOutputArgs(),
1203 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1204 if (!reduceValue)
1205 continue;
1206 reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand));
1207 }
1208 if (!reductionOperands.empty()) {
1209 assert(reductionOperands.size() == 1);
1210 Operation *reduceOp =
1211 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1212 reductionOperands[0].second, bvm);
1213 if (reduceOp)
1214 return VectorizationResult{.status: VectorizationStatus::NewOp, .newOp: reduceOp};
1215 }
1216
1217 // 5. Generic vectorization path for ElementwiseMappable ops.
1218 // a. Get the first max ranked shape.
1219 VectorType firstMaxRankedType;
1220 for (Value operand : op->getOperands()) {
1221 auto vecOperand = bvm.lookup(from: operand);
1222 assert(vecOperand && "Vector operand couldn't be found");
1223
1224 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1225 if (vecType && (!firstMaxRankedType ||
1226 firstMaxRankedType.getRank() < vecType.getRank()))
1227 firstMaxRankedType = vecType;
1228 }
1229 // b. Broadcast each op if needed.
1230 SmallVector<Value> vecOperands;
1231 for (Value scalarOperand : op->getOperands()) {
1232 Value vecOperand = bvm.lookup(from: scalarOperand);
1233 assert(vecOperand && "Vector operand couldn't be found");
1234
1235 if (firstMaxRankedType) {
1236 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1237 getElementTypeOrSelf(vecOperand.getType()),
1238 firstMaxRankedType.getScalableDims());
1239 vecOperands.push_back(Elt: broadcastIfNeeded(rewriter, vecOperand, vecType));
1240 } else {
1241 vecOperands.push_back(Elt: vecOperand);
1242 }
1243 }
1244 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1245 SmallVector<Type> resultTypes;
1246 for (Type resultType : op->getResultTypes()) {
1247 resultTypes.push_back(
1248 firstMaxRankedType
1249 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1250 firstMaxRankedType.getScalableDims())
1251 : resultType);
1252 }
1253 // d. Build and return the new op.
1254 return VectorizationResult{
1255 .status: VectorizationStatus::NewOp,
1256 .newOp: rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1257 resultTypes, op->getAttrs())};
1258}
1259
1260/// Generic vectorization function that rewrites the body of a `linalgOp` into
1261/// vector form. Generic vectorization proceeds as follows:
1262/// 1. Verify the `linalgOp` has one non-empty region.
1263/// 2. Values defined above the region are mapped to themselves and will be
1264/// broadcasted on a per-need basis by their consumers.
1265/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1266/// load).
1267/// TODO: Reuse opportunities for RAR dependencies.
1268/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1269/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1270/// iteration indices.
1271/// 5. Iteratively call vectorizeOneOp on the region operations.
1272///
1273/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1274/// performed to the maximal common vector size implied by the `linalgOp`
1275/// iteration space. This eager broadcasting is introduced in the
1276/// permutation_map of the vector.transfer_read operations. The eager
1277/// broadcasting makes it trivial to detrmine where broadcast, transposes and
1278/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1279/// the absence of good canonicalizations, the amount of work increases.
1280/// This is not deemed a problem as we expect canonicalizations and foldings to
1281/// aggressively clean up the useless work.
1282static LogicalResult
1283vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1284 LinalgOp linalgOp,
1285 SmallVectorImpl<Value> &newResults) {
1286 LDBG("Vectorizing operation as linalg generic\n");
1287 Block *block = linalgOp.getBlock();
1288
1289 // 2. Values defined above the region can only be broadcast for now. Make them
1290 // map to themselves.
1291 IRMapping bvm;
1292 SetVector<Value> valuesSet;
1293 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1294 bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef());
1295
1296 if (linalgOp.getNumDpsInits() == 0)
1297 return failure();
1298
1299 // 3. Turn all BBArgs into vector.transfer_read / load.
1300 Location loc = linalgOp.getLoc();
1301 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1302 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1303 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1304 if (linalgOp.isScalar(opOperand)) {
1305 bvm.map(bbarg, opOperand->get());
1306 continue;
1307 }
1308
1309 // 3.a. Convert the indexing map for this input/output to a transfer read
1310 // permutation map and masking map.
1311 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1312
1313 // Remove zeros from indexing map to use it as masking map.
1314 SmallVector<int64_t> zeroPos;
1315 auto results = indexingMap.getResults();
1316 for (const auto &result : llvm::enumerate(results)) {
1317 if (isa<AffineConstantExpr>(result.value())) {
1318 zeroPos.push_back(result.index());
1319 }
1320 }
1321 AffineMap maskingMap = indexingMap.dropResults(zeroPos);
1322
1323 AffineMap readMap;
1324 VectorType readType;
1325 Type elemType = getElementTypeOrSelf(opOperand->get());
1326 if (linalgOp.isDpsInput(opOperand)) {
1327 // 3.a.i. For input reads we use the canonical vector shape.
1328 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1329 readType = state.getCanonicalVecType(elemType);
1330 } else {
1331 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1332 // reductions), the vector shape is computed by mapping the canonical
1333 // vector shape to the output domain and back to the canonical domain.
1334 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1335 readType =
1336 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1337 }
1338
1339 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1340
1341 Operation *read = rewriter.create<vector::TransferReadOp>(
1342 loc, readType, opOperand->get(), indices, readMap);
1343 read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1344 Value readValue = read->getResult(0);
1345
1346 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1347 // will be in-bounds.
1348 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1349 SmallVector<bool> inBounds(readType.getRank(), true);
1350 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1351 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1352 }
1353
1354 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1355 // TODO: remove this.
1356 if (readType.getRank() == 0)
1357 readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1358
1359 LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
1360 << "\n");
1361 bvm.map(bbarg, readValue);
1362 bvm.map(opOperand->get(), readValue);
1363 }
1364
1365 SmallVector<CustomVectorizationHook> hooks;
1366 // 4a. Register CustomVectorizationHook for yieldOp.
1367 CustomVectorizationHook vectorizeYield =
1368 [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1369 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1370 };
1371 hooks.push_back(Elt: vectorizeYield);
1372
1373 // 4b. Register CustomVectorizationHook for indexOp.
1374 CustomVectorizationHook vectorizeIndex =
1375 [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1376 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1377 };
1378 hooks.push_back(Elt: vectorizeIndex);
1379
1380 // 4c. Register CustomVectorizationHook for extractOp.
1381 CustomVectorizationHook vectorizeExtract =
1382 [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1383 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1384 };
1385 hooks.push_back(Elt: vectorizeExtract);
1386
1387 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1388 for (Operation &op : block->getOperations()) {
1389 VectorizationResult result =
1390 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1391 if (result.status == VectorizationStatus::Failure) {
1392 LDBG("failed to vectorize: " << op << "\n");
1393 return failure();
1394 }
1395 if (result.status == VectorizationStatus::NewOp) {
1396 Operation *maybeMaskedOp =
1397 state.maskOperation(rewriter, result.newOp, linalgOp);
1398 LDBG("New vector op: " << *maybeMaskedOp << "\n");
1399 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1400 }
1401 }
1402
1403 return success();
1404}
1405
1406/// Given a tensor::PackOp, return the `dest` shape before any packing
1407/// permutations.
1408static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
1409 ArrayRef<int64_t> destShape) {
1410 return applyPermutation(destShape, tensor::getPackInverseDestPerm(packOp));
1411}
1412
1413/// Given an input, the mixed destSizes, and the vector sizes for vectorization,
1414/// create an empty destination tensor and create a TransferWriteOp from the
1415/// input to the empty tensor. If the destination shape is not the same as the
1416/// inputVectorSizes for the first rank(inputVectorSizes) dims, then create a
1417/// mask for the write.
1418static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
1419 Value input,
1420 SmallVector<OpFoldResult> destSizes,
1421 ArrayRef<int64_t> inputVectorSizes) {
1422 auto inputType = cast<VectorType>(input.getType());
1423 Value dest = builder.create<tensor::EmptyOp>(loc, destSizes,
1424 inputType.getElementType());
1425 int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1426 auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
1427 Operation *write = builder.create<vector::TransferWriteOp>(
1428 loc,
1429 /*vector=*/input,
1430 /*source=*/dest,
1431 /*indices=*/SmallVector<Value>(rank, zero),
1432 /*inBounds=*/SmallVector<bool>(rank, true));
1433 auto destShape = cast<ShapedType>(dest.getType()).getShape();
1434 assert(llvm::none_of(
1435 destShape.drop_front(inputVectorSizes.size()),
1436 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1437 "Only dims aligned with inputVectorSizes may be dynamic");
1438 bool needMaskForWrite = !llvm::equal(
1439 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1440 if (needMaskForWrite) {
1441 SmallVector<int64_t> writeMaskShape;
1442 writeMaskShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end());
1443 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1444 destShape.end());
1445 auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1446 Value maskForWrite =
1447 builder.create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1448 write = mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite);
1449 }
1450 return write;
1451}
1452
1453/// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
1454/// padding value and (3) input vector sizes into:
1455/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
1456/// As in the following example:
1457/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
1458/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1459///
1460/// This pack would be vectorized to:
1461///
1462/// %load = vector.mask %mask {
1463/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
1464/// {in_bounds = [true, true, true]} :
1465/// tensor<32x7x16xf32>, vector<32x8x16xf32>
1466/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
1467/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
1468/// to vector<32x4x2x1x16xf32>
1469/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
1470/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1471/// %write = vector.transfer_write %transpose,
1472/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
1473/// {in_bounds = [true, true, true, true, true]}
1474/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1475///
1476/// If the (3) input vector sizes are not provided, the vector sizes are
1477/// determined by the result tensor shape. Also, we update the inBounds
1478/// attribute instead of masking.
1479static LogicalResult
1480vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
1481 ArrayRef<int64_t> inputVectorSizes,
1482 SmallVectorImpl<Value> &newResults) {
1483 OpBuilder::InsertionGuard g(rewriter);
1484 rewriter.setInsertionPoint(packOp);
1485
1486 Location loc = packOp.getLoc();
1487 auto padValue = packOp.getPaddingValue();
1488 if (!padValue) {
1489 padValue = rewriter.create<arith::ConstantOp>(
1490 loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
1491 }
1492 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1493 LogicalResult status =
1494 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1495 .reifyResultShapes(rewriter, reifiedReturnShapes);
1496 (void)status; // prevent unused variable warning on non-assert builds.
1497 assert(succeeded(status) && "failed to reify result shapes");
1498
1499 // If the input vector sizes are not provided, then the vector sizes are
1500 // determined by the result tensor shape. In case the vector sizes aren't
1501 // provided, we update the inBounds attribute instead of masking.
1502 bool useInBoundsInsteadOfMasking = true;
1503 if (inputVectorSizes.empty()) {
1504 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1505 inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank());
1506 useInBoundsInsteadOfMasking = false;
1507 }
1508
1509 // Create masked TransferReadOp.
1510 SmallVector<int64_t> inputShape(inputVectorSizes);
1511 auto innerTiles = packOp.getStaticInnerTiles();
1512 auto innerDimsPos = packOp.getInnerDimsPos();
1513 auto outerDimsPerm = packOp.getOuterDimsPerm();
1514 if (!outerDimsPerm.empty())
1515 applyPermutationToVector(inputShape,
1516 invertPermutationVector(outerDimsPerm));
1517 for (auto [idx, size] : enumerate(innerTiles))
1518 inputShape[innerDimsPos[idx]] *= size;
1519 auto maskedRead = vector::createReadOrMaskedRead(
1520 builder&: rewriter, loc, source: packOp.getSource(), readShape: inputShape, padValue: padValue,
1521 useInBoundsInsteadOfMasking);
1522
1523 // Create ShapeCastOp.
1524 SmallVector<int64_t> destShape(inputVectorSizes);
1525 destShape.append(innerTiles.begin(), innerTiles.end());
1526 auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
1527 packOp.getDestType().getElementType());
1528 auto shapeCastOp =
1529 rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1530
1531 // Create TransposeOp.
1532 auto destPermutation =
1533 invertPermutationVector(tensor::getPackInverseDestPerm(packOp));
1534 auto transposeOp = rewriter.create<vector::TransposeOp>(
1535 loc, shapeCastOp.getResult(), destPermutation);
1536
1537 // Create TransferWriteOp.
1538 Operation *write =
1539 createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(),
1540 reifiedReturnShapes[0], inputVectorSizes);
1541 newResults.push_back(Elt: write->getResult(idx: 0));
1542 return success();
1543}
1544
1545/// Vectorize a `tensor::UnPackOp` to these 4 Ops:
1546/// Vector::TransferReadOp - Reads a vector from the source tensor
1547/// vector::TransposeOp - Transpose the Source tensor
1548/// ShapeCastOp - Reshape the data based on the target.
1549/// vector::TransferWriteOp. - Write the result vector back to the destination
1550/// tensor
1551static LogicalResult
1552vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1553 ArrayRef<int64_t> inputVectorSizes,
1554 SmallVectorImpl<Value> &newResults) {
1555
1556 OpBuilder::InsertionGuard g(rewriter);
1557 rewriter.setInsertionPoint(unpackOp);
1558
1559 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1560
1561 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1562 ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
1563
1564 SmallVector<int64_t> readMaskShape(inputVectorSizes.begin(),
1565 inputVectorSizes.end());
1566 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1567 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1568
1569 // ReadMask is the size of tensor used to read and apply mask. It is
1570 // set like this: Let's say the vectorSize (VS) array is size 'N' and
1571 // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1572 // size M-N
1573 // Thus:
1574 // - initially: ReadMaskShape = vectorInputSizes
1575 // - Divide all the readMaskShape locations pointed by innerDimPos
1576 // by the innerTileSize attribute value.
1577 // - if outer_dims_perms is present: do that permutation on readMaskShape.
1578 // - Append the remaining shape from SS
1579 // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1580 // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1581 // 128] and outer_dims_perm is [1, 0] then read shape is:
1582 // ReadMaskShape(initial): [512, 128]
1583 // Final Value(after innerDim Adjustment): [512/32, 128/16]
1584 // = [16, 8]
1585 // After applying outer_dims_perm: [8, 16]
1586 // After appending the rest of the sourceShape: [8, 16, 32, 16]
1587
1588 for (auto [index, size] : enumerate(innerTiles)) {
1589 readMaskShape[innerDimPos[index]] =
1590 llvm::divideCeil(readMaskShape[innerDimPos[index]], size);
1591 }
1592 if (!outerDimsPerm.empty()) {
1593 applyPermutationToVector(inVec&: readMaskShape, permutation: outerDimsPerm);
1594 }
1595 readMaskShape.append(in_start: sourceShape.begin() + inputVectorSizes.size(),
1596 in_end: sourceShape.end());
1597
1598 ReifiedRankedShapedTypeDims reifiedRetShapes;
1599 LogicalResult status =
1600 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1601 .reifyResultShapes(rewriter, reifiedRetShapes);
1602 if (status.failed()) {
1603 LDBG("Unable to reify result shapes of " << unpackOp);
1604 return failure();
1605 }
1606 Location loc = unpackOp->getLoc();
1607
1608 auto padValue = rewriter.create<arith::ConstantOp>(
1609 loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1610
1611 // Read result, mask if necessary. If transferReadOp shape is not equal
1612 // to shape of source, then a mask is necessary.
1613 Value readResult = vector::createReadOrMaskedRead(
1614 builder&: rewriter, loc, source: unpackOp.getSource(),
1615 readShape: ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue: padValue);
1616
1617 PackingMetadata packMetadata;
1618 SmallVector<int64_t> lastDimToInsertPosPerm =
1619 tensor::getUnPackInverseSrcPerm(unpackOp, packMetadata);
1620 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1621 SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1622 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1623 applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm);
1624 RankedTensorType stripMineTensorType =
1625 RankedTensorType::get(stripMineShape, stripMineElemType);
1626 // Transpose the appropriate rows to match output.
1627 vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
1628 loc, readResult, lastDimToInsertPosPerm);
1629
1630 // Collapse the vector to the size required by result.
1631 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1632 stripMineTensorType, packMetadata.reassociations);
1633 mlir::VectorType vecCollapsedType =
1634 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1635 vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
1636 loc, vecCollapsedType, transposeOp->getResult(0));
1637
1638 // WriteMaskShape had to match the shapecast shape for dynamic sizes,
1639 // otherwise the validator complains that the mask size is invalid.
1640 SmallVector<int64_t> writeMaskShape(
1641 unpackOp.getDestType().hasStaticShape()
1642 ? inputVectorSizes
1643 : shapeCastOp.getResultVectorType().getShape());
1644 Operation *write =
1645 createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(),
1646 reifiedRetShapes[0], writeMaskShape);
1647 newResults.push_back(Elt: write->getResult(idx: 0));
1648 return success();
1649}
1650
1651/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1652/// and (3) all-zero lowPad to
1653/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1654static LogicalResult
1655vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1656 ArrayRef<int64_t> inputVectorSizes,
1657 SmallVectorImpl<Value> &newResults) {
1658 auto padValue = padOp.getConstantPaddingValue();
1659 Location loc = padOp.getLoc();
1660
1661 // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1662 OpBuilder::InsertionGuard g(rewriter);
1663 rewriter.setInsertionPoint(padOp);
1664
1665 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1666 LogicalResult status =
1667 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1668 .reifyResultShapes(rewriter, reifiedReturnShapes);
1669 (void)status; // prevent unused variable warning on non-assert builds
1670 assert(succeeded(status) && "failed to reify result shapes");
1671 auto maskedRead = vector::createReadOrMaskedRead(
1672 builder&: rewriter, loc, source: padOp.getSource(), readShape: inputVectorSizes, padValue: padValue);
1673 Operation *write = createWriteOrMaskedWrite(
1674 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
1675 newResults.push_back(Elt: write->getResult(idx: 0));
1676 return success();
1677}
1678
1679// TODO: probably need some extra checks for reduction followed by consumer
1680// ops that may not commute (e.g. linear reduction + non-linear instructions).
1681static LogicalResult reductionPreconditions(LinalgOp op) {
1682 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1683 LDBG("reduction precondition failed: no reduction iterator\n");
1684 return failure();
1685 }
1686 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1687 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1688 if (indexingMap.isPermutation())
1689 continue;
1690
1691 Operation *reduceOp = matchLinalgReduction(&opOperand);
1692 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1693 LDBG("reduction precondition failed: reduction detection failed\n");
1694 return failure();
1695 }
1696 }
1697 return success();
1698}
1699
1700static LogicalResult
1701vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1702 bool flatten1DDepthwiseConv) {
1703 if (flatten1DDepthwiseConv) {
1704 LDBG("Vectorization of flattened convs with dynamic shapes is not "
1705 "supported\n");
1706 return failure();
1707 }
1708
1709 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1710 LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1711 return failure();
1712 }
1713
1714 // Support dynamic shapes in 1D depthwise convolution, but only in the
1715 // _channel_ dimension.
1716 Value lhs = conv.getDpsInputOperand(0)->get();
1717 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1718 auto shapeWithoutCh = lhsShape.drop_back(N: 1);
1719 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1720 LDBG("Dynamically-shaped op vectorization precondition failed: only "
1721 "channel dim can be dynamic\n");
1722 return failure();
1723 }
1724
1725 return success();
1726}
1727
1728static LogicalResult
1729vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1730 bool flatten1DDepthwiseConv) {
1731 if (isa<ConvolutionOpInterface>(op.getOperation()))
1732 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1733
1734 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1735 // linalg.copy ops and ops that implement ContractionOpInterface for now.
1736 if (!isElementwise(op) &&
1737 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1738 op.getOperation()))
1739 return failure();
1740
1741 LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
1742 return success();
1743}
1744
1745/// Need to check if the inner-tiles are static/constant.
1746static LogicalResult
1747vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1748 ArrayRef<int64_t> inputVectorSizes) {
1749
1750 if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
1751 return !getConstantIntValue(ofr: res).has_value();
1752 })) {
1753 LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
1754 return failure();
1755 }
1756 llvm::ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
1757 if (!inputVectorSizes.empty() &&
1758 failed(result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes)))
1759 return failure();
1760
1761 return success();
1762}
1763
1764static LogicalResult vectorizeLinalgOpPrecondition(
1765 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
1766 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
1767 // tensor with dimension of 0 cannot be vectorized.
1768 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1769 return failure();
1770 // Check API contract for input vector sizes.
1771 if (!inputVectorSizes.empty() &&
1772 failed(vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(),
1773 inputVectorSizes)))
1774 return failure();
1775
1776 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1777 linalgOp, flatten1DDepthwiseConv))) {
1778 LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
1779 return failure();
1780 }
1781
1782 SmallVector<CustomVectorizationPrecondition> customPreconditions;
1783
1784 // Register CustomVectorizationPrecondition for extractOp.
1785 customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition);
1786
1787 // All types in the body should be a supported element type for VectorType.
1788 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
1789 // Check if any custom hook can vectorize the inner op.
1790 if (llvm::any_of(
1791 customPreconditions,
1792 [&](const CustomVectorizationPrecondition &customPrecondition) {
1793 return succeeded(
1794 customPrecondition(&innerOp, vectorizeNDExtract));
1795 })) {
1796 continue;
1797 }
1798 if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
1799 return !VectorType::isValidElementType(type);
1800 })) {
1801 return failure();
1802 }
1803 if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
1804 return !VectorType::isValidElementType(type);
1805 })) {
1806 return failure();
1807 }
1808 }
1809 if (isElementwise(linalgOp))
1810 return success();
1811
1812 // TODO: isaConvolutionOpInterface that can also infer from generic
1813 // features. But we will still need stride/dilation attributes that will be
1814 // annoying to reverse-engineer...
1815 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1816 return success();
1817 // TODO: the common vector shape is equal to the static loop sizes only when
1818 // all indexing maps are projected permutations. For convs and stencils the
1819 // logic will need to evolve.
1820 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
1821 LDBG("precondition failed: not projected permutations\n");
1822 return failure();
1823 }
1824 if (failed(reductionPreconditions(linalgOp))) {
1825 LDBG("precondition failed: reduction preconditions\n");
1826 return failure();
1827 }
1828 return success();
1829}
1830
1831static LogicalResult
1832vectorizePackOpPrecondition(tensor::PackOp packOp,
1833 ArrayRef<int64_t> inputVectorSizes) {
1834 auto padValue = packOp.getPaddingValue();
1835 Attribute cstAttr;
1836 if (padValue && !matchPattern(padValue, m_Constant(bind_value: &cstAttr))) {
1837 LDBG("pad value is not constant: " << packOp << "\n");
1838 return failure();
1839 }
1840 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
1841 bool satisfyEmptyCond = true;
1842 if (inputVectorSizes.empty()) {
1843 if (!packOp.getDestType().hasStaticShape() ||
1844 !packOp.getSourceType().hasStaticShape())
1845 satisfyEmptyCond = false;
1846 }
1847
1848 if (!satisfyEmptyCond &&
1849 failed(vector::isValidMaskedInputVector(
1850 shape: resultTensorShape.take_front(N: packOp.getSourceRank()),
1851 inputVectorSizes)))
1852 return failure();
1853
1854 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
1855 return !getConstantIntValue(ofr: v).has_value();
1856 })) {
1857 LDBG("inner_tiles must be constant: " << packOp << "\n");
1858 return failure();
1859 }
1860
1861 return success();
1862}
1863
1864static LogicalResult
1865vectorizePadOpPrecondition(tensor::PadOp padOp,
1866 ArrayRef<int64_t> inputVectorSizes) {
1867 auto padValue = padOp.getConstantPaddingValue();
1868 if (!padValue) {
1869 LDBG("pad value is not constant: " << padOp << "\n");
1870 return failure();
1871 }
1872
1873 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
1874 if (failed(result: vector::isValidMaskedInputVector(shape: resultTensorShape,
1875 inputVectorSizes)))
1876 return failure();
1877
1878 if (llvm::any_of(padOp.getLow(), [](Value v) {
1879 std::optional<int64_t> res = getConstantIntValue(ofr: v);
1880 return !res.has_value() || res.value() != 0;
1881 })) {
1882 LDBG("low pad must all be zero: " << padOp << "\n");
1883 return failure();
1884 }
1885
1886 return success();
1887}
1888
1889/// Preconditions for scalable vectors.
1890static LogicalResult
1891vectorizeScalableVectorPrecondition(Operation *op,
1892 ArrayRef<int64_t> inputVectorSizes,
1893 ArrayRef<bool> inputScalableVecDims) {
1894 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1895 "Number of input vector sizes and scalable dims doesn't match");
1896
1897 if (inputVectorSizes.empty())
1898 return success();
1899
1900 bool isScalable = inputScalableVecDims.back();
1901 if (!isScalable)
1902 return success();
1903
1904 // Only element-wise and 1d depthwise conv ops supported in the presence of
1905 // scalable dims.
1906 auto linalgOp = dyn_cast<LinalgOp>(op);
1907 return success(linalgOp && (isElementwise(linalgOp) ||
1908 isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
1909}
1910
1911LogicalResult mlir::linalg::vectorizeOpPrecondition(
1912 Operation *op, ArrayRef<int64_t> inputVectorSizes,
1913 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
1914 bool flatten1DDepthwiseConv) {
1915 if (failed(result: vectorizeScalableVectorPrecondition(op, inputVectorSizes,
1916 inputScalableVecDims)))
1917 return failure();
1918
1919 return TypeSwitch<Operation *, LogicalResult>(op)
1920 .Case<linalg::LinalgOp>([&](auto linalgOp) {
1921 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1922 vectorizeNDExtract,
1923 flatten1DDepthwiseConv);
1924 })
1925 .Case<tensor::PadOp>([&](auto padOp) {
1926 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
1927 })
1928 .Case<tensor::PackOp>([&](auto packOp) {
1929 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
1930 })
1931 .Case<tensor::UnPackOp>([&](auto unpackOp) {
1932 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
1933 })
1934 .Default([](auto) { return failure(); });
1935}
1936
1937/// Converts affine.apply Ops to arithmetic operations.
1938static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1939 OpBuilder::InsertionGuard g(rewriter);
1940 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1941
1942 for (auto op : make_early_inc_range(toReplace)) {
1943 rewriter.setInsertionPoint(op);
1944 auto expanded = affine::expandAffineExpr(
1945 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1946 op.getOperands().take_front(op.getAffineMap().getNumDims()),
1947 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1948 rewriter.replaceOp(op, expanded);
1949 }
1950}
1951
1952/// Emit a suitable vector form for an operation. If provided,
1953/// `inputVectorSizes` are used to vectorize this operation.
1954/// `inputVectorSizes` must match the rank of the iteration space of the
1955/// operation and the input vector sizes must be greater than or equal to
1956/// their counterpart iteration space sizes, if static. `inputVectorShapes`
1957/// also allows the vectorization of operations with dynamic shapes.
1958LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1959 ArrayRef<int64_t> inputVectorSizes,
1960 ArrayRef<bool> inputScalableVecDims,
1961 bool vectorizeNDExtract,
1962 bool flatten1DDepthwiseConv) {
1963 LDBG("Attempting to vectorize:\n" << *op << "\n");
1964 LDBG("Input vector sizes: ");
1965 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1966 LLVM_DEBUG(llvm::dbgs() << "\n");
1967 LDBG("Input scalable vector dims: ");
1968 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
1969 LLVM_DEBUG(llvm::dbgs() << "\n");
1970
1971 if (failed(result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
1972 vectorizeNDExtract,
1973 flatten1DDepthwiseConv))) {
1974 LDBG("Vectorization pre-conditions failed\n");
1975 return failure();
1976 }
1977
1978 // Initialize vectorization state.
1979 VectorizationState state(rewriter);
1980 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
1981 if (failed(state.initState(rewriter, linalgOp: linalgOp, inputVectorSizes,
1982 inputScalableVecDims))) {
1983 LDBG("Vectorization state couldn't be initialized\n");
1984 return failure();
1985 }
1986 }
1987
1988 SmallVector<Value> results;
1989 auto vectorizeResult =
1990 TypeSwitch<Operation *, LogicalResult>(op)
1991 .Case<linalg::LinalgOp>([&](auto linalgOp) {
1992 // TODO: isaConvolutionOpInterface that can also infer from
1993 // generic features. Will require stride/dilation attributes
1994 // inference.
1995 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
1996 FailureOr<Operation *> convOr = vectorizeConvolution(
1997 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
1998 flatten1DDepthwiseConv);
1999 if (succeeded(convOr)) {
2000 llvm::append_range(results, (*convOr)->getResults());
2001 return success();
2002 }
2003
2004 LDBG("Unsupported convolution can't be vectorized.\n");
2005 return failure();
2006 }
2007
2008 LDBG("Vectorize generic by broadcasting to the canonical vector "
2009 "shape\n");
2010
2011 // Pre-process before proceeding.
2012 convertAffineApply(rewriter, linalgOp);
2013
2014 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2015 // to 'OpBuilder' when it is passed over to some methods like
2016 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2017 // erase an op within these methods, the actual rewriter won't be
2018 // notified and we will end up with read-after-free issues!
2019 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2020 })
2021 .Case<tensor::PadOp>([&](auto padOp) {
2022 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2023 results);
2024 })
2025 .Case<tensor::PackOp>([&](auto packOp) {
2026 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2027 results);
2028 })
2029 .Case<tensor::UnPackOp>([&](auto unpackOp) {
2030 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2031 inputVectorSizes, results);
2032 })
2033 .Default([](auto) { return failure(); });
2034
2035 if (failed(vectorizeResult)) {
2036 LDBG("Vectorization failed\n");
2037 return failure();
2038 }
2039
2040 if (!results.empty())
2041 rewriter.replaceOp(op, newValues: results);
2042 else
2043 rewriter.eraseOp(op);
2044
2045 return success();
2046}
2047
2048LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2049 memref::CopyOp copyOp) {
2050 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2051 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2052 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2053 return failure();
2054
2055 auto srcElementType = getElementTypeOrSelf(srcType);
2056 auto dstElementType = getElementTypeOrSelf(dstType);
2057 if (!VectorType::isValidElementType(srcElementType) ||
2058 !VectorType::isValidElementType(dstElementType))
2059 return failure();
2060
2061 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2062 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2063
2064 Location loc = copyOp->getLoc();
2065 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
2066 SmallVector<Value> indices(srcType.getRank(), zero);
2067
2068 Value readValue = rewriter.create<vector::TransferReadOp>(
2069 loc, readType, copyOp.getSource(), indices,
2070 rewriter.getMultiDimIdentityMap(rank: srcType.getRank()));
2071 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2072 readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2073 readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
2074 }
2075 Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
2076 loc, readValue, copyOp.getTarget(), indices,
2077 rewriter.getMultiDimIdentityMap(rank: srcType.getRank()));
2078 rewriter.replaceOp(copyOp, writeValue->getResults());
2079 return success();
2080}
2081
2082//----------------------------------------------------------------------------//
2083// Misc. vectorization patterns.
2084//----------------------------------------------------------------------------//
2085
2086/// Helper function that retrieves the value of an IntegerAttr.
2087static int64_t getIntFromAttr(Attribute attr) {
2088 return cast<IntegerAttr>(attr).getInt();
2089}
2090
2091/// Given an ArrayRef of OpFoldResults, return a vector of Values.
2092/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2093/// not supported.
2094static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
2095 ArrayRef<OpFoldResult> ofrs) {
2096 SmallVector<Value> result;
2097 for (auto o : ofrs) {
2098 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: o)) {
2099 result.push_back(Elt: val);
2100 } else {
2101 result.push_back(rewriter.create<arith::ConstantIndexOp>(
2102 location: loc, args: getIntFromAttr(attr: o.template get<Attribute>())));
2103 }
2104 }
2105 return result;
2106}
2107
2108/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2109/// InsertSliceOp. For now, only constant padding values are supported.
2110/// If there is enough static type information, TransferReadOps and
2111/// TransferWriteOps may be generated instead of InsertSliceOps.
2112struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2113 GenericPadOpVectorizationPattern(MLIRContext *context,
2114 PatternBenefit benefit = 1)
2115 : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2116 /// Vectorize the copying of a tensor::PadOp's source. This is possible if
2117 /// each dimension size is statically know in the source type or the result
2118 /// type (or both).
2119 static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
2120 tensor::PadOp padOp, Value dest) {
2121 auto sourceType = padOp.getSourceType();
2122 auto resultType = padOp.getResultType();
2123 if (!VectorType::isValidElementType(sourceType.getElementType()))
2124 return failure();
2125
2126 // Copy cannot be vectorized if pad value is non-constant and source shape
2127 // is dynamic. In case of a dynamic source shape, padding must be appended
2128 // by TransferReadOp, but TransferReadOp supports only constant padding.
2129 auto padValue = padOp.getConstantPaddingValue();
2130 if (!padValue) {
2131 if (!sourceType.hasStaticShape())
2132 return failure();
2133 // Create dummy padding value.
2134 auto elemType = sourceType.getElementType();
2135 padValue = rewriter.create<arith::ConstantOp>(
2136 padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
2137 }
2138
2139 SmallVector<int64_t> vecShape;
2140 SmallVector<bool> readInBounds;
2141 SmallVector<bool> writeInBounds;
2142 for (unsigned i = 0; i < sourceType.getRank(); ++i) {
2143 if (!sourceType.isDynamicDim(i)) {
2144 vecShape.push_back(Elt: sourceType.getDimSize(i));
2145 // Source shape is statically known: Neither read nor write are
2146 // out-of- bounds.
2147 readInBounds.push_back(Elt: true);
2148 writeInBounds.push_back(Elt: true);
2149 } else if (!resultType.isDynamicDim(i)) {
2150 // Source shape is not statically known, but result shape is.
2151 // Vectorize with size of result shape. This may be larger than the
2152 // source size.
2153 vecShape.push_back(Elt: resultType.getDimSize(i));
2154 // Read may be out-of-bounds because the result size could be larger
2155 // than the source size.
2156 readInBounds.push_back(Elt: false);
2157 // Write is out-of-bounds if low padding > 0.
2158 writeInBounds.push_back(
2159 Elt: getConstantIntValue(padOp.getMixedLowPad()[i]) ==
2160 static_cast<int64_t>(0));
2161 } else {
2162 // Neither source nor result dim of padOp is static. Cannot vectorize
2163 // the copy.
2164 return failure();
2165 }
2166 }
2167 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2168
2169 // Generate TransferReadOp.
2170 SmallVector<Value> readIndices(
2171 vecType.getRank(),
2172 rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2173 auto read = rewriter.create<vector::TransferReadOp>(
2174 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2175 ArrayRef<bool>{readInBounds});
2176
2177 // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2178 // entire tensor, write directly to the FillOp's operand.
2179 if (llvm::equal(vecShape, resultType.getShape()) &&
2180 llvm::all_of(Range&: writeInBounds, P: [](bool b) { return b; }))
2181 if (auto fill = dest.getDefiningOp<FillOp>())
2182 dest = fill.output();
2183
2184 // Generate TransferWriteOp.
2185 auto writeIndices =
2186 ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
2187 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2188 padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
2189
2190 return success();
2191 }
2192};
2193
2194/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2195/// given operation type OpTy.
2196template <typename OpTy>
2197struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2198 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2199
2200 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2201 PatternRewriter &rewriter) const final {
2202 bool changed = false;
2203 // Insert users in vector, because some users may be replaced/removed.
2204 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2205 if (auto op = dyn_cast<OpTy>(user))
2206 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2207 return success(isSuccess: changed);
2208 }
2209
2210protected:
2211 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2212 tensor::PadOp padOp, OpTy op) const = 0;
2213};
2214
2215/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2216/// ```
2217/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2218/// %r = vector.transfer_read %0[%c0, %c0], %cst
2219/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2220/// ```
2221/// is rewritten to:
2222/// ```
2223/// %r = vector.transfer_read %src[%c0, %c0], %padding
2224/// {in_bounds = [true, true]}
2225/// : tensor<?x?xf32>, vector<17x5xf32>
2226/// ```
2227/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2228/// sure that the original padding value %cst was never used.
2229///
2230/// This rewrite is possible if:
2231/// - `xferOp` has no out-of-bounds dims or mask.
2232/// - Low padding is static 0.
2233/// - Single, scalar padding value.
2234struct PadOpVectorizationWithTransferReadPattern
2235 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2236 using VectorizePadOpUserPattern<
2237 vector::TransferReadOp>::VectorizePadOpUserPattern;
2238
2239 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2240 vector::TransferReadOp xferOp) const override {
2241 // Low padding must be static 0.
2242 if (!padOp.hasZeroLowPad())
2243 return failure();
2244 // Pad value must be a constant.
2245 auto padValue = padOp.getConstantPaddingValue();
2246 if (!padValue)
2247 return failure();
2248 // Padding value of existing `xferOp` is unused.
2249 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2250 return failure();
2251
2252 rewriter.modifyOpInPlace(xferOp, [&]() {
2253 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2254 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2255 rewriter.getBoolArrayAttr(inBounds));
2256 xferOp.getSourceMutable().assign(padOp.getSource());
2257 xferOp.getPaddingMutable().assign(padValue);
2258 });
2259
2260 return success();
2261 }
2262};
2263
2264/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2265/// This pattern rewrites TransferWriteOps that write to a padded tensor
2266/// value, where the same amount of padding is immediately removed again after
2267/// the write. In such cases, the TransferWriteOp can write to the non-padded
2268/// tensor value and apply out-of-bounds masking. E.g.:
2269/// ```
2270/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2271/// : tensor<...> to tensor<?x?xf32>
2272/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2273/// %2 = vector.transfer_write %vec, %1[...]
2274/// : vector<17x5xf32>, tensor<17x5xf32>
2275/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2276/// : tensor<17x5xf32> to tensor<?x?xf32>
2277/// ```
2278/// is rewritten to:
2279/// ```
2280/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2281/// : tensor<...> to tensor<?x?xf32>
2282/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2283/// tensor<?x?xf32>
2284/// ```
2285/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2286/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2287/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2288/// from %r's old dimensions.
2289///
2290/// This rewrite is possible if:
2291/// - Low padding is static 0.
2292/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2293/// ExtractSliceOp trims the same amount of padding that was added
2294/// beforehand.
2295/// - Single, scalar padding value.
2296struct PadOpVectorizationWithTransferWritePattern
2297 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2298 using VectorizePadOpUserPattern<
2299 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2300
2301 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2302 vector::TransferWriteOp xferOp) const override {
2303 // TODO: support 0-d corner case.
2304 if (xferOp.getTransferRank() == 0)
2305 return failure();
2306
2307 // Low padding must be static 0.
2308 if (!padOp.hasZeroLowPad())
2309 return failure();
2310 // Pad value must be a constant.
2311 auto padValue = padOp.getConstantPaddingValue();
2312 if (!padValue)
2313 return failure();
2314 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2315 if (!xferOp->hasOneUse())
2316 return failure();
2317 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2318 if (!trimPadding)
2319 return failure();
2320 // Only static zero offsets supported when trimming padding.
2321 if (!trimPadding.hasZeroOffset())
2322 return failure();
2323 // trimPadding must remove the amount of padding that was added earlier.
2324 if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding))
2325 return failure();
2326
2327 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2328 rewriter.setInsertionPoint(xferOp);
2329
2330 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2331 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2332 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2333 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2334 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2335 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2336
2337 return success();
2338 }
2339
2340 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2341 /// i.e., same dimensions.
2342 ///
2343 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2344 /// dimensions, this function tries to infer the (static) tensor size by
2345 /// looking at the defining op and utilizing op-specific knowledge.
2346 ///
2347 /// This is a conservative analysis. In case equal tensor sizes cannot be
2348 /// proven statically, this analysis returns `false` even though the tensor
2349 /// sizes may turn out to be equal at runtime.
2350 bool hasSameTensorSize(Value beforePadding,
2351 tensor::ExtractSliceOp afterTrimming) const {
2352 // If the input to tensor::PadOp is a CastOp, try with both CastOp
2353 // result and CastOp operand.
2354 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2355 if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming: afterTrimming))
2356 return true;
2357
2358 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2359 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2360 // Only RankedTensorType supported.
2361 if (!t1 || !t2)
2362 return false;
2363 // Rank of both values must be the same.
2364 if (t1.getRank() != t2.getRank())
2365 return false;
2366
2367 // All static dimensions must be the same. Mixed cases (e.g., dimension
2368 // static in `t1` but dynamic in `t2`) are not supported.
2369 for (unsigned i = 0; i < t1.getRank(); ++i) {
2370 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2371 return false;
2372 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2373 return false;
2374 }
2375
2376 // Nothing more to check if all dimensions are static.
2377 if (t1.getNumDynamicDims() == 0)
2378 return true;
2379
2380 // All dynamic sizes must be the same. The only supported case at the
2381 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2382 // thereof).
2383
2384 // Apart from CastOp, only ExtractSliceOp is supported.
2385 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2386 if (!beforeSlice)
2387 return false;
2388
2389 assert(static_cast<size_t>(t1.getRank()) ==
2390 beforeSlice.getMixedSizes().size());
2391 assert(static_cast<size_t>(t2.getRank()) ==
2392 afterTrimming.getMixedSizes().size());
2393
2394 for (unsigned i = 0; i < t1.getRank(); ++i) {
2395 // Skip static dimensions.
2396 if (!t1.isDynamicDim(i))
2397 continue;
2398 auto size1 = beforeSlice.getMixedSizes()[i];
2399 auto size2 = afterTrimming.getMixedSizes()[i];
2400
2401 // Case 1: Same value or same constant int.
2402 if (isEqualConstantIntOrValue(size1, size2))
2403 continue;
2404
2405 // Other cases: Take a deeper look at defining ops of values.
2406 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2407 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2408 if (!v1 || !v2)
2409 return false;
2410
2411 // Case 2: Both values are identical AffineMinOps. (Should not happen if
2412 // CSE is run.)
2413 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2414 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2415 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2416 minOp1.getOperands() == minOp2.getOperands())
2417 continue;
2418
2419 // Add additional cases as needed.
2420 }
2421
2422 // All tests passed.
2423 return true;
2424 }
2425};
2426
2427/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2428/// ```
2429/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2430/// %r = tensor.insert_slice %0
2431/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
2432/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
2433/// ```
2434/// is rewritten to:
2435/// ```
2436/// %0 = vector.transfer_read %src[%c0, %c0], %padding
2437/// : tensor<?x?xf32>, vector<17x5xf32>
2438/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
2439/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
2440/// ```
2441///
2442/// This rewrite is possible if:
2443/// - Low padding is static 0.
2444/// - `padOp` result shape is static.
2445/// - The entire padded tensor is inserted.
2446/// (Implies that sizes of `insertOp` are all static.)
2447/// - Only unit strides in `insertOp`.
2448/// - Single, scalar padding value.
2449/// - `padOp` result not used as destination.
2450struct PadOpVectorizationWithInsertSlicePattern
2451 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
2452 using VectorizePadOpUserPattern<
2453 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
2454
2455 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2456 tensor::InsertSliceOp insertOp) const override {
2457 // Low padding must be static 0.
2458 if (!padOp.hasZeroLowPad())
2459 return failure();
2460 // Only unit stride supported.
2461 if (!insertOp.hasUnitStride())
2462 return failure();
2463 // Pad value must be a constant.
2464 auto padValue = padOp.getConstantPaddingValue();
2465 if (!padValue)
2466 return failure();
2467 // Dynamic shapes not supported.
2468 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2469 return failure();
2470 // Pad result not used as destination.
2471 if (insertOp.getDest() == padOp.getResult())
2472 return failure();
2473
2474 auto vecType = VectorType::get(padOp.getType().getShape(),
2475 padOp.getType().getElementType());
2476 unsigned vecRank = vecType.getRank();
2477 unsigned tensorRank = insertOp.getType().getRank();
2478
2479 // Check if sizes match: Insert the entire tensor into most minor dims.
2480 // (No permutations allowed.)
2481 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
2482 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2483 if (!llvm::all_of(
2484 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
2485 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2486 }))
2487 return failure();
2488
2489 // Insert the TransferReadOp and TransferWriteOp at the position of the
2490 // InsertSliceOp.
2491 rewriter.setInsertionPoint(insertOp);
2492
2493 // Generate TransferReadOp: Read entire source tensor and add high
2494 // padding.
2495 SmallVector<Value> readIndices(
2496 vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2497 auto read = rewriter.create<vector::TransferReadOp>(
2498 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2499
2500 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
2501 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
2502 // source must fit into the destination at the specified offsets.
2503 auto writeIndices =
2504 ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2505 SmallVector<bool> inBounds(vecRank, true);
2506 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2507 insertOp, read, insertOp.getDest(), writeIndices,
2508 ArrayRef<bool>{inBounds});
2509
2510 return success();
2511 }
2512};
2513
2514void mlir::linalg::populatePadOpVectorizationPatterns(
2515 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2516 patterns.add<GenericPadOpVectorizationPattern>(arg: patterns.getContext(),
2517 args&: baseBenefit);
2518 // Try these specialized patterns first before resorting to the generic one.
2519 patterns.add<PadOpVectorizationWithTransferReadPattern,
2520 PadOpVectorizationWithTransferWritePattern,
2521 PadOpVectorizationWithInsertSlicePattern>(
2522 arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1);
2523}
2524
2525//----------------------------------------------------------------------------//
2526// Forwarding patterns
2527//----------------------------------------------------------------------------//
2528
2529/// Check whether there is any interleaved use of any `values` between
2530/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
2531/// is in a different block.
2532static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
2533 ValueRange values) {
2534 if (firstOp->getBlock() != secondOp->getBlock() ||
2535 !firstOp->isBeforeInBlock(other: secondOp)) {
2536 LDBG("interleavedUses precondition failed, firstOp: "
2537 << *firstOp << ", second op: " << *secondOp << "\n");
2538 return true;
2539 }
2540 for (auto v : values) {
2541 for (auto &u : v.getUses()) {
2542 Operation *owner = u.getOwner();
2543 if (owner == firstOp || owner == secondOp)
2544 continue;
2545 // TODO: this is too conservative, use dominance info in the future.
2546 if (owner->getBlock() == firstOp->getBlock() &&
2547 (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner)))
2548 continue;
2549 LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
2550 << ", second op: " << *secondOp << "\n");
2551 return true;
2552 }
2553 }
2554 return false;
2555}
2556
2557/// Return the unique subview use of `v` if it is indeed unique, null
2558/// otherwise.
2559static memref::SubViewOp getSubViewUseIfUnique(Value v) {
2560 memref::SubViewOp subViewOp;
2561 for (auto &u : v.getUses()) {
2562 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2563 if (subViewOp)
2564 return memref::SubViewOp();
2565 subViewOp = newSubViewOp;
2566 }
2567 }
2568 return subViewOp;
2569}
2570
2571/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2572/// when available.
2573LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
2574 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
2575
2576 // TODO: support mask.
2577 if (xferOp.getMask())
2578 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2579
2580 // Transfer into `view`.
2581 Value viewOrAlloc = xferOp.getSource();
2582 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2583 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2584 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2585
2586 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2587 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2588 if (!subViewOp)
2589 return rewriter.notifyMatchFailure(xferOp, "no subview found");
2590 Value subView = subViewOp.getResult();
2591
2592 // Find the copy into `subView` without interleaved uses.
2593 memref::CopyOp copyOp;
2594 for (auto &u : subView.getUses()) {
2595 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2596 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2597 if (newCopyOp.getTarget() != subView)
2598 continue;
2599 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
2600 continue;
2601 copyOp = newCopyOp;
2602 break;
2603 }
2604 }
2605 if (!copyOp)
2606 return rewriter.notifyMatchFailure(xferOp, "no copy found");
2607
2608 // Find the fill into `viewOrAlloc` without interleaved uses before the
2609 // copy.
2610 FillOp maybeFillOp;
2611 for (auto &u : viewOrAlloc.getUses()) {
2612 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2613 assert(isa<MemRefType>(newFillOp.output().getType()));
2614 if (newFillOp.output() != viewOrAlloc)
2615 continue;
2616 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
2617 continue;
2618 maybeFillOp = newFillOp;
2619 break;
2620 }
2621 }
2622 // Ensure padding matches.
2623 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2624 return rewriter.notifyMatchFailure(xferOp,
2625 "padding value does not match fill");
2626
2627 // `in` is the subview that memref.copy reads. Replace it.
2628 Value in = copyOp.getSource();
2629
2630 // memref.copy + linalg.fill can be used to create a padded local buffer.
2631 // The `masked` attribute is only valid on this padded buffer.
2632 // When forwarding to vector.transfer_read, the attribute must be reset
2633 // conservatively.
2634 Value res = rewriter.create<vector::TransferReadOp>(
2635 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2636 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2637 // in_bounds is explicitly reset
2638 /*inBoundsAttr=*/ArrayAttr());
2639
2640 if (maybeFillOp)
2641 rewriter.eraseOp(op: maybeFillOp);
2642 rewriter.eraseOp(op: copyOp);
2643 rewriter.replaceOp(xferOp, res);
2644
2645 return success();
2646}
2647
2648/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
2649/// when available.
2650LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
2651 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
2652 // TODO: support mask.
2653 if (xferOp.getMask())
2654 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
2655
2656 // Transfer into `viewOrAlloc`.
2657 Value viewOrAlloc = xferOp.getSource();
2658 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
2659 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
2660 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
2661
2662 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
2663 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
2664 if (!subViewOp)
2665 return rewriter.notifyMatchFailure(xferOp, "no subview found");
2666 Value subView = subViewOp.getResult();
2667
2668 // Find the copy from `subView` without interleaved uses.
2669 memref::CopyOp copyOp;
2670 for (auto &u : subViewOp.getResult().getUses()) {
2671 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2672 if (newCopyOp.getSource() != subView)
2673 continue;
2674 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
2675 continue;
2676 copyOp = newCopyOp;
2677 break;
2678 }
2679 }
2680 if (!copyOp)
2681 return rewriter.notifyMatchFailure(xferOp, "no copy found");
2682
2683 // `out` is the subview copied into that we replace.
2684 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2685 Value out = copyOp.getTarget();
2686
2687 // Forward vector.transfer into copy.
2688 // memref.copy + linalg.fill can be used to create a padded local buffer.
2689 // The `masked` attribute is only valid on this padded buffer.
2690 // When forwarding to vector.transfer_write, the attribute must be reset
2691 // conservatively.
2692 rewriter.create<vector::TransferWriteOp>(
2693 xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2694 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2695 // in_bounds is explicitly reset
2696 /*inBoundsAttr=*/ArrayAttr());
2697
2698 rewriter.eraseOp(op: copyOp);
2699 rewriter.eraseOp(op: xferOp);
2700
2701 return success();
2702}
2703
2704//===----------------------------------------------------------------------===//
2705// Convolution vectorization patterns
2706//===----------------------------------------------------------------------===//
2707
2708template <int N>
2709static void bindShapeDims(ShapedType shapedType) {}
2710
2711template <int N, typename IntTy, typename... IntTy2>
2712static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
2713 val = shapedType.getShape()[N];
2714 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
2715}
2716
2717/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
2718template <typename... IntTy>
2719static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
2720 bindShapeDims<0>(shapedType, vals...);
2721}
2722
2723namespace {
2724bool isCastOfBlockArgument(Operation *op) {
2725 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2726 isa<BlockArgument>(op->getOperand(0));
2727}
2728
2729bool isSupportedPoolKind(vector::CombiningKind kind) {
2730 switch (kind) {
2731 case vector::CombiningKind::ADD:
2732 case vector::CombiningKind::MAXNUMF:
2733 case vector::CombiningKind::MAXIMUMF:
2734 case vector::CombiningKind::MAXSI:
2735 case vector::CombiningKind::MAXUI:
2736 case vector::CombiningKind::MINNUMF:
2737 case vector::CombiningKind::MINIMUMF:
2738 case vector::CombiningKind::MINSI:
2739 case vector::CombiningKind::MINUI:
2740 return true;
2741 default:
2742 return false;
2743 }
2744}
2745
2746/// Generate a vector implementation for either:
2747/// ```
2748/// Op def: ( w, kw )
2749/// Iters: ({Par(), Red()})
2750/// Layout: {{w + kw}, {kw}, {w}}
2751/// ```
2752/// kw is unrolled.
2753///
2754/// or
2755///
2756/// ```
2757/// Op def: ( n, w, c, kw, f )
2758/// Iters: ({Par(), Par(), Par(), Red(), Red()})
2759/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2760/// ```
2761/// kw is unrolled, w is unrolled iff dilationW > 1.
2762///
2763/// or
2764///
2765/// ```
2766/// Op def: ( n, c, w, f, kw )
2767/// Iters: ({Par(), Par(), Par(), Red(), Red()})
2768/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
2769/// ```
2770/// kw is unrolled, w is unrolled iff dilationW > 1.
2771///
2772/// or
2773///
2774/// ```
2775/// Op def: ( n, w, c, kw )
2776/// Iters: ({Par(), Par(), Par(), Red()})
2777/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
2778/// ```
2779/// kw is unrolled, w is unrolled iff dilationW > 1.
2780struct Conv1DGenerator
2781 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
2782 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
2783 int dilationW)
2784 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
2785 strideW(strideW), dilationW(dilationW) {
2786 // Determine whether `linalgOp` can be generated with this generator
2787 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2788 return;
2789 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2790 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2791 resShaped = linalgOp.getDpsInitOperand(0)->get();
2792 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2793 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2794 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2795 if (!lhsShapedType || !rhsShapedType || !resShapedType)
2796 return;
2797 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2798 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2799 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2800 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2801 return;
2802
2803 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
2804 if (!reduceOp)
2805 return;
2806 redOp = reduceOp->getName().getIdentifier();
2807
2808 if (!setOperKind(reduceOp))
2809 return;
2810 auto maybeKind = getCombinerOpKind(reduceOp);
2811 if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2812 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2813 return;
2814 }
2815
2816 auto rhsRank = rhsShapedType.getRank();
2817 switch (oper) {
2818 case Conv:
2819 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2820 return;
2821 break;
2822 case Pool:
2823 if (rhsRank != 1)
2824 return;
2825 break;
2826 }
2827 // The op is now known to be valid.
2828 valid = true;
2829 }
2830
2831 /// Generate a vector implementation for:
2832 /// ```
2833 /// Op def: ( w, kw )
2834 /// Iters: ({Par(), Red()})
2835 /// Layout: {{w + kw}, {kw}, {w}}
2836 /// ```
2837 /// kw is always unrolled.
2838 ///
2839 /// or
2840 ///
2841 /// ```
2842 /// Op def: ( n, w, c, kw, f )
2843 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
2844 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
2845 /// ```
2846 /// kw is always unrolled.
2847 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2848 /// > 1.
2849 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
2850 if (!valid)
2851 return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
2852
2853 int64_t nSize, wSize, cSize, kwSize, fSize;
2854 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
2855 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
2856 switch (conv1DOpOrder) {
2857 case Conv1DOpOrder::W:
2858 // Initialize unused dimensions
2859 nSize = fSize = cSize = 0;
2860 // out{W}
2861 bindShapeDims(resShapedType, wSize);
2862 // kernel{kw}
2863 bindShapeDims(rhsShapedType, kwSize);
2864 lhsShape = {// iw = ow + kw - 1
2865 // (i.e. 16 convolved with 3 -> 14)
2866 (wSize + kwSize - 1)};
2867 rhsShape = {kwSize};
2868 resShape = {wSize};
2869 break;
2870 case Conv1DOpOrder::Nwc:
2871 // out{n, w, f}
2872 bindShapeDims(resShapedType, nSize, wSize, fSize);
2873 switch (oper) {
2874 case Conv:
2875 // kernel{kw, c, f}
2876 bindShapeDims(rhsShapedType, kwSize, cSize);
2877 break;
2878 case Pool:
2879 // kernel{kw}
2880 bindShapeDims(rhsShapedType, kwSize);
2881 cSize = fSize;
2882 break;
2883 }
2884 lhsShape = {nSize,
2885 // iw = ow * sw + kw * dw - 1
2886 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2887 // Perform the proper inclusive -> exclusive -> inclusive.
2888 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2889 1,
2890 cSize};
2891 switch (oper) {
2892 case Conv:
2893 rhsShape = {kwSize, cSize, fSize};
2894 break;
2895 case Pool:
2896 rhsShape = {kwSize};
2897 break;
2898 }
2899 resShape = {nSize, wSize, fSize};
2900 break;
2901 case Conv1DOpOrder::Ncw:
2902 // out{n, f, w}
2903 bindShapeDims(resShapedType, nSize, fSize, wSize);
2904 switch (oper) {
2905 case Conv:
2906 // kernel{f, c, kw}
2907 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
2908 break;
2909 case Pool:
2910 // kernel{kw}
2911 bindShapeDims(rhsShapedType, kwSize);
2912 cSize = fSize;
2913 break;
2914 }
2915 lhsShape = {nSize, cSize,
2916 // iw = ow * sw + kw * dw - 1
2917 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
2918 // Perform the proper inclusive -> exclusive -> inclusive.
2919 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2920 1};
2921 switch (oper) {
2922 case Conv:
2923 rhsShape = {fSize, cSize, kwSize};
2924 break;
2925 case Pool:
2926 rhsShape = {kwSize};
2927 break;
2928 }
2929 resShape = {nSize, fSize, wSize};
2930 break;
2931 }
2932
2933 vector::TransferWriteOp write;
2934 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
2935
2936 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
2937 // When strideW == 1, we can batch the contiguous loads and avoid
2938 // unrolling
2939 int64_t wSizeStep = strideW == 1 ? wSize : 1;
2940
2941 Type lhsEltType = lhsShapedType.getElementType();
2942 Type rhsEltType = rhsShapedType.getElementType();
2943 Type resEltType = resShapedType.getElementType();
2944 auto lhsType = VectorType::get(lhsShape, lhsEltType);
2945 auto rhsType = VectorType::get(rhsShape, rhsEltType);
2946 auto resType = VectorType::get(resShape, resEltType);
2947 // Zero padding with the corresponding dimensions for lhs, rhs and res.
2948 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
2949 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
2950 SmallVector<Value> resPadding(resShape.size(), zero);
2951
2952 // Read the whole lhs, rhs and res in one shot (with zero padding).
2953 Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2954 lhsPadding);
2955 // This is needed only for Conv.
2956 Value rhs = nullptr;
2957 if (oper == Conv)
2958 rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2959 rhsPadding);
2960 Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
2961 resPadding);
2962
2963 // The base vectorization case for channeled convolution is input:
2964 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
2965 // vectorization case, we do pre transpose on input, weight, and output.
2966 switch (conv1DOpOrder) {
2967 case Conv1DOpOrder::W:
2968 case Conv1DOpOrder::Nwc:
2969 // Base case, so no transposes necessary.
2970 break;
2971 case Conv1DOpOrder::Ncw: {
2972 // To match base vectorization case, we pre-transpose current case.
2973 // ncw -> nwc
2974 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2975 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
2976 // fcw -> wcf
2977 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2978
2979 // This is needed only for Conv.
2980 if (oper == Conv)
2981 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
2982 // nfw -> nwf
2983 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2984 res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
2985 break;
2986 }
2987 }
2988
2989 //===------------------------------------------------------------------===//
2990 // Begin vector-only rewrite part
2991 //===------------------------------------------------------------------===//
2992 // Unroll along kw and read slices of lhs and rhs.
2993 SmallVector<Value> lhsVals, rhsVals, resVals;
2994 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
2995 kwSize, strideW, dilationW, wSizeStep,
2996 isSingleChanneled);
2997 // Do not do for pooling.
2998 if (oper == Conv)
2999 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3000 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3001 wSizeStep, isSingleChanneled);
3002
3003 auto linearIndex = [&](int64_t kw, int64_t w) {
3004 return kw * (wSize / wSizeStep) + w;
3005 };
3006
3007 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3008 // or perform outerproduct for non-channeled convolution or perform simple
3009 // arith operation for pooling
3010 for (int64_t kw = 0; kw < kwSize; ++kw) {
3011 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3012 switch (oper) {
3013 case Conv:
3014 if (isSingleChanneled) {
3015 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3016 lhsVals[linearIndex(kw, w)],
3017 rhsVals[kw], resVals[w]);
3018 } else {
3019 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3020 lhsVals[linearIndex(kw, w)],
3021 rhsVals[kw], resVals[w]);
3022 }
3023 break;
3024 case Pool:
3025 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3026 resVals[w]);
3027 break;
3028 }
3029 }
3030 }
3031
3032 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3033 isSingleChanneled);
3034 //===------------------------------------------------------------------===//
3035 // End vector-only rewrite part
3036 //===------------------------------------------------------------------===//
3037
3038 // The base vectorization case for channeled convolution is output:
3039 // {n,w,f} To reuse the result from base pattern vectorization case, we
3040 // post transpose the base case result.
3041 switch (conv1DOpOrder) {
3042 case Conv1DOpOrder::W:
3043 case Conv1DOpOrder::Nwc:
3044 // Base case, so no transposes necessary.
3045 break;
3046 case Conv1DOpOrder::Ncw: {
3047 // nwf -> nfw
3048 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3049 res = rewriter.create<vector::TransposeOp>(loc, res, perm);
3050 break;
3051 }
3052 }
3053
3054 return rewriter
3055 .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3056 .getOperation();
3057 }
3058
3059 // Take a value and widen to have the same element type as `ty`.
3060 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3061 const Type srcElementType = getElementTypeOrSelf(type: val.getType());
3062 const Type dstElementType = getElementTypeOrSelf(type: ty);
3063 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3064 if (srcElementType == dstElementType)
3065 return val;
3066
3067 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3068 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3069 const Type dstType =
3070 cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3071
3072 if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) {
3073 return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
3074 }
3075
3076 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3077 srcWidth < dstWidth)
3078 return rewriter.create<arith::ExtFOp>(loc, dstType, val);
3079
3080 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3081 srcWidth < dstWidth)
3082 return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
3083
3084 assert(false && "unhandled promotion case");
3085 return nullptr;
3086 }
3087
3088 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3089 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3090 Value lhs, Value rhs, Value res) {
3091 vector::IteratorType par = vector::IteratorType::parallel;
3092 vector::IteratorType red = vector::IteratorType::reduction;
3093 AffineExpr n, w, f, c;
3094 bindDims(ctx, n, w, f, c);
3095 lhs = promote(rewriter, loc, val: lhs, ty: res.getType());
3096 rhs = promote(rewriter, loc, val: rhs, ty: res.getType());
3097 return rewriter.create<vector::ContractionOp>(
3098 loc, lhs, rhs, res,
3099 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3100 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3101 }
3102
3103 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3104 // convolution.
3105 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3106 Value lhs, Value rhs, Value res) {
3107 return rewriter.create<vector::OuterProductOp>(
3108 loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3109 }
3110
3111 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3112 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3113 Value res) {
3114 if (isPoolExt)
3115 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3116 return rewriter
3117 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3118 ->getResult(0);
3119 }
3120
3121 /// Generate a vector implementation for:
3122 /// ```
3123 /// Op def: ( n, w, c, kw)
3124 /// Iters: ({Par(), Par(), Par(), Red()})
3125 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3126 /// ```
3127 /// kw is always unrolled.
3128 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3129 /// > 1.
3130 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3131 bool channelDimScalableFlag,
3132 bool flatten) {
3133 if (!valid)
3134 return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
3135
3136 bool scalableChDim = false;
3137 bool useMasking = false;
3138 int64_t nSize, wSize, cSize, kwSize;
3139 // kernel{kw, c}
3140 bindShapeDims(rhsShapedType, kwSize, cSize);
3141 if (ShapedType::isDynamic(cSize)) {
3142 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3143 cSize = channelDimVecSize;
3144 // Scalable vectors are only used when both conditions are met:
3145 // 1. channel dim is dynamic
3146 // 2. channelDimScalableFlag is set
3147 scalableChDim = channelDimScalableFlag;
3148 useMasking = true;
3149 }
3150
3151 assert(!(useMasking && flatten) &&
3152 "Unsupported flattened conv with dynamic shapes");
3153
3154 // out{n, w, c}
3155 bindShapeDims(resShapedType, nSize, wSize);
3156
3157 vector::TransferWriteOp write;
3158 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3159
3160 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3161 // When strideW == 1, we can batch the contiguous loads and avoid
3162 // unrolling
3163 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3164
3165 Type lhsEltType = lhsShapedType.getElementType();
3166 Type rhsEltType = rhsShapedType.getElementType();
3167 Type resEltType = resShapedType.getElementType();
3168 VectorType lhsType = VectorType::get(
3169 {nSize,
3170 // iw = ow * sw + kw * dw - 1
3171 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3172 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3173 cSize},
3174 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3175 VectorType rhsType =
3176 VectorType::get({kwSize, cSize}, rhsEltType,
3177 /*scalableDims=*/{false, scalableChDim});
3178 VectorType resType =
3179 VectorType::get({nSize, wSize, cSize}, resEltType,
3180 /*scalableDims=*/{false, false, scalableChDim});
3181
3182 // Masks the input xfer Op along the channel dim, iff the corresponding
3183 // scalable flag is set.
3184 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3185 ArrayRef<bool> scalableDims,
3186 Operation *opToMask) {
3187 if (!useMasking)
3188 return opToMask;
3189 auto maskType =
3190 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3191
3192 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3193 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3194
3195 Value maskOp =
3196 rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3197
3198 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3199 };
3200
3201 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3202 // 0].
3203 Value lhs = rewriter.create<vector::TransferReadOp>(
3204 loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
3205 auto maybeMaskedLhs = maybeMaskXferOp(
3206 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3207
3208 // Read rhs slice of size {kw, c} @ [0, 0].
3209 Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3210 ValueRange{zero, zero});
3211 auto maybeMaskedRhs = maybeMaskXferOp(
3212 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3213
3214 // Read res slice of size {n, w, c} @ [0, 0, 0].
3215 Value res = rewriter.create<vector::TransferReadOp>(
3216 loc, resType, resShaped, ValueRange{zero, zero, zero});
3217 auto maybeMaskedRes = maybeMaskXferOp(
3218 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3219
3220 //===------------------------------------------------------------------===//
3221 // Begin vector-only rewrite part
3222 //===------------------------------------------------------------------===//
3223 // Unroll along kw and read slices of lhs and rhs.
3224 SmallVector<Value> lhsVals, rhsVals, resVals;
3225 auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
3226 auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
3227
3228 // Extract lhs slice of size {n, wSizeStep, c}
3229 // @ [0, sw * w + dw * kw, 0].
3230 for (int64_t kw = 0; kw < kwSize; ++kw) {
3231 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3232 lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3233 loc, maybeMaskedLhs->getResult(0),
3234 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3235 inOutSliceSizes, inOutStrides));
3236 }
3237 }
3238 // Extract rhs slice of size {c} @ [kw].
3239 for (int64_t kw = 0; kw < kwSize; ++kw) {
3240 rhsVals.push_back(rewriter.create<vector::ExtractOp>(
3241 loc, maybeMaskedRhs->getResult(0),
3242 /*offsets=*/ArrayRef<int64_t>{kw}));
3243 }
3244 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3245 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3246 resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
3247 loc, maybeMaskedRes->getResult(0),
3248 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3249 inOutStrides));
3250 }
3251
3252 auto linearIndex = [&](int64_t kw, int64_t w) {
3253 return kw * (wSize / wSizeStep) + w;
3254 };
3255
3256 // Note - the scalable flags are ignored as flattening combined with
3257 // scalable vectorization is not supported.
3258 auto inOutFlattenSliceSizes =
3259 SmallVector<int64_t>{nSize, wSizeStep * cSize};
3260 auto lhsTypeAfterFlattening =
3261 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3262 auto resTypeAfterFlattening =
3263 VectorType::get(inOutFlattenSliceSizes, resEltType);
3264
3265 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3266 for (int64_t kw = 0; kw < kwSize; ++kw) {
3267 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3268 Value lhsVal = lhsVals[linearIndex(kw, w)];
3269 Value resVal = resVals[w];
3270 if (flatten) {
3271 // Flatten the input and output vectors (collapse the channel
3272 // dimension)
3273 lhsVal = rewriter.create<vector::ShapeCastOp>(
3274 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3275 resVal = rewriter.create<vector::ShapeCastOp>(
3276 loc, resTypeAfterFlattening, resVals[w]);
3277 }
3278 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3279 rhsVals[kw], resVal, flatten);
3280 if (flatten) {
3281 // Un-flatten the output vector (restore the channel dimension)
3282 resVals[w] = rewriter.create<vector::ShapeCastOp>(
3283 loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
3284 }
3285 }
3286 }
3287
3288 // Its possible we failed to create the Fma.
3289 if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) {
3290 // Manually revert (in reverse order) to avoid leaving a bad IR state.
3291 for (auto &collection :
3292 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3293 for (Value v : collection)
3294 rewriter.eraseOp(v.getDefiningOp());
3295 return rewriter.notifyMatchFailure(op, "failed to create FMA");
3296 }
3297
3298 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3299 // This does not depend on kw.
3300 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3301 maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
3302 loc, resVals[w], maybeMaskedRes->getResult(0),
3303 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3304 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3305 }
3306 //===------------------------------------------------------------------===//
3307 // End vector-only rewrite part
3308 //===------------------------------------------------------------------===//
3309
3310 // Write back res slice of size {n, w, c} @ [0, 0, 0].
3311 Operation *resOut = rewriter.create<vector::TransferWriteOp>(
3312 loc, maybeMaskedRes->getResult(0), resShaped,
3313 ValueRange{zero, zero, zero});
3314 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3315 resOut);
3316 }
3317
3318 /// Lower:
3319 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3320 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3321 /// to MulAcc.
3322 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3323 Value lhs, Value rhs, Value res,
3324 bool flatten) {
3325 auto rhsTy = cast<ShapedType>(rhs.getType());
3326 auto resTy = cast<ShapedType>(res.getType());
3327
3328 // TODO(suderman): Change this to use a vector.ima intrinsic.
3329 lhs = promote(rewriter, loc, val: lhs, ty: resTy);
3330
3331 if (flatten) {
3332 // NOTE: This following logic won't work for scalable vectors. For this
3333 // reason, "flattening" is not supported when shapes are dynamic (this
3334 // should be captured by one of the pre-conditions).
3335
3336 // There are two options for handling the filter:
3337 // * shape_cast(broadcast(filter))
3338 // * broadcast(shuffle(filter))
3339 // Opt for the option without shape_cast to simplify the codegen.
3340 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3341 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3342
3343 SmallVector<int64_t, 16> indicies;
3344 for (int i = 0; i < resSize / rhsSize; ++i) {
3345 for (int j = 0; j < rhsSize; ++j)
3346 indicies.push_back(Elt: j);
3347 }
3348
3349 rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
3350 }
3351 // Broadcast the filter to match the output vector
3352 rhs = rewriter.create<vector::BroadcastOp>(
3353 loc, resTy.clone(rhsTy.getElementType()), rhs);
3354
3355 rhs = promote(rewriter, loc, val: rhs, ty: resTy);
3356
3357 if (!lhs || !rhs)
3358 return nullptr;
3359
3360 if (isa<FloatType>(resTy.getElementType()))
3361 return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
3362
3363 auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
3364 return rewriter.create<arith::AddIOp>(loc, mul, res);
3365 }
3366
3367 /// Entry point for non-channeled convolution:
3368 /// {{w + kw}, {kw}, {w}}
3369 FailureOr<Operation *> generateNonChanneledConv() {
3370 AffineExpr w, kw;
3371 bindDims(ctx, w, kw);
3372 if (!iters({Par(), Red()}))
3373 return rewriter.notifyMatchFailure(op,
3374 "failed to match conv::W 1-par 1-red");
3375
3376 // No transposition needed.
3377 if (layout({/*lhsIndex*/ {w + kw},
3378 /*rhsIndex*/ {kw},
3379 /*resIndex*/ {w}}))
3380 return conv(conv1DOpOrder: Conv1DOpOrder::W);
3381
3382 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3383 }
3384
3385 /// Entry point that transposes into the common form:
3386 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3387 FailureOr<Operation *> generateNwcConv() {
3388 AffineExpr n, w, f, kw, c;
3389 bindDims(ctx, n, w, f, kw, c);
3390 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3391 return rewriter.notifyMatchFailure(
3392 op, "failed to match conv::Nwc 3-par 2-red");
3393
3394 // No transposition needed.
3395 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3396 /*rhsIndex*/ {kw, c, f},
3397 /*resIndex*/ {n, w, f}}))
3398 return conv(conv1DOpOrder: Conv1DOpOrder::Nwc);
3399
3400 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
3401 }
3402
3403 /// Entry point that transposes into the common form:
3404 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3405 FailureOr<Operation *> generateNcwConv() {
3406 AffineExpr n, w, f, kw, c;
3407 bindDims(ctx, n, f, w, c, kw);
3408 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3409 return rewriter.notifyMatchFailure(
3410 op, "failed to match conv::Ncw 3-par 2-red");
3411
3412 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3413 /*rhsIndex*/ {f, c, kw},
3414 /*resIndex*/ {n, f, w}}))
3415 return conv(conv1DOpOrder: Conv1DOpOrder::Ncw);
3416
3417 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
3418 }
3419
3420 /// Entry point that transposes into the common form:
3421 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
3422 FailureOr<Operation *> generateNwcPooling() {
3423 AffineExpr n, w, c, kw;
3424 bindDims(ctx, n, w, c, kw);
3425 if (!iters({Par(), Par(), Par(), Red()}))
3426 return rewriter.notifyMatchFailure(op,
3427 "failed to match pooling 3-par 1-red");
3428
3429 // No transposition needed.
3430 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3431 /*rhsIndex*/ {kw},
3432 /*resIndex*/ {n, w, c}}))
3433 return conv(conv1DOpOrder: Conv1DOpOrder::Nwc);
3434
3435 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
3436 }
3437
3438 /// Entry point that transposes into the common form:
3439 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
3440 FailureOr<Operation *> generateNcwPooling() {
3441 AffineExpr n, w, c, kw;
3442 bindDims(ctx, n, c, w, kw);
3443 if (!iters({Par(), Par(), Par(), Red()}))
3444 return rewriter.notifyMatchFailure(op,
3445 "failed to match pooling 3-par 1-red");
3446
3447 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
3448 /*rhsIndex*/ {kw},
3449 /*resIndex*/ {n, c, w}}))
3450 return conv(conv1DOpOrder: Conv1DOpOrder::Ncw);
3451
3452 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
3453 }
3454
3455 /// Entry point that transposes into the common form:
3456 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3457 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3458 bool vecChDimScalableFlag = false,
3459 bool flatten = false) {
3460 AffineExpr n, w, c, kw;
3461 bindDims(ctx, n, w, c, kw);
3462 if (!iters({Par(), Par(), Par(), Red()}))
3463 return rewriter.notifyMatchFailure(
3464 op, "failed to match depthwise::Nwc conv 3-par 1-red");
3465
3466 // No transposition needed.
3467 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3468 /*rhsIndex*/ {kw, c},
3469 /*resIndex*/ {n, w, c}}))
3470 return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten);
3471
3472 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
3473 }
3474
3475private:
3476 enum OperKind { Conv, Pool };
3477 bool valid = false;
3478 OperKind oper = Conv;
3479 StringAttr redOp;
3480 StringAttr poolExtOp;
3481 bool isPoolExt = false;
3482 int strideW, dilationW;
3483 Value lhsShaped, rhsShaped, resShaped;
3484 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3485
3486 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
3487 // Returns true iff it is a valid conv/pooling op.
3488 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
3489 // + yield) and rhs is not used) then it is the body of a pooling
3490 // If conv, check for single `mul` predecessor. The `mul` operands must be
3491 // block arguments or extension of block arguments.
3492 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
3493 // must be block arguments or extension of block arguments.
3494 bool setOperKind(Operation *reduceOp) {
3495 int numBlockArguments =
3496 llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>);
3497 switch (numBlockArguments) {
3498 case 1: {
3499 // Will be convolution if feeder is a MulOp.
3500 // Otherwise, if it can be pooling.
3501 auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(),
3502 P: llvm::IsaPred<BlockArgument>);
3503 Operation *feedOp = (*feedValIt).getDefiningOp();
3504 if (isCastOfBlockArgument(op: feedOp)) {
3505 oper = Pool;
3506 isPoolExt = true;
3507 poolExtOp = feedOp->getName().getIdentifier();
3508 } else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3509 llvm::all_of(feedOp->getOperands(), [](Value v) {
3510 if (isa<BlockArgument>(v))
3511 return true;
3512 if (Operation *op = v.getDefiningOp())
3513 return isCastOfBlockArgument(op);
3514 return false;
3515 }))) {
3516 return false;
3517 }
3518 return true;
3519 }
3520 case 2:
3521 // Must be pooling
3522 oper = Pool;
3523 isPoolExt = false;
3524 return true;
3525 default:
3526 return false;
3527 }
3528 }
3529};
3530} // namespace
3531
3532/// Helper function to vectorize a LinalgOp with convolution semantics.
3533// TODO: extend the generic vectorization to support windows and drop this.
3534static FailureOr<Operation *> vectorizeConvolution(
3535 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3536 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
3537 // The ConvolutionOpInterface gives us guarantees of existence for
3538 // strides/dilations. However, we do not need to rely on those, we can
3539 // simply use them if present, otherwise use the default and let the generic
3540 // conv. matcher in the ConvGenerator succeed or fail.
3541 auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
3542 auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
3543 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3544 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3545 Conv1DGenerator e(rewriter, op, stride, dilation);
3546 auto res = e.generateNonChanneledConv();
3547 if (succeeded(res))
3548 return res;
3549 res = e.generateNwcConv();
3550 if (succeeded(res))
3551 return res;
3552 res = e.generateNcwConv();
3553 if (succeeded(res))
3554 return res;
3555 res = e.generateNwcPooling();
3556 if (succeeded(res))
3557 return res;
3558 res = e.generateNcwPooling();
3559 if (succeeded(res))
3560 return res;
3561
3562 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
3563 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
3564 // masked/scalable) is the channel dim (i.e. the trailing dim).
3565 uint64_t vecChDimSize = ShapedType::kDynamic;
3566 bool vecChDimScalableFlag = false;
3567 if (!inputVecSizes.empty()) {
3568 // Only use the input vector size corresponding to the channel dim. Other
3569 // vector dims will be inferred from the Ops.
3570 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3571 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3572 "Not a 1D depthwise conv!");
3573 size_t chDimIdx =
3574 TypeSwitch<Operation *, size_t>(op)
3575 .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
3576 .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
3577
3578 vecChDimSize = inputVecSizes[chDimIdx];
3579 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3580 }
3581 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3582 flatten: flatten1DDepthwiseConv);
3583}
3584
3585struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
3586 using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3587
3588 LogicalResult matchAndRewrite(LinalgOp op,
3589 PatternRewriter &rewriter) const override {
3590 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
3591 if (failed(result: resultOrFail))
3592 return failure();
3593 Operation *newOp = *resultOrFail;
3594 if (newOp->getNumResults() == 0) {
3595 rewriter.eraseOp(op: op.getOperation());
3596 return success();
3597 }
3598 assert(newOp->getNumResults() == 1 && "expected single result");
3599 rewriter.replaceOp(op.getOperation(), newOp->getResult(idx: 0));
3600 return success();
3601 }
3602};
3603
3604void mlir::linalg::populateConvolutionVectorizationPatterns(
3605 RewritePatternSet &patterns, PatternBenefit benefit) {
3606 patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit);
3607}
3608

source code of mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp