1//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites as 1->N patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14
15#include <cassert>
16#include <cstdint>
17#include <functional>
18#include <optional>
19
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Arith/Utils/Utils.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26#include "mlir/Dialect/Vector/IR/VectorOps.h"
27#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29#include "mlir/IR/BuiltinTypes.h"
30#include "mlir/IR/Location.h"
31#include "mlir/IR/Matchers.h"
32#include "mlir/IR/PatternMatch.h"
33#include "mlir/IR/TypeUtilities.h"
34
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/FormatVariadic.h"
37
38#define DEBUG_TYPE "vector-to-vector"
39
40using namespace mlir;
41using namespace mlir::vector;
42
43template <typename IntType>
44static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
48}
49
50// Helper to find an index in an affine map.
51static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
52 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
53 int64_t idx = map.getDimPosition(idx: i);
54 if (idx == index)
55 return i;
56 }
57 return std::nullopt;
58}
59
60namespace {
61
62/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
63/// Ex:
64/// ```
65/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
66/// %1 = vector.multi_reduction add, %0 [1]
67/// : vector<8x32x16xf32> to vector<8x16xf32>
68/// ```
69/// Gets converted to:
70/// ```
71/// %1 = vector.contract {indexing_maps = [
72/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
73/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74/// affine_map<(d0, d1, d2) -> (d0, d1)>],
75/// iterator_types = ["parallel", "parallel", "reduction"],
76/// kind = add} %0, %arg1, %cst_f0
77/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78/// ```
79struct MultiReduceToContract
80 : public OpRewritePattern<vector::MultiDimReductionOp> {
81 using OpRewritePattern::OpRewritePattern;
82
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84 PatternRewriter &rewriter) const override {
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
86 return failure();
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(Val: mulOp))
89 return failure();
90 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
91 auto srcMap = rewriter.getMultiDimIdentityMap(rank: reductionMask.size());
92 SmallVector<AffineExpr> exprs;
93 SmallVector<vector::IteratorType> iteratorTypes;
94 for (const auto &isReduceDim : llvm::enumerate(First&: reductionMask)) {
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(Elt: vector::IteratorType::parallel);
97 exprs.push_back(Elt: rewriter.getAffineDimExpr(position: isReduceDim.index()));
98 } else {
99 iteratorTypes.push_back(Elt: vector::IteratorType::reduction);
100 }
101 }
102 auto dstMap =
103 AffineMap::get(/*dimCount=*/reductionMask.size(),
104 /*symbolCount=*/0, results: exprs, context: reduceOp.getContext());
105 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
106 op: reduceOp, args: mulOp->getOperand(idx: 0), args: mulOp->getOperand(idx: 1), args: reduceOp.getAcc(),
107 args: rewriter.getAffineMapArrayAttr(values: {srcMap, srcMap, dstMap}),
108 args: rewriter.getArrayAttr(value: llvm::to_vector(Range: llvm::map_range(
109 C&: iteratorTypes, F: [&](IteratorType t) -> mlir::Attribute {
110 return IteratorTypeAttr::get(context: rewriter.getContext(), value: t);
111 }))));
112 return success();
113 }
114};
115
116/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
117/// Ex:
118/// ```
119/// %0 = vector.transpose %arg0, [2, 0, 1]
120/// : vector<32x16x8xf32> to vector<8x32x16xf32>
121/// %1 = vector.contract {indexing_maps = [
122/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
123/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124/// affine_map<(d0, d1, d2) -> (d0, d1)>],
125/// iterator_types = ["parallel", "parallel", "reduction"],
126/// kind = add} %0, %arg1, %cst_f0
127/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
128/// ```
129/// Gets converted to:
130/// ```
131/// %1 = vector.contract {indexing_maps = [
132/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
133/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134/// affine_map<(d0, d1, d2) -> (d0, d1)>],
135/// iterator_types = ["parallel", "parallel", "reduction"],
136/// kind = add} %arg0, %arg1, %cst_f0
137/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138/// ```
139struct CombineContractABTranspose final
140 : public OpRewritePattern<vector::ContractionOp> {
141 using OpRewritePattern::OpRewritePattern;
142
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144 PatternRewriter &rewriter) const override {
145 SmallVector<AffineMap> maps =
146 llvm::to_vector<4>(Range: contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
149 size_t index = 0;
150 bool changed = false;
151 for (Value *operand : {&lhs, &rhs}) {
152 AffineMap &map = maps[index++];
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
154 if (!transposeOp)
155 continue;
156 AffineMap permutationMap = AffineMap::getPermutationMap(
157 permutation: transposeOp.getPermutation(), context: contractOp.getContext());
158 map = inversePermutation(map: permutationMap).compose(map);
159 *operand = transposeOp.getVector();
160 changed = true;
161 }
162 if (!changed)
163 return failure();
164 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
165 op: contractOp, args&: lhs, args&: rhs, args: contractOp.getAcc(),
166 args: rewriter.getAffineMapArrayAttr(values: maps), args: contractOp.getIteratorTypes());
167 return success();
168 }
169};
170
171/// Merges accumulator and result transposes into contract.
172///
173/// For example:
174/// ```mlir
175/// %accT = vector.transpose %acc, [0, 2, 1]
176/// : vector<2x8x4xf32> to vector<2x4x8xf32>
177/// %contract = vector.contract {
178/// indexing_maps = [
179/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
180/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
181/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
182/// ],
183/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
184/// kind = #vector.kind<add>
185/// } %lhs, %rhs, %accT
186/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
187/// %0 = vector.transpose %contract, [0, 2, 1]
188/// : vector<2x4x8xf32> to vector<2x8x4>
189/// ```
190/// Becomes:
191/// ```mlir
192/// %0 = vector.contract {
193/// indexing_maps = [
194/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
195/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
196/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
197/// ],
198/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
199/// kind = #vector.kind<add>
200/// } %lhs, %rhs, %acc
201/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
202/// ```
203struct CombineContractResultTranspose final
204 : public OpRewritePattern<vector::TransposeOp> {
205 using OpRewritePattern::OpRewritePattern;
206
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208 PatternRewriter &rewriter) const override {
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
211 return failure();
212
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
214 if (!accTOp)
215 return failure();
216
217 MLIRContext *context = contractOp.getContext();
218 auto maps = llvm::to_vector<3>(Range: contractOp.getIndexingMapsArray());
219 AffineMap contractMap = maps.back();
220
221 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
222 // To index into A in contract, we need revert(f)(g(C)) -> A.
223 auto accTMap =
224 AffineMap::getPermutationMap(permutation: accTOp.getPermutation(), context);
225
226 // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
227 // To index into E in contract, we need h(g(C)) -> E.
228 auto resTMap =
229 AffineMap::getPermutationMap(permutation: resTOp.getPermutation(), context);
230 auto combinedResMap = resTMap.compose(map: contractMap);
231
232 // The accumulator and result share the same indexing map. So they should be
233 // the same to be able to merge. This means combinedResMap is the same as
234 // inversePermutation(accTMap).compose(contractMap), which means
235 if (inversePermutation(map: accTMap) != resTMap)
236 return failure();
237 maps.back() = combinedResMap;
238
239 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
240 op: resTOp, args: contractOp.getLhs(), args: contractOp.getRhs(), args: accTOp.getVector(),
241 args: rewriter.getAffineMapArrayAttr(values: maps), args: contractOp.getIteratorTypes());
242 return success();
243 }
244};
245
246/// Merge BroadcastOp into ContractionOp user.
247/// Ex:
248/// ```
249/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
250/// %1 = vector.contract {indexing_maps = [
251/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
252/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253/// affine_map<(d0, d1, d2) -> (d0, d1)>],
254/// iterator_types = ["parallel", "parallel", "reduction"],
255/// kind = add} %0, %arg1, %cst_f0
256/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
257/// ```
258/// Gets converted to:
259/// ```
260/// %1 = vector.contract {indexing_maps = [
261/// affine_map<(d0, d1, d2) -> (d1, d2)>,
262/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
263/// affine_map<(d0, d1, d2) -> (d0, d1)>],
264/// iterator_types = ["parallel", "parallel", "reduction"],
265/// kind = add} %arg0, %arg1, %cst_f0
266/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267/// ```
268///
269/// For masked vector.contract, the mask requires updating when a dimension is
270/// dropped. In such cases, the dropped dimensions must correspond to the mask's
271/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272/// is not supported.
273FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
275 PatternRewriter &rewriter) {
276 SmallVector<AffineMap> maps =
277 llvm::to_vector<4>(Range: contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
280 size_t index = 0;
281 bool changed = false;
282 for (Value *operand : {&lhs, &rhs}) {
283 AffineMap &map = maps[index++];
284 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
285 if (!broadcast)
286 continue;
287 // contractionOp can only take vector as operands.
288 auto srcType = dyn_cast<VectorType>(Val: broadcast.getSourceType());
289 if (!srcType ||
290 srcType.getRank() == broadcast.getResultVectorType().getRank())
291 continue;
292 int64_t rankDiff =
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast = false;
295 SmallVector<AffineExpr> originalDims;
296 for (const auto &dim : llvm::enumerate(First: srcType.getShape())) {
297 if (dim.value() !=
298 broadcast.getResultVectorType().getDimSize(idx: rankDiff + dim.index())) {
299 innerDimBroadcast = true;
300 break;
301 }
302 originalDims.push_back(Elt: rewriter.getAffineDimExpr(position: dim.index() + rankDiff));
303 }
304 // Contract doesn't support inner dimension broadcast. Once this is
305 // relaxed we can remove this case.
306 if (innerDimBroadcast)
307 continue;
308
309 // It would be incorrect to fold a broadcast onto a reduction dimension
310 // of non-unit size.
311 bool nonUnitDimReductionBroadcast = false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (broadcast.getResultVectorType().getDimSize(idx: i) != 1 &&
314 isReductionIterator(attr: contractOp.getIteratorTypes()
315 .getValue()[map.getDimPosition(idx: i)])) {
316 nonUnitDimReductionBroadcast = true;
317 break;
318 }
319 }
320 if (nonUnitDimReductionBroadcast)
321 continue;
322
323 AffineMap broadcastMap =
324 AffineMap::get(dimCount: broadcast.getResultVectorType().getRank(), symbolCount: 0,
325 results: originalDims, context: contractOp.getContext());
326 map = broadcastMap.compose(map);
327 *operand = broadcast.getSource();
328 changed = true;
329 }
330
331 if (!changed)
332 return failure();
333
334 // Determine which dims are usused, now that the maps have been composed
335 // with the broadcast maps.
336 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
337 // Compress unused dims.
338 for (auto &m : maps)
339 m = compressDims(map: m, unusedDims: unusedDimsBitVector);
340 // Compute the combined iterators.
341 SmallVector<Attribute> iterators;
342 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(Idx: i))
344 iterators.push_back(Elt: contractOp.getIteratorTypes().getValue()[i]);
345 }
346
347 // Check whether any of the unused dims is non-unit, e.g.:
348 // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349 // This is only required when collapsing a mask. If there is no mask, skip.
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit = false;
352 if (maskingOp) {
353 oldMaskType = cast<VectorType>(Val: maskingOp.getMask().getType());
354 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(Idx: i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit = true;
357 break;
358 }
359 }
360 }
361
362 // Check that compressing unused dims isn't removing all reduction dimension
363 // pairs. For example, if the vector.contract had only one reduction
364 // iterator and that was a unit-dimension created by a broadcast,
365 // then we should bail here, otherwise we would create a contract without
366 // a reduction dimension pair.
367 bool hasReductionIteratorApplyingOnBothSides = false;
368 for (unsigned i = 0; i < iterators.size(); ++i) {
369 if (!isReductionIterator(attr: iterators[i]))
370 continue;
371 if (getResultIndex(map: maps[0], index: i) && getResultIndex(map: maps[1], index: i)) {
372 hasReductionIteratorApplyingOnBothSides = true;
373 break;
374 }
375 }
376 if (!hasReductionIteratorApplyingOnBothSides)
377 return failure();
378
379 // If the compressed maps have a dimension that is not used by either LHS or
380 // RHS then the ContractionOp verifier would fail.
381 if (getUnusedDimsBitVector(maps: {maps[0], maps[1]}).any())
382 return failure();
383
384 Operation *newOp = rewriter.create<vector::ContractionOp>(
385 location: contractOp.getLoc(), args&: lhs, args&: rhs, args: contractOp.getAcc(),
386 args: rewriter.getAffineMapArrayAttr(values: maps), args: rewriter.getArrayAttr(value: iterators));
387
388 // Handle the mask.
389 if (maskingOp) {
390 if (isAnyUnusedDimNonUnit)
391 return rewriter.notifyMatchFailure(arg&: contractOp,
392 msg: "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
396
397 // If a dimension has been dropped, update the mask accordingly. Otherwise,
398 // keep it as is.
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
401 // At this point, two assumptions are made:
402 // * The unused dimensions are the leading mask dimensions
403 // (vector.contract does not support inner dim broadcasting).
404 // * The unused dimensions are all unit.
405 // These conditions are effectively verified in the blocks preceeding this
406 // one.
407 auto newShape =
408 oldMaskType.getShape().drop_front(N: unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(N: unusedDimsBitVector.count());
411 VectorType maskOpType =
412 VectorType::get(shape: newShape, elementType: rewriter.getI1Type(), scalableDims: newShapeScalableDims);
413 mask = rewriter
414 .create<vector::ShapeCastOp>(location: contractOp.getLoc(), args&: maskOpType,
415 args: maskingOp.getMask())
416 .getResult();
417 }
418
419 newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask);
420 }
421 return newOp->getResult(idx: 0);
422}
423
424struct CombineContractBroadcastMask
425 : public MaskableOpRewritePattern<vector::ContractionOp> {
426 using MaskableOpRewritePattern::MaskableOpRewritePattern;
427 FailureOr<Value>
428
429 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430 MaskingOpInterface maskingOp,
431 PatternRewriter &rewriter) const override {
432 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
433 }
434};
435
436/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
437/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
438/// casting ops are around these operations.
439/// Ex:
440/// ```
441/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
442/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
443/// ```
444/// Gets converted to:
445/// ```
446/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
447/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
448/// ```
449struct ReorderCastOpsOnBroadcast
450 : public OpInterfaceRewritePattern<CastOpInterface> {
451 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOpInterface op,
454 PatternRewriter &rewriter) const override {
455 if (op->getNumOperands() != 1)
456 return failure();
457 auto bcastOp = op->getOperand(idx: 0).getDefiningOp<vector::BroadcastOp>();
458 if (!bcastOp)
459 return failure();
460
461 Type castResTy = getElementTypeOrSelf(val: op->getResult(idx: 0));
462 if (auto vecTy = dyn_cast<VectorType>(Val: bcastOp.getSourceType()))
463 castResTy = vecTy.clone(elementType: castResTy);
464 auto *castOp =
465 rewriter.create(loc: op->getLoc(), opName: op->getName().getIdentifier(),
466 operands: bcastOp.getSource(), types: castResTy, attributes: op->getAttrs());
467 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
468 op, args: op->getResult(idx: 0).getType(), args: castOp->getResult(idx: 0));
469 return success();
470 }
471};
472
473/// Reorders elementwise(transpose) to transpose(elementwise). This makes
474/// transpose ops and contraction ops closer, which kicks in
475/// CombineContractABTranspose pattern when elementwise ops are between these
476/// operations. Ex:
477/// ```
478/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
479/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
480/// %r = arith.addf %at, %bt : vector<2x4xf32>
481/// ```
482/// Gets converted to:
483/// ```
484/// %0 = arith.addf %a, %b : vector<4x2xf32>
485/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
486/// ```
487struct ReorderElementwiseOpsOnTranspose final
488 : public OpTraitRewritePattern<OpTrait::Elementwise> {
489 using OpTraitRewritePattern::OpTraitRewritePattern;
490 LogicalResult matchAndRewrite(Operation *op,
491 PatternRewriter &rewriter) const override {
492 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
493 return failure();
494
495 // Make sure all operands are transpose/constant ops and collect their
496 // transposition maps.
497 SmallVector<ArrayRef<int64_t>> transposeMaps;
498 transposeMaps.reserve(N: op->getNumOperands());
499 // Record the initial type before transposition. We'll use its shape later.
500 // Any type will do here as we will check all transpose maps are the same.
501 VectorType srcType;
502 for (Value operand : op->getOperands()) {
503 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
504 if (transposeOp) {
505 transposeMaps.push_back(Elt: transposeOp.getPermutation());
506 srcType = transposeOp.getSourceVectorType();
507 } else if (!matchPattern(value: operand, pattern: m_Constant())) {
508 return failure();
509 }
510 }
511 if (transposeMaps.empty())
512 return failure();
513 // This is an elementwise op, so all transposed operands should have the
514 // same type. We need to additionally check that all transposes uses the
515 // same map.
516 if (!llvm::all_equal(Range&: transposeMaps))
517 return rewriter.notifyMatchFailure(arg&: op, msg: "different transpose map");
518
519 SmallVector<Value> srcValues;
520 srcValues.reserve(N: op->getNumOperands());
521
522 // If there are constant operands, we need to insert inverse transposes for
523 // them. Calculate the inverse order first.
524 auto order = transposeMaps.front();
525 SmallVector<int64_t> invOrder(order.size());
526 for (int i = 0, e = order.size(); i < e; ++i)
527 invOrder[order[i]] = i;
528
529 for (Value operand : op->getOperands()) {
530 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
531 if (transposeOp) {
532 srcValues.push_back(Elt: transposeOp.getVector());
533 } else {
534 // This is a constant. Create a reverse transpose op for it.
535 auto vectorType =
536 srcType.clone(elementType: cast<VectorType>(Val: operand.getType()).getElementType());
537 srcValues.push_back(Elt: rewriter.create<vector::TransposeOp>(
538 location: operand.getLoc(), args&: vectorType, args&: operand, args&: invOrder));
539 }
540 }
541
542 auto vectorType = srcType.clone(
543 elementType: cast<VectorType>(Val: op->getResultTypes()[0]).getElementType());
544 Operation *elementwiseOp =
545 rewriter.create(loc: op->getLoc(), opName: op->getName().getIdentifier(), operands: srcValues,
546 types: vectorType, attributes: op->getAttrs());
547 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
548 op, args: op->getResultTypes()[0], args: elementwiseOp->getResult(idx: 0),
549 args&: transposeMaps.front());
550 return success();
551 }
552};
553
554// Returns the values in `arrayAttr` as an integer vector.
555static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
556 return llvm::to_vector<4>(
557 Range: llvm::map_range(C: arrayAttr.getAsRange<IntegerAttr>(),
558 F: [](IntegerAttr attr) { return attr.getInt(); }));
559}
560
561// Shuffles vector.bitcast op after vector.extract op.
562//
563// This transforms IR like:
564// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
565// %1 = vector.extract %0[3] : f16 from vector<8xf16>
566// Into:
567// %0 = vector.extract %src[1] : f32 from vector<4xf32>
568// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
569// %2 = vector.extract %1[1] : f16 from vector<2xf16>
570struct BubbleDownVectorBitCastForExtract
571 : public OpRewritePattern<vector::ExtractOp> {
572 using OpRewritePattern::OpRewritePattern;
573
574 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
575 PatternRewriter &rewriter) const override {
576 // Only support extracting scalars for now.
577 if (extractOp.getSourceVectorType().getRank() != 1)
578 return failure();
579
580 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
581 if (!castOp)
582 return failure();
583
584 VectorType castSrcType = castOp.getSourceVectorType();
585 VectorType castDstType = castOp.getResultVectorType();
586 assert(castSrcType.getRank() == castDstType.getRank());
587
588 // Fail to match if we only have one element in the cast op source.
589 // This is to avoid infinite loop given that this pattern can generate
590 // such cases.
591 if (castSrcType.getNumElements() == 1)
592 return failure();
593
594 // Only support casting to a larger number of elements or now.
595 // E.g., vector<4xf32> -> vector<8xf16>.
596 if (castSrcType.getNumElements() > castDstType.getNumElements())
597 return failure();
598
599 unsigned expandRatio =
600 castDstType.getNumElements() / castSrcType.getNumElements();
601
602 // Get the first element of the mixed position as integer.
603 auto mixedPos = extractOp.getMixedPosition();
604 if (mixedPos.size() > 0 && !isa<Attribute>(Val: mixedPos[0]))
605 return failure();
606 uint64_t index = cast<IntegerAttr>(Val: cast<Attribute>(Val&: mixedPos[0])).getInt();
607
608 // Get the single scalar (as a vector) in the source value that packs the
609 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
610 Location loc = extractOp.getLoc();
611 Value packedValue = rewriter.create<vector::ExtractOp>(
612 location: loc, args: castOp.getSource(), args: index / expandRatio);
613 Type packedVecType = VectorType::get(/*shape=*/{1}, elementType: packedValue.getType());
614 Value zero = rewriter.create<arith::ConstantOp>(
615 location: loc, args&: packedVecType, args: rewriter.getZeroAttr(type: packedVecType));
616 packedValue = rewriter.create<vector::InsertOp>(location: loc, args&: packedValue, args&: zero,
617 /*position=*/args: 0);
618
619 // Cast it to a vector with the desired scalar's type.
620 // E.g. f32 -> vector<2xf16>
621 VectorType packedType =
622 VectorType::get(shape: {expandRatio}, elementType: castDstType.getElementType());
623 Value castedValue =
624 rewriter.create<vector::BitCastOp>(location: loc, args&: packedType, args&: packedValue);
625
626 // Finally extract the desired scalar.
627 rewriter.replaceOpWithNewOp<vector::ExtractOp>(op: extractOp, args&: castedValue,
628 args: index % expandRatio);
629 return success();
630 }
631};
632
633// Shuffles vector.bitcast op after vector.extract_strided_slice op.
634//
635// This transforms IR like:
636// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
637// %0 = vector.extract_strided_slice %cast {
638// offsets = [4], sizes = [4], strides = [1]
639// } : vector<8xf16> to vector<4xf16>
640// Into:
641// %0 = vector.extract_strided_slice %src {
642// offsets = [2], sizes = [2], strides = [1]
643// } : vector<4xf32> to vector<2xf32>
644// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
645struct BubbleDownBitCastForStridedSliceExtract
646 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
647 using OpRewritePattern::OpRewritePattern;
648
649 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
650 PatternRewriter &rewriter) const override {
651 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
652 if (!castOp)
653 return failure();
654
655 VectorType castSrcType = castOp.getSourceVectorType();
656 VectorType castDstType = castOp.getResultVectorType();
657 assert(castSrcType.getRank() == castDstType.getRank());
658
659 int64_t castSrcLastDim = castSrcType.getShape().back();
660 int64_t castDstLastDim = castDstType.getShape().back();
661 // Require casting to more elements for now; other cases to be implemented.
662 if (castSrcLastDim > castDstLastDim)
663 return failure();
664
665 // Only accept all one strides for now.
666 if (llvm::any_of(Range: extractOp.getStrides().getAsValueRange<IntegerAttr>(),
667 P: [](const APInt &val) { return !val.isOne(); }))
668 return failure();
669
670 unsigned rank = extractOp.getSourceVectorType().getRank();
671 assert(castDstLastDim % castSrcLastDim == 0);
672 int64_t expandRatio = castDstLastDim / castSrcLastDim;
673
674 // If we have a less number of offsets than the rank, then implicitly we
675 // are selecting the full range for the last bitcasted dimension; other
676 // dimensions aren't affected. Otherwise, we need to scale down the last
677 // dimension's offset given we are extracting from less elements now.
678 ArrayAttr newOffsets = extractOp.getOffsets();
679 if (newOffsets.size() == rank) {
680 SmallVector<int64_t> offsets = getIntValueVector(arrayAttr: newOffsets);
681 if (offsets.back() % expandRatio != 0)
682 return failure();
683 offsets.back() = offsets.back() / expandRatio;
684 newOffsets = rewriter.getI64ArrayAttr(values: offsets);
685 }
686
687 // Similarly for sizes.
688 ArrayAttr newSizes = extractOp.getSizes();
689 if (newSizes.size() == rank) {
690 SmallVector<int64_t> sizes = getIntValueVector(arrayAttr: newSizes);
691 if (sizes.back() % expandRatio != 0)
692 return failure();
693 sizes.back() = sizes.back() / expandRatio;
694 newSizes = rewriter.getI64ArrayAttr(values: sizes);
695 }
696
697 SmallVector<int64_t> dims =
698 llvm::to_vector<4>(Range: cast<VectorType>(Val: extractOp.getType()).getShape());
699 dims.back() = dims.back() / expandRatio;
700 VectorType newExtractType =
701 VectorType::get(shape: dims, elementType: castSrcType.getElementType());
702
703 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
704 location: extractOp.getLoc(), args&: newExtractType, args: castOp.getSource(), args&: newOffsets,
705 args&: newSizes, args: extractOp.getStrides());
706
707 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
708 op: extractOp, args: extractOp.getType(), args&: newExtractOp);
709
710 return success();
711 }
712};
713
714// Shuffles vector.bitcast op before vector.insert_strided_slice op.
715//
716// This transforms IR like:
717// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
718// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
719// Into:
720// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
721// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
722// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
723//
724struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
725 using OpRewritePattern::OpRewritePattern;
726
727 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
728 PatternRewriter &rewriter) const override {
729 VectorType castSrcType = bitcastOp.getSourceVectorType();
730 VectorType castDstType = bitcastOp.getResultVectorType();
731
732 // 0-D and scalable vectors are not supported yet.
733 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
734 castDstType.isScalable())
735 return failure();
736
737 int64_t castSrcLastDim = castSrcType.getShape().back();
738 int64_t castDstLastDim = castDstType.getShape().back();
739 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
740 int64_t ratio;
741 if (isNumElemsShrink) {
742 assert(castSrcLastDim % castDstLastDim == 0);
743 ratio = castSrcLastDim / castDstLastDim;
744 } else {
745 assert(castDstLastDim % castSrcLastDim == 0);
746 ratio = castDstLastDim / castSrcLastDim;
747 }
748
749 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
750 if (!insertOp)
751 return failure();
752
753 // Only vector sources are supported for now.
754 auto insertSrcType = dyn_cast<VectorType>(Val: insertOp.getValueToStoreType());
755 if (!insertSrcType)
756 return failure();
757
758 // Bitcast the source.
759 SmallVector<int64_t> srcDims(insertSrcType.getShape());
760 srcDims.back() =
761 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
762 VectorType newCastSrcType =
763 VectorType::get(shape: srcDims, elementType: castDstType.getElementType());
764 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
765 location: bitcastOp.getLoc(), args&: newCastSrcType, args: insertOp.getValueToStore());
766
767 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
768 dstDims.back() =
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
771 VectorType::get(shape: dstDims, elementType: castDstType.getElementType());
772
773 // Bitcast the destination.
774 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
775 location: bitcastOp.getLoc(), args&: newCastDstType, args: insertOp.getDest());
776
777 // Generate new insert.
778 rewriter.replaceOpWithNewOp<vector::InsertOp>(
779 op: bitcastOp, args&: newCastSrcOp, args&: newCastDstOp, args: insertOp.getMixedPosition());
780 return success();
781 }
782};
783
784// Shuffles vector.bitcast op before vector.insert_strided_slice op.
785//
786// This transforms IR like:
787// %0 = vector.insert_strided_slice %src, %dst {
788// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
789// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
790// Into:
791// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
792// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
793// %2 = vector.insert_strided_slice %src, %dst {
794// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
795struct BubbleUpBitCastForStridedSliceInsert
796 : public OpRewritePattern<vector::BitCastOp> {
797 using OpRewritePattern::OpRewritePattern;
798
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
800 PatternRewriter &rewriter) const override {
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
804 // Skip 0-D vector which will not from InsertStridedSliceOp.
805 if (castSrcType.getRank() == 0)
806 return failure();
807
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
810 // Require casting to less elements for now; other cases to be implemented.
811 if (castSrcLastDim < castDstLastDim)
812 return failure();
813
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
816
817 auto insertOp =
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
819 if (!insertOp)
820 return failure();
821
822 // Only accept all one strides for now.
823 if (llvm::any_of(Range: insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 P: [](const APInt &val) { return !val.isOne(); }))
825 return failure();
826
827 unsigned rank = insertOp.getSourceVectorType().getRank();
828 // Require insert op to have the same rank for the source and destination
829 // vector; other cases to be implemented.
830 if (rank != insertOp.getDestVectorType().getRank())
831 return failure();
832
833 // Requires that shape of insert op src is castable to dstType.
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
839 return failure();
840
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
843 SmallVector<int64_t> offsets = getIntValueVector(arrayAttr: newOffsets);
844 if (offsets.back() % shrinkRatio != 0)
845 return failure();
846 offsets.back() = offsets.back() / shrinkRatio;
847 newOffsets = rewriter.getI64ArrayAttr(values: offsets);
848
849 SmallVector<int64_t> srcDims =
850 llvm::to_vector<4>(Range: insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
853 VectorType::get(shape: srcDims, elementType: castDstType.getElementType());
854
855 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
856 location: bitcastOp.getLoc(), args&: newCastSrcType, args: insertOp.getValueToStore());
857
858 SmallVector<int64_t> dstDims =
859 llvm::to_vector<4>(Range: insertOp.getDestVectorType().getShape());
860 dstDims.back() = dstDims.back() / shrinkRatio;
861 VectorType newCastDstType =
862 VectorType::get(shape: dstDims, elementType: castDstType.getElementType());
863
864 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
865 location: bitcastOp.getLoc(), args&: newCastDstType, args: insertOp.getDest());
866
867 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
868 op: bitcastOp, args: bitcastOp.getType(), args&: newCastSrcOp, args&: newCastDstOp, args&: newOffsets,
869 args: insertOp.getStrides());
870
871 return success();
872 }
873};
874
875// Breaks down vector.bitcast op
876//
877// This transforms IR like:
878// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
879// Into:
880// %cst = vector.splat %c0_f32 : vector<4xf32>
881// %1 = vector.extract_strided_slice %0 {
882// offsets = [0], sizes = [4], strides = [1]
883// } : vector<8xf16> to vector<4xf16>
884// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
885// %4 = vector.insert_strided_slice %2, %cst {
886// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
887// %5 = vector.extract_strided_slice %0 {
888// offsets = [4], sizes = [4], strides = [1]
889// } : vector<8xf16> to vector<4xf16>
890// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
891// %7 = vector.insert_strided_slice %6, %cst {
892// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
893struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
894 using OpRewritePattern::OpRewritePattern;
895
896public:
897 BreakDownVectorBitCast(MLIRContext *context,
898 std::function<bool(vector::BitCastOp)> controlFn,
899 PatternBenefit benefit)
900 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
901
902 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
903 PatternRewriter &rewriter) const override {
904
905 if (controlFn && !controlFn(bitcastOp))
906 return failure();
907
908 VectorType castSrcType = bitcastOp.getSourceVectorType();
909 VectorType castDstType = bitcastOp.getResultVectorType();
910 assert(castSrcType.getRank() == castDstType.getRank());
911
912 // This transformation builds on top of
913 // vector.{extract|insert}_strided_slice, which do not support
914 // extracting/inserting "scallable sub-vectors". Bail out.
915 if (castSrcType.isScalable())
916 return rewriter.notifyMatchFailure(arg&: bitcastOp,
917 msg: "Scalable vectors are not supported");
918
919 // Only support rank 1 case for now.
920 if (castSrcType.getRank() != 1)
921 return failure();
922
923 int64_t castSrcLastDim = castSrcType.getShape().back();
924 int64_t castDstLastDim = castDstType.getShape().back();
925 // Require casting to less elements for now; other cases to be implemented.
926 if (castSrcLastDim < castDstLastDim)
927 return failure();
928
929 assert(castSrcLastDim % castDstLastDim == 0);
930 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
931 // Nothing to do if it is already bitcasting to a single element.
932 if (castSrcLastDim == shrinkRatio)
933 return failure();
934
935 Location loc = bitcastOp.getLoc();
936 Type elemType = castDstType.getElementType();
937 assert(elemType.isSignlessIntOrIndexOrFloat());
938
939 Value zero = rewriter.create<arith::ConstantOp>(
940 location: loc, args&: elemType, args: rewriter.getZeroAttr(type: elemType));
941 Value res = rewriter.create<SplatOp>(location: loc, args&: castDstType, args&: zero);
942
943 SmallVector<int64_t> sliceShape = {castDstLastDim};
944 SmallVector<int64_t> strides = {1};
945 VectorType newCastDstType =
946 VectorType::get(shape: SmallVector<int64_t>{castDstLastDim / shrinkRatio},
947 elementType: castDstType.getElementType());
948
949 for (int i = 0, e = shrinkRatio; i < e; ++i) {
950 Value extracted = rewriter.create<ExtractStridedSliceOp>(
951 location: loc, args: bitcastOp.getSource(), args: ArrayRef<int64_t>{i * castDstLastDim},
952 args&: sliceShape, args&: strides);
953 Value bitcast =
954 rewriter.create<BitCastOp>(location: loc, args&: newCastDstType, args&: extracted);
955 res = rewriter.create<InsertStridedSliceOp>(
956 location: loc, args&: bitcast, args&: res,
957 args: ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, args&: strides);
958 }
959 rewriter.replaceOp(op: bitcastOp, newValues: res);
960 return success();
961 }
962
963private:
964 std::function<bool(BitCastOp)> controlFn;
965};
966
967/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
968///
969/// Example:
970/// ```
971/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
972/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
973/// %r = arith.addi %a, %b : vector<1x4xindex>
974/// ```
975/// Gets converted to:
976/// ```
977/// %r = arith.addi %arg0, %arg1 : index
978/// %b = vector.broadcast %r : index to vector<1x4xindex>
979/// ```
980///
981/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
982/// ops.
983struct ReorderElementwiseOpsOnBroadcast final
984 : public OpTraitRewritePattern<OpTrait::Elementwise> {
985 using OpTraitRewritePattern::OpTraitRewritePattern;
986 LogicalResult matchAndRewrite(Operation *op,
987 PatternRewriter &rewriter) const override {
988 if (op->getNumResults() != 1)
989 return failure();
990 if (!llvm::isa<ShapedType>(Val: op->getResults()[0].getType()))
991 return failure();
992 if (!OpTrait::hasElementwiseMappableTraits(op))
993 return rewriter.notifyMatchFailure(
994 arg&: op, msg: "Op doesn't have ElementwiseMappableTraits");
995 if (op->getNumOperands() == 0)
996 return failure();
997 if (op->getResults()[0].getType() != op->getOperand(idx: 0).getType())
998 return rewriter.notifyMatchFailure(arg&: op,
999 msg: "result and operand type mismatch");
1000 if (isa<vector::FMAOp>(Val: op)) {
1001 return rewriter.notifyMatchFailure(
1002 arg&: op,
1003 msg: "Op only accepts vector types - not supported as broadcast source "
1004 "might be a scalar");
1005 }
1006
1007 // Get the type of the lhs operand
1008 auto *lhsBcastOrSplat = op->getOperand(idx: 0).getDefiningOp();
1009 if (!lhsBcastOrSplat ||
1010 !isa<vector::BroadcastOp, vector::SplatOp>(Val: *lhsBcastOrSplat))
1011 return failure();
1012 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(idx: 0).getType();
1013
1014 // Make sure that all operands are broadcast from identical types:
1015 // * scalar (`vector.broadcast` + `vector.splat`), or
1016 // * vector (`vector.broadcast`).
1017 // Otherwise the re-ordering wouldn't be safe.
1018 if (!llvm::all_of(Range: op->getOperands(), P: [&lhsBcastOrSplatType](Value val) {
1019 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1020 if (bcast)
1021 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1022 auto splat = val.getDefiningOp<vector::SplatOp>();
1023 if (splat)
1024 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1025 return false;
1026 })) {
1027 return failure();
1028 }
1029
1030 // Collect the source values before broadcasting
1031 SmallVector<Value> srcValues;
1032 srcValues.reserve(N: op->getNumOperands());
1033 for (Value operand : op->getOperands()) {
1034 srcValues.push_back(Elt: operand.getDefiningOp()->getOperand(idx: 0));
1035 }
1036
1037 // Create the "elementwise" Op
1038 Operation *elementwiseOp =
1039 rewriter.create(loc: op->getLoc(), opName: op->getName().getIdentifier(), operands: srcValues,
1040 types: lhsBcastOrSplatType, attributes: op->getAttrs());
1041
1042 // Replace the original Op with the elementwise Op
1043 auto vectorType = op->getResultTypes()[0];
1044 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1045 op, args&: vectorType, args: elementwiseOp->getResults());
1046
1047 return success();
1048 }
1049};
1050
1051/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1052/// This may result in cleaner code when extracting a single value
1053/// from multi-element vector and also to help canonicalize 1-element vectors to
1054/// scalars.
1055///
1056/// Example:
1057/// ```
1058/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1059/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1060/// ```
1061/// Gets converted to:
1062/// ```
1063/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1064/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1065/// %2 = arith.addf %0, %1 : f32
1066/// ```
1067class ExtractOpFromElementwise final
1068 : public OpRewritePattern<vector::ExtractOp> {
1069public:
1070 using OpRewritePattern::OpRewritePattern;
1071
1072 LogicalResult matchAndRewrite(vector::ExtractOp op,
1073 PatternRewriter &rewriter) const override {
1074 Operation *eltwise = op.getVector().getDefiningOp();
1075
1076 // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1077 // as it doesn't support scalars.
1078 if (!eltwise || !OpTrait::hasElementwiseMappableTraits(op: eltwise) ||
1079 isa<vector::FMAOp>(Val: eltwise))
1080 return rewriter.notifyMatchFailure(arg&: op, msg: "not an elementwise op");
1081
1082 if (eltwise->getNumResults() != 1)
1083 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single result");
1084
1085 if (!eltwise->hasOneUse())
1086 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single op use");
1087
1088 if (!llvm::all_equal(Range: eltwise->getOperandTypes()))
1089 return rewriter.notifyMatchFailure(arg&: op, msg: "operand types are different");
1090
1091 // Dynamic position can cause dominance issues, so conservatively fail for
1092 // now.
1093 if (!op.getDynamicPosition().empty())
1094 return rewriter.notifyMatchFailure(
1095 arg&: op, msg: "dynamic position not yet implemented");
1096
1097 Type dstType = op.getType();
1098
1099 OpBuilder::InsertionGuard g(rewriter);
1100 rewriter.setInsertionPoint(eltwise);
1101
1102 IRMapping mapping;
1103 Location loc = eltwise->getLoc();
1104 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1105 for (Value arg : eltwise->getOperands()) {
1106 Value newArg = rewriter.create<vector::ExtractOp>(location: loc, args&: arg, args&: pos);
1107 mapping.map(from: arg, to: newArg);
1108 }
1109
1110 Operation *newEltwise = rewriter.clone(op&: *eltwise, mapper&: mapping);
1111 newEltwise->getResult(idx: 0).setType(dstType);
1112
1113 rewriter.replaceOp(op, newOp: newEltwise);
1114 rewriter.eraseOp(op: eltwise);
1115 return success();
1116 }
1117};
1118
1119/// Check if the element type is suitable for vector.load/store sinking.
1120/// Element type must be index or byte-aligned integer or floating-point type.
1121static bool isSupportedMemSinkElementType(Type type) {
1122 if (isa<IndexType>(Val: type))
1123 return true;
1124
1125 return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1126}
1127
1128/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1129/// Only index and byte-aligned integer and floating-point element types are
1130/// supported for now.
1131///
1132/// Example:
1133/// ```
1134/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1135/// vector.extract %0[1] : f32 from vector<4xf32>
1136/// ```
1137/// Gets converted to:
1138/// ```
1139/// %c1 = arith.constant 1 : index
1140/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1141/// %1 = memref.load %arg0[%0] : memref<?xf32>
1142/// ```
1143class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1144public:
1145 using OpRewritePattern::OpRewritePattern;
1146
1147 LogicalResult matchAndRewrite(vector::ExtractOp op,
1148 PatternRewriter &rewriter) const override {
1149 auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1150 if (!loadOp)
1151 return rewriter.notifyMatchFailure(arg&: op, msg: "expected a load op");
1152
1153 // Checking for single use so we won't duplicate load ops.
1154 if (!loadOp->hasOneUse())
1155 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single op use");
1156
1157 VectorType loadVecType = loadOp.getVectorType();
1158 if (loadVecType.isScalable())
1159 return rewriter.notifyMatchFailure(arg&: op,
1160 msg: "scalable vectors are not supported");
1161
1162 MemRefType memType = loadOp.getMemRefType();
1163
1164 // Non-byte-aligned types are tricky and may require special handling,
1165 // ignore them for now.
1166 if (!isSupportedMemSinkElementType(type: memType.getElementType()))
1167 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported element type");
1168
1169 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1170 if (rankOffset < 0)
1171 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported ranks combination");
1172
1173 auto extractVecType = dyn_cast<VectorType>(Val: op.getResult().getType());
1174 int64_t finalRank = 0;
1175 if (extractVecType)
1176 finalRank = extractVecType.getRank();
1177
1178 SmallVector<Value> indices = loadOp.getIndices();
1179 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1180
1181 // There may be memory stores between the load and the extract op, so we
1182 // need to make sure that the new load op is inserted at the same place as
1183 // the original load op.
1184 OpBuilder::InsertionGuard g(rewriter);
1185 rewriter.setInsertionPoint(loadOp);
1186 Location loc = loadOp.getLoc();
1187 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1188 for (auto i : llvm::seq<int64_t>(Begin: rankOffset, End: indices.size() - finalRank)) {
1189 OpFoldResult pos = extractPos[i - rankOffset];
1190 if (isZeroInteger(v: pos))
1191 continue;
1192
1193 Value offset = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: pos);
1194 indices[i] = idxBuilderf.add(lhs: indices[i], rhs: offset);
1195 }
1196
1197 Value base = loadOp.getBase();
1198 if (extractVecType) {
1199 rewriter.replaceOpWithNewOp<vector::LoadOp>(op, args&: extractVecType, args&: base,
1200 args&: indices);
1201 } else {
1202 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, args&: base, args&: indices);
1203 }
1204 // We checked for single use so we can safely erase the load op.
1205 rewriter.eraseOp(op: loadOp);
1206 return success();
1207 }
1208};
1209
1210/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1211///
1212/// Example:
1213/// ```
1214/// %0 = vector.splat %arg2 : vector<1xf32>
1215/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1216/// ```
1217/// Gets converted to:
1218/// ```
1219/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1220/// ```
1221class StoreOpFromSplatOrBroadcast final
1222 : public OpRewritePattern<vector::StoreOp> {
1223public:
1224 using OpRewritePattern::OpRewritePattern;
1225
1226 LogicalResult matchAndRewrite(vector::StoreOp op,
1227 PatternRewriter &rewriter) const override {
1228 VectorType vecType = op.getVectorType();
1229 if (vecType.isScalable())
1230 return rewriter.notifyMatchFailure(arg&: op,
1231 msg: "scalable vectors are not supported");
1232
1233 if (isa<VectorType>(Val: op.getMemRefType().getElementType()))
1234 return rewriter.notifyMatchFailure(
1235 arg&: op, msg: "memrefs of vectors are not supported");
1236
1237 if (vecType.getNumElements() != 1)
1238 return rewriter.notifyMatchFailure(
1239 arg&: op, msg: "only 1-element vectors are supported");
1240
1241 Operation *splat = op.getValueToStore().getDefiningOp();
1242 if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(Val: splat))
1243 return rewriter.notifyMatchFailure(arg&: op, msg: "neither a splat nor a broadcast");
1244
1245 // Checking for single use so we can remove splat.
1246 if (!splat->hasOneUse())
1247 return rewriter.notifyMatchFailure(arg&: op, msg: "expected single op use");
1248
1249 Value source = splat->getOperand(idx: 0);
1250 Value base = op.getBase();
1251 ValueRange indices = op.getIndices();
1252
1253 if (isa<VectorType>(Val: source.getType())) {
1254 rewriter.replaceOpWithNewOp<vector::StoreOp>(op, args&: source, args&: base, args&: indices);
1255 } else {
1256 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, args&: source, args&: base, args&: indices);
1257 }
1258 rewriter.eraseOp(op: splat);
1259 return success();
1260 }
1261};
1262
1263// Helper that returns a vector comparison that constructs a mask:
1264// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1265//
1266// If `dim == 0` then the result will be a 0-D vector.
1267//
1268// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1269// much more compact, IR for this operation, but LLVM eventually
1270// generates more elaborate instructions for this intrinsic since it
1271// is very conservative on the boundary conditions.
1272static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1273 bool force32BitVectorIndices, int64_t dim,
1274 Value b, Value *off = nullptr) {
1275 auto loc = op->getLoc();
1276 // If we can assume all indices fit in 32-bit, we perform the vector
1277 // comparison in 32-bit to get a higher degree of SIMD parallelism.
1278 // Otherwise we perform the vector comparison using 64-bit indices.
1279 Type idxType =
1280 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1281 DenseIntElementsAttr indicesAttr;
1282 if (dim == 0 && force32BitVectorIndices) {
1283 indicesAttr = DenseIntElementsAttr::get(
1284 type: VectorType::get(shape: ArrayRef<int64_t>{}, elementType: idxType), arg: ArrayRef<int32_t>{0});
1285 } else if (dim == 0) {
1286 indicesAttr = DenseIntElementsAttr::get(
1287 type: VectorType::get(shape: ArrayRef<int64_t>{}, elementType: idxType), arg: ArrayRef<int64_t>{0});
1288 } else if (force32BitVectorIndices) {
1289 indicesAttr = rewriter.getI32VectorAttr(
1290 values: llvm::to_vector<4>(Range: llvm::seq<int32_t>(Begin: 0, End: dim)));
1291 } else {
1292 indicesAttr = rewriter.getI64VectorAttr(
1293 values: llvm::to_vector<4>(Range: llvm::seq<int64_t>(Begin: 0, End: dim)));
1294 }
1295 Value indices = rewriter.create<arith::ConstantOp>(location: loc, args&: indicesAttr);
1296 // Add in an offset if requested.
1297 if (off) {
1298 Value o = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: *off);
1299 Value ov = rewriter.create<vector::SplatOp>(location: loc, args: indices.getType(), args&: o);
1300 indices = rewriter.create<arith::AddIOp>(location: loc, args&: ov, args&: indices);
1301 }
1302 // Construct the vector comparison.
1303 Value bound = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: b);
1304 Value bounds =
1305 rewriter.create<vector::SplatOp>(location: loc, args: indices.getType(), args&: bound);
1306 return rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt, args&: indices,
1307 args&: bounds);
1308}
1309
1310template <typename ConcreteOp>
1311struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1312public:
1313 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1314 PatternBenefit benefit = 1)
1315 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1316 force32BitVectorIndices(enableIndexOpt) {}
1317
1318 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1319 PatternRewriter &rewriter) const override {
1320 if (!xferOp.hasOutOfBoundsDim())
1321 return failure();
1322
1323 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1324 return failure();
1325
1326 Location loc = xferOp->getLoc();
1327 VectorType vtp = xferOp.getVectorType();
1328
1329 // Create the in-bounds mask with all elements between [0 .. dim - offset)
1330 // set and [dim - offset .. vector_length) unset.
1331 //
1332 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1333 // dimensions here.
1334 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1335 Value off = xferOp.getIndices()[lastIndex];
1336 Value dim =
1337 vector::createOrFoldDimOp(b&: rewriter, loc, source: xferOp.getBase(), dim: lastIndex);
1338 Value b = rewriter.create<arith::SubIOp>(location: loc, args: dim.getType(), args&: dim, args&: off);
1339 Value mask = rewriter.create<vector::CreateMaskOp>(
1340 location: loc,
1341 args: VectorType::get(shape: vtp.getShape(), elementType: rewriter.getI1Type(),
1342 scalableDims: vtp.getScalableDims()),
1343 args&: b);
1344 if (xferOp.getMask()) {
1345 // Intersect the in-bounds with the mask specified as an op parameter.
1346 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1347 }
1348
1349 rewriter.modifyOpInPlace(xferOp, [&]() {
1350 xferOp.getMaskMutable().assign(mask);
1351 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr(values: {true}));
1352 });
1353
1354 return success();
1355 }
1356
1357private:
1358 const bool force32BitVectorIndices;
1359};
1360
1361/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1362class VectorCreateMaskOpConversion
1363 : public OpRewritePattern<vector::CreateMaskOp> {
1364public:
1365 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1366 bool enableIndexOpt,
1367 PatternBenefit benefit = 1)
1368 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1369 force32BitVectorIndices(enableIndexOpt) {}
1370
1371 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1372 PatternRewriter &rewriter) const override {
1373 auto dstType = op.getType();
1374 if (cast<VectorType>(Val&: dstType).isScalable())
1375 return failure();
1376 int64_t rank = dstType.getRank();
1377 if (rank > 1)
1378 return failure();
1379 rewriter.replaceOp(
1380 op, newValues: buildVectorComparison(rewriter, op, force32BitVectorIndices,
1381 dim: rank == 0 ? 0 : dstType.getDimSize(idx: 0),
1382 b: op.getOperand(i: 0)));
1383 return success();
1384 }
1385
1386private:
1387 const bool force32BitVectorIndices;
1388};
1389
1390/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1391static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1392 auto denseAttr = dyn_cast<DenseIntElementsAttr>(Val: constantOp.getValue());
1393 // TODO: Support non-dense constant.
1394 if (!denseAttr)
1395 return false;
1396
1397 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1398 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1399}
1400
1401/// Folds a select operation between an all-true and all-false vector. For now,
1402/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1403///
1404/// %true = arith.constant dense<true> : vector<1xi1>
1405/// %false = arith.constant dense<false> : vector<1xi1>
1406/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1407/// =>
1408/// %result = vector.broadcast %cond : i1 to vector<1xi1>
1409///
1410/// InstCombine seems to handle vectors with multiple elements but not the
1411/// single element ones.
1412struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1413 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1414
1415 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1416 PatternRewriter &rewriter) const override {
1417 auto vecType = dyn_cast<VectorType>(Val: selectOp.getType());
1418 if (!vecType || !vecType.getElementType().isInteger(width: 1))
1419 return failure();
1420
1421 // Only scalar conditions can be folded.
1422 Value cond = selectOp.getCondition();
1423 if (isa<VectorType>(Val: cond.getType()))
1424 return failure();
1425
1426 // TODO: Support n-D and scalable vectors.
1427 if (vecType.getRank() != 1 || vecType.isScalable())
1428 return failure();
1429
1430 // TODO: Support vectors with multiple elements.
1431 if (vecType.getShape()[0] != 1)
1432 return failure();
1433
1434 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1435 if (!trueConst || !allI1ConstantValuesSetTo(constantOp: trueConst, value: true))
1436 return failure();
1437
1438 auto falseConst =
1439 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1440 if (!falseConst || !allI1ConstantValuesSetTo(constantOp: falseConst, value: false))
1441 return failure();
1442
1443 // Replace select with its condition broadcasted to single element vector.
1444 auto elemType = rewriter.getIntegerType(width: vecType.getNumElements());
1445 auto bcastType = VectorType::get(/*shape=*/{1}, elementType: elemType);
1446 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op: selectOp, args&: bcastType, args&: cond);
1447 return success();
1448 }
1449};
1450
1451/// Returns the number of dims can be folded away from transfer ops. It returns
1452/// a failure if it can not determine the number of dims to be folded.
1453///
1454/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1455/// `vectorType` is vector<16x16x1x1xf32>
1456/// (there two inner most dims can be dropped by memref.subview ops)
1457///
1458/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1459/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1460/// (only the inner most unit dim of `srcType` can be dropped)
1461///
1462/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1463/// `vectorType` is vector<16x16x1x[1]xf32>
1464/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1465/// unit")
1466static FailureOr<size_t>
1467getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1468 SmallVector<int64_t> srcStrides;
1469 int64_t srcOffset;
1470 if (failed(Result: srcType.getStridesAndOffset(strides&: srcStrides, offset&: srcOffset)))
1471 return failure();
1472
1473 auto isUnitDim = [](VectorType type, int dim) {
1474 return type.getDimSize(idx: dim) == 1 && !type.getScalableDims()[dim];
1475 };
1476
1477 // According to vector.transfer_read/write semantics, the vector can be a
1478 // slice. Thus, we have to offset the check index with `rankDiff` in
1479 // `srcStrides` and source dim sizes.
1480 size_t result = 0;
1481 int rankDiff = srcType.getRank() - vectorType.getRank();
1482 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1483 // Check that the inner dim size is 1 for both memref type and vector slice.
1484 // It can be folded only if they are 1 and the stride is 1.
1485 int dim = vectorType.getRank() - i - 1;
1486 if (srcStrides[dim + rankDiff] != 1 ||
1487 srcType.getDimSize(idx: dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1488 break;
1489 result++;
1490 }
1491 return result;
1492}
1493
1494/// Drop inner most contiguous unit dimensions from transfer_read operand.
1495class DropInnerMostUnitDimsTransferRead
1496 : public OpRewritePattern<vector::TransferReadOp> {
1497 using OpRewritePattern::OpRewritePattern;
1498
1499 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1500 PatternRewriter &rewriter) const override {
1501 // TODO: support 0-d corner case.
1502 if (readOp.getTransferRank() == 0)
1503 return failure();
1504
1505 // TODO: support mask.
1506 if (readOp.getMask())
1507 return failure();
1508
1509 auto srcType = dyn_cast<MemRefType>(Val: readOp.getBase().getType());
1510 if (!srcType)
1511 return failure();
1512
1513 if (!readOp.getPermutationMap().isMinorIdentity())
1514 return failure();
1515
1516 auto targetType = readOp.getVectorType();
1517 if (targetType.getRank() <= 1)
1518 return failure();
1519
1520 FailureOr<size_t> maybeDimsToDrop =
1521 getTransferFoldableInnerUnitDims(srcType, vectorType: targetType);
1522 if (failed(Result: maybeDimsToDrop))
1523 return failure();
1524
1525 size_t dimsToDrop = maybeDimsToDrop.value();
1526 if (dimsToDrop == 0)
1527 return failure();
1528
1529 auto inBounds = readOp.getInBoundsValues();
1530 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(N: dimsToDrop);
1531 if (llvm::is_contained(Range&: droppedInBounds, Element: false))
1532 return failure();
1533
1534 auto resultTargetVecType =
1535 VectorType::get(shape: targetType.getShape().drop_back(N: dimsToDrop),
1536 elementType: targetType.getElementType(),
1537 scalableDims: targetType.getScalableDims().drop_back(N: dimsToDrop));
1538
1539 auto loc = readOp.getLoc();
1540 SmallVector<OpFoldResult> sizes =
1541 memref::getMixedSizes(builder&: rewriter, loc, value: readOp.getBase());
1542 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1543 rewriter.getIndexAttr(value: 0));
1544 SmallVector<OpFoldResult> strides(srcType.getRank(),
1545 rewriter.getIndexAttr(value: 1));
1546 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1547 resultShape: srcType.getShape().drop_back(N: dimsToDrop), sourceMemRefType: srcType, staticOffsets: offsets, staticSizes: sizes,
1548 staticStrides: strides);
1549 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1550 value: readOp.getInBoundsAttr().getValue().drop_back(N: dimsToDrop));
1551 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1552 location: loc, args&: resultMemrefType, args: readOp.getBase(), args&: offsets, args&: sizes, args&: strides);
1553 auto permMap = getTransferMinorIdentityMap(
1554 shapedType: cast<ShapedType>(Val: rankedReducedView.getType()), vectorType: resultTargetVecType);
1555 Value result = rewriter.create<vector::TransferReadOp>(
1556 location: loc, args&: resultTargetVecType, args&: rankedReducedView,
1557 args: readOp.getIndices().drop_back(n: dimsToDrop), args: AffineMapAttr::get(value: permMap),
1558 args: readOp.getPadding(),
1559 // TODO: support mask.
1560 /*mask=*/args: Value(), args&: inBoundsAttr);
1561 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op: readOp, args&: targetType,
1562 args&: result);
1563 return success();
1564 }
1565};
1566
1567/// Drop inner most contiguous unit dimensions from transfer_write operand.
1568/// E.g.,
1569/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1570/// {in_bounds = [true, true, true, true, true]}
1571/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1572///
1573/// will be replaced with
1574///
1575/// %subview = memref.subview %arg0
1576/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1577/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1578/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1579/// to vector<1x16x16xf32>
1580/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1581/// {in_bounds = [true, true, true]}
1582/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1583///
1584/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1585class DropInnerMostUnitDimsTransferWrite
1586 : public OpRewritePattern<vector::TransferWriteOp> {
1587 using OpRewritePattern::OpRewritePattern;
1588
1589 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1590 PatternRewriter &rewriter) const override {
1591 // TODO: support 0-d corner case.
1592 if (writeOp.getTransferRank() == 0)
1593 return failure();
1594
1595 // TODO: support mask.
1596 if (writeOp.getMask())
1597 return failure();
1598
1599 auto srcType = dyn_cast<MemRefType>(Val: writeOp.getBase().getType());
1600 if (!srcType)
1601 return failure();
1602
1603 if (!writeOp.getPermutationMap().isMinorIdentity())
1604 return failure();
1605
1606 auto targetType = writeOp.getVectorType();
1607 if (targetType.getRank() <= 1)
1608 return failure();
1609
1610 FailureOr<size_t> maybeDimsToDrop =
1611 getTransferFoldableInnerUnitDims(srcType, vectorType: targetType);
1612 if (failed(Result: maybeDimsToDrop))
1613 return failure();
1614
1615 size_t dimsToDrop = maybeDimsToDrop.value();
1616 if (dimsToDrop == 0)
1617 return failure();
1618
1619 auto inBounds = writeOp.getInBoundsValues();
1620 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(N: dimsToDrop);
1621 if (llvm::is_contained(Range&: droppedInBounds, Element: false))
1622 return failure();
1623
1624 auto resultTargetVecType =
1625 VectorType::get(shape: targetType.getShape().drop_back(N: dimsToDrop),
1626 elementType: targetType.getElementType(),
1627 scalableDims: targetType.getScalableDims().drop_back(N: dimsToDrop));
1628
1629 Location loc = writeOp.getLoc();
1630 SmallVector<OpFoldResult> sizes =
1631 memref::getMixedSizes(builder&: rewriter, loc, value: writeOp.getBase());
1632 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1633 rewriter.getIndexAttr(value: 0));
1634 SmallVector<OpFoldResult> strides(srcType.getRank(),
1635 rewriter.getIndexAttr(value: 1));
1636 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1637 resultShape: srcType.getShape().drop_back(N: dimsToDrop), sourceMemRefType: srcType, staticOffsets: offsets, staticSizes: sizes,
1638 staticStrides: strides);
1639 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1640 value: writeOp.getInBoundsAttr().getValue().drop_back(N: dimsToDrop));
1641
1642 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1643 location: loc, args&: resultMemrefType, args: writeOp.getBase(), args&: offsets, args&: sizes, args&: strides);
1644 auto permMap = getTransferMinorIdentityMap(
1645 shapedType: cast<ShapedType>(Val: rankedReducedView.getType()), vectorType: resultTargetVecType);
1646
1647 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1648 location: loc, args&: resultTargetVecType, args: writeOp.getVector());
1649 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1650 op: writeOp, args&: shapeCast, args&: rankedReducedView,
1651 args: writeOp.getIndices().drop_back(n: dimsToDrop), args: AffineMapAttr::get(value: permMap),
1652 // TODO: support mask.
1653 /*mask=*/args: Value(), args&: inBoundsAttr);
1654 return success();
1655 }
1656};
1657
1658/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1659/// semantics to a contraction suitable for MMT (matrix matrix multiplication
1660/// with the RHS transposed) lowering.
1661struct CanonicalizeContractMatmulToMMT final
1662 : OpRewritePattern<vector::ContractionOp> {
1663 using OpRewritePattern::OpRewritePattern;
1664
1665 using FilterConstraintType =
1666 std::function<LogicalResult(vector::ContractionOp op)>;
1667
1668 CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1669 FilterConstraintType constraint)
1670 : OpRewritePattern<vector::ContractionOp>(context, benefit),
1671 filter(std::move(constraint)) {}
1672
1673 LogicalResult matchAndRewrite(vector::ContractionOp op,
1674 PatternRewriter &rewriter) const override {
1675 if (failed(Result: filter(op)))
1676 return failure();
1677
1678 Location loc = op.getLoc();
1679 Value lhs = op.getLhs();
1680 Value rhs = op.getRhs();
1681 Value res = op.getAcc();
1682
1683 // Set up the parallel/reduction structure in right form.
1684 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1685 auto infer = [&](MapList m) {
1686 return AffineMap::inferFromExprList(exprsList: m, context: op.getContext());
1687 };
1688 AffineExpr m;
1689 AffineExpr n;
1690 AffineExpr k;
1691 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
1692 static constexpr std::array<int64_t, 2> perm = {1, 0};
1693 auto iteratorTypes = op.getIteratorTypes().getValue();
1694 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1695 if (iteratorTypes.size() != 3 ||
1696 !vector::isParallelIterator(attr: iteratorTypes[0]) ||
1697 !vector::isParallelIterator(attr: iteratorTypes[1]) ||
1698 !vector::isReductionIterator(attr: iteratorTypes[2]))
1699 return rewriter.notifyMatchFailure(arg&: op, msg: "contraction is not a gemm");
1700
1701 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1702 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1703 if (maps == canonicalForm)
1704 return rewriter.notifyMatchFailure(arg&: op, msg: "already in the canonical form");
1705
1706 // Create a vector transpose making sure to emit zero/sign-extend at the
1707 // end.
1708 auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1709 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1710 Value trans =
1711 rewriter.create<vector::TransposeOp>(location: loc, args: sext.getIn(), args: perm);
1712 VectorType newType =
1713 cast<VectorType>(Val: trans.getType())
1714 .clone(elementType: cast<VectorType>(Val: mat.getType()).getElementType());
1715 return rewriter.create<arith::ExtSIOp>(location: loc, args&: newType, args&: trans);
1716 }
1717 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1718 Value trans =
1719 rewriter.create<vector::TransposeOp>(location: loc, args: zext.getIn(), args: perm);
1720 VectorType newType =
1721 VectorType::get(shape: cast<VectorType>(Val: trans.getType()).getShape(),
1722 elementType: cast<VectorType>(Val: mat.getType()).getElementType());
1723 return rewriter.create<arith::ExtUIOp>(location: loc, args&: newType, args&: trans);
1724 }
1725 return rewriter.create<vector::TransposeOp>(location: loc, args&: mat, args: perm);
1726 };
1727
1728 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1729 rhs = createTranspose(rhs);
1730 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1731 lhs = createTranspose(lhs);
1732 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1733 rhs = createTranspose(rhs);
1734 lhs = createTranspose(lhs);
1735 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1736 std::swap(a&: rhs, b&: lhs);
1737 rhs = createTranspose(rhs);
1738 lhs = createTranspose(lhs);
1739 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1740 std::swap(a&: rhs, b&: lhs);
1741 rhs = createTranspose(rhs);
1742 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1743 std::swap(a&: lhs, b&: rhs);
1744 lhs = createTranspose(lhs);
1745 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1746 std::swap(a&: lhs, b&: rhs);
1747 } else {
1748 return rewriter.notifyMatchFailure(arg&: op, msg: "unhandled contraction form");
1749 }
1750 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1751 op, args&: lhs, args&: rhs, args&: res, args: rewriter.getAffineMapArrayAttr(values: canonicalForm),
1752 args: op.getIteratorTypes());
1753 return success();
1754 };
1755
1756private:
1757 FilterConstraintType filter;
1758};
1759
1760/// Pattern to fold arithmetic extensions on floating point data types into
1761/// vector contraction operations. linalg.matmul introduces arithmetic
1762/// extensions on its operands. Please mlir snippets below for more details.
1763/// ```mlir
1764/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1765/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1766/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1767/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1768/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1769/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1770/// "linalg.yield"(%acc) : (f32) -> ()
1771/// })
1772/// ```
1773/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1774/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1775/// This pattern folds the arithmetic extensions into the vector contraction and
1776/// enables the usage of native mixed precision Tensor Core instructions.
1777template <typename ExtOp>
1778struct FoldArithExtIntoContractionOp
1779 : public OpRewritePattern<vector::ContractionOp> {
1780 using OpRewritePattern::OpRewritePattern;
1781
1782 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1783 PatternRewriter &rewriter) const override {
1784
1785 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1786 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1787
1788 if (!lhsDefOp || !rhsDefOp) {
1789 return rewriter.notifyMatchFailure(arg&: contractOp,
1790 msg: "no defining op on contract operands");
1791 }
1792
1793 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1794 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1795 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1796 contractOp.getIteratorTypesAttr());
1797
1798 return success();
1799 }
1800};
1801
1802/// Pattern to fold chained reduction to a series of vector additions and a
1803/// final reduction. This form should require fewer subgroup operations.
1804///
1805/// ```mlir
1806/// %a = vector.reduction <add> %x, %acc
1807/// %b = vector.reduction <add> %y, %a
1808/// ==>
1809/// %a = arith.addf %x, %y
1810/// %b = vector.reduction <add> %a, %acc
1811/// ```
1812struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1813 using OpRewritePattern::OpRewritePattern;
1814
1815 LogicalResult matchAndRewrite(vector::ReductionOp op,
1816 PatternRewriter &rewriter) const override {
1817 // TODO: Handle other combining kinds.
1818 if (op.getKind() != vector::CombiningKind::ADD)
1819 return failure();
1820
1821 // Accumulator is optional.
1822 Value acc = op.getAcc();
1823 if (!acc)
1824 return failure();
1825
1826 if (!acc.getType().isIntOrFloat())
1827 return failure();
1828
1829 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1830 if (!parentReduction)
1831 return failure();
1832
1833 Location loc = op.getLoc();
1834 Value vAdd;
1835 if (isa<IntegerType>(Val: acc.getType())) {
1836 vAdd = rewriter.createOrFold<arith::AddIOp>(
1837 location: loc, args: parentReduction.getVector(), args: op.getVector());
1838 } else {
1839 vAdd = rewriter.create<arith::AddFOp>(location: loc, args: parentReduction.getVector(),
1840 args: op.getVector());
1841 }
1842 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, args: op.getKind(), args&: vAdd,
1843 args: parentReduction.getAcc());
1844 return success();
1845 }
1846};
1847
1848// Helper function dropping unit non-scalable dimension from a VectorType
1849// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1850// dimensions are not dropped. Folding such dimensions would require "shifting"
1851// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1852// vector<[4]xf32>). This could be implemented in the future.
1853static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1854 auto inVecShape = inVecTy.getShape();
1855 SmallVector<int64_t> newShape;
1856 SmallVector<bool> newScalableDims;
1857 for (auto [dim, isScalable] :
1858 llvm::zip_equal(t&: inVecShape, u: inVecTy.getScalableDims())) {
1859 if (dim == 1 && !isScalable)
1860 continue;
1861
1862 newShape.push_back(Elt: dim);
1863 newScalableDims.push_back(Elt: isScalable);
1864 }
1865 // All dims have been dropped, return vector<1xeType>.
1866 if (newShape.empty()) {
1867 newShape.push_back(Elt: 1);
1868 newScalableDims.push_back(Elt: false);
1869 }
1870
1871 return VectorType::get(shape: newShape, elementType: inVecTy.getElementType(), scalableDims: newScalableDims);
1872}
1873
1874/// For vectors with at least one unit dim, replaces:
1875/// elementwise(a, b)
1876/// with:
1877/// sc_a = shape_cast(a)
1878/// sc_b = shape_cast(b)
1879/// res = elementwise(sc_a, sc_b)
1880/// return shape_cast(res)
1881/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1882/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1883/// required to be rank > 1.
1884///
1885/// Ex:
1886/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1887/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1888///
1889/// gets converted to:
1890///
1891/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1892/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1893/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1894/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1895/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1896///
1897/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1898/// `%cast`.
1899struct DropUnitDimFromElementwiseOps final
1900 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1901 using OpTraitRewritePattern::OpTraitRewritePattern;
1902 LogicalResult matchAndRewrite(Operation *op,
1903 PatternRewriter &rewriter) const override {
1904 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1905 return failure();
1906
1907 auto resultVectorType = dyn_cast<VectorType>(Val: op->getResult(idx: 0).getType());
1908 if (!resultVectorType)
1909 return failure();
1910
1911 // Check the operand pre-conditions. For `Elementwise` ops all operands are
1912 // guaranteed to have identical shapes (with some exceptions such as
1913 // `arith.select`) and it suffices to only check one of them.
1914 auto sourceVectorType = dyn_cast<VectorType>(Val: op->getOperand(idx: 0).getType());
1915 if (!sourceVectorType)
1916 return failure();
1917 if (sourceVectorType.getRank() < 2)
1918 return failure();
1919
1920 SmallVector<Value> newOperands;
1921 auto loc = op->getLoc();
1922 for (auto operand : op->getOperands()) {
1923 auto opVectorType = cast<VectorType>(Val: operand.getType());
1924 auto newVType = dropNonScalableUnitDimFromType(inVecTy: opVectorType);
1925 if (newVType == opVectorType)
1926 return rewriter.notifyMatchFailure(arg&: op, msg: "No unit dimension to remove.");
1927
1928 auto opSC = rewriter.create<vector::ShapeCastOp>(location: loc, args&: newVType, args&: operand);
1929 newOperands.push_back(Elt: opSC);
1930 }
1931
1932 VectorType newResultVectorType =
1933 dropNonScalableUnitDimFromType(inVecTy: resultVectorType);
1934 // Create an updated elementwise Op without unit dim.
1935 Operation *elementwiseOp =
1936 rewriter.create(loc, opName: op->getName().getIdentifier(), operands: newOperands,
1937 types: newResultVectorType, attributes: op->getAttrs());
1938
1939 // Restore the unit dim by applying vector.shape_cast to the result.
1940 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, args&: resultVectorType,
1941 args: elementwiseOp->getResult(idx: 0));
1942
1943 return success();
1944 }
1945};
1946
1947/// A pattern to drop unit dims from vector.transpose.
1948///
1949/// Example:
1950///
1951/// BEFORE:
1952/// ```mlir
1953/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
1954/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1955/// ```
1956///
1957/// AFTER:
1958/// ```mlir
1959/// %dropDims = vector.shape_cast %vector
1960/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1961/// %transpose = vector.transpose %0, [1, 0]
1962/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1963/// %restoreDims = vector.shape_cast %transpose
1964/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1965/// ```
1966struct DropUnitDimsFromTransposeOp final
1967 : OpRewritePattern<vector::TransposeOp> {
1968 using OpRewritePattern::OpRewritePattern;
1969
1970 LogicalResult matchAndRewrite(vector::TransposeOp op,
1971 PatternRewriter &rewriter) const override {
1972 VectorType sourceType = op.getSourceVectorType();
1973 VectorType sourceTypeWithoutUnitDims =
1974 dropNonScalableUnitDimFromType(inVecTy: sourceType);
1975
1976 if (sourceType == sourceTypeWithoutUnitDims)
1977 return failure();
1978
1979 // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1980 auto sourceDims = llvm::to_vector(Range: vector::getDims(vType: sourceType));
1981 SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1982 int64_t droppedDims = 0;
1983 for (auto [i, dim] : llvm::enumerate(First&: sourceDims)) {
1984 droppedDimsBefore[i] = droppedDims;
1985 if (dim == std::make_tuple(args: 1, args: false))
1986 ++droppedDims;
1987 }
1988
1989 // Drop unit dims from transpose permutation.
1990 ArrayRef<int64_t> perm = op.getPermutation();
1991 SmallVector<int64_t> newPerm;
1992 for (int64_t idx : perm) {
1993 if (sourceDims[idx] == std::make_tuple(args: 1, args: false))
1994 continue;
1995 newPerm.push_back(Elt: idx - droppedDimsBefore[idx]);
1996 }
1997
1998 // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1999 // type when the dimensions are unit dimensions. In this case, the newPerm
2000 // should be [0].
2001 if (newPerm.empty()) {
2002 newPerm.push_back(Elt: 0);
2003 }
2004
2005 Location loc = op.getLoc();
2006 // Drop the unit dims via shape_cast.
2007 auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
2008 location: loc, args&: sourceTypeWithoutUnitDims, args: op.getVector());
2009 // Create the new transpose.
2010 auto transposeWithoutUnitDims =
2011 rewriter.create<vector::TransposeOp>(location: loc, args&: dropDimsShapeCast, args&: newPerm);
2012 // Restore the unit dims via shape cast.
2013 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2014 op, args: op.getResultVectorType(), args&: transposeWithoutUnitDims);
2015
2016 return success();
2017 }
2018};
2019
2020/// A pattern to drop unit dims from the iter_args of an scf.for.
2021///
2022/// Example:
2023///
2024/// BEFORE:
2025/// ```mlir
2026/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
2027/// ...
2028/// scf.yield %
2029/// }
2030/// ```
2031///
2032/// AFTER:
2033/// ```mlir
2034/// %drop = vector.shape_cast %init
2035/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
2036/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
2037/// %new_iter = vector.shape_cast %iter
2038/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2039/// ...
2040/// }
2041/// %res = vector.shape_cast %new_loop
2042/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2043/// ```
2044struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
2045 using OpRewritePattern::OpRewritePattern;
2046
2047 LogicalResult matchAndRewrite(scf::ForOp forOp,
2048 PatternRewriter &rewriter) const override {
2049 /// Find the first iter_arg with droppable unit dims. Further applications
2050 /// of this pattern will apply to later arguments.
2051 for (OpOperand &operand : forOp.getInitArgsMutable()) {
2052 auto vectorType = dyn_cast<VectorType>(Val: operand.get().getType());
2053 if (!vectorType)
2054 continue;
2055
2056 VectorType newVectorType = dropNonScalableUnitDimFromType(inVecTy: vectorType);
2057 if (vectorType == newVectorType)
2058 continue;
2059
2060 // Create a new ForOp with that iter operand replaced.
2061 auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
2062 return b.create<vector::ShapeCastOp>(location: loc, args&: type, args&: source);
2063 };
2064
2065 Value replacement =
2066 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2067 rewriter.replaceOp(op: forOp,
2068 newValues: replaceAndCastForOpIterArg(rewriter, forOp, operand,
2069 replacement, castFn));
2070 return success();
2071 }
2072 return failure();
2073 }
2074};
2075
2076/// Pattern to eliminate redundant zero-constants added to reduction operands.
2077/// It's enough for there to be one initial zero value, so we can eliminate the
2078/// extra ones that feed into `vector.reduction <add>`. These get created by the
2079/// `ChainedReduction` pattern.
2080///
2081/// ```mlir
2082/// %a = arith.addf %x, %zero
2083/// %b = arith.addf %a, %y
2084/// %c = vector.reduction <add> %b, %acc
2085/// ==>
2086/// %b = arith.addf %a, %y
2087/// %c = vector.reduction <add> %b, %acc
2088/// ```
2089struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2090 using OpRewritePattern::OpRewritePattern;
2091
2092 LogicalResult matchAndRewrite(vector::ReductionOp op,
2093 PatternRewriter &rewriter) const override {
2094 // TODO: Handle other reduction kinds and their identity values.
2095 if (op.getKind() != vector::CombiningKind::ADD)
2096 return failure();
2097
2098 Type elemType = op.getSourceVectorType().getElementType();
2099 // The integer case should be handled by `arith.addi` folders, only check
2100 // for floats here.
2101 if (!isa<FloatType>(Val: elemType))
2102 return failure();
2103
2104 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2105 if (!vAdd)
2106 return failure();
2107 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2108 if (!addLhs)
2109 return failure();
2110
2111 if (!matchPattern(value: addLhs.getRhs(), pattern: m_AnyZeroFloat()))
2112 return failure();
2113
2114 auto newAdd = rewriter.create<arith::AddFOp>(location: vAdd.getLoc(), args: addLhs.getLhs(),
2115 args: vAdd.getRhs());
2116 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, args: op.getKind(), args&: newAdd,
2117 args: op.getAcc());
2118 return success();
2119 }
2120};
2121
2122/// Example:
2123/// ```
2124/// %a = vector.reduction <add> %x : vector<2xf32> into f32
2125/// ```
2126/// is transformed into:
2127/// ```
2128/// %y = vector.extract %x[0] : f32 from vector<2xf32>
2129/// %z = vector.extract %x[1] : f32 from vector<2xf32>
2130/// %a = arith.addf %y, %z : f32
2131/// ```
2132struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2133 BreakDownVectorReduction(MLIRContext *context,
2134 unsigned maxNumElementsToExtract,
2135 PatternBenefit benefit)
2136 : OpRewritePattern(context, benefit),
2137 maxNumElementsToExtract(maxNumElementsToExtract) {}
2138
2139 LogicalResult matchAndRewrite(vector::ReductionOp op,
2140 PatternRewriter &rewriter) const override {
2141 VectorType type = op.getSourceVectorType();
2142 if (type.isScalable() || op.isMasked())
2143 return failure();
2144 assert(type.getRank() == 1 && "Expected a 1-d vector");
2145
2146 int64_t numElems = type.getNumElements();
2147 if (numElems > maxNumElementsToExtract) {
2148 return rewriter.notifyMatchFailure(
2149 arg&: op, msg: llvm::formatv(Fmt: "has too many vector elements ({0}) to break down "
2150 "(max allowed: {1})",
2151 Vals&: numElems, Vals: maxNumElementsToExtract));
2152 }
2153
2154 Location loc = op.getLoc();
2155 SmallVector<Value> extracted(numElems, nullptr);
2156 for (auto [idx, extractedElem] : llvm::enumerate(First&: extracted))
2157 extractedElem = rewriter.create<vector::ExtractOp>(
2158 location: loc, args: op.getVector(), args: static_cast<int64_t>(idx));
2159
2160 Value res = extracted.front();
2161 for (auto extractedElem : llvm::drop_begin(RangeOrContainer&: extracted))
2162 res = vector::makeArithReduction(b&: rewriter, loc, kind: op.getKind(), v1: res,
2163 acc: extractedElem, fastmath: op.getFastmathAttr());
2164 if (Value acc = op.getAcc())
2165 res = vector::makeArithReduction(b&: rewriter, loc, kind: op.getKind(), v1: res, acc,
2166 fastmath: op.getFastmathAttr());
2167
2168 rewriter.replaceOp(op, newValues: res);
2169 return success();
2170 }
2171
2172private:
2173 unsigned maxNumElementsToExtract = 0;
2174};
2175
2176/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2177/// B)`.
2178/// Example:
2179/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2180/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2181/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2182/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2183///
2184/// Becomes :
2185///
2186/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2187///
2188/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2189/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2190/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2191/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2192template <typename MulOpType>
2193struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2194 using OpRewritePattern<MulOpType>::OpRewritePattern;
2195 // Returns whether a vector.broadcast matches requirements for an outerproduct
2196 // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2197 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2198 // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2199 // shape_casts/broadcasts which does not belong in this pattern.
2200 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2201 return false;
2202 // Avoid broadcast like f32 or vector<f32> -> ResType
2203 auto srcType = dyn_cast<VectorType>(Val: broadcastOp.getSourceType());
2204 return srcType && srcType.getRank() != 2;
2205 }
2206
2207 LogicalResult matchAndRewrite(MulOpType mulOp,
2208 PatternRewriter &rewriter) const override {
2209 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
2210 if (!resType)
2211 return failure();
2212 if (resType.getRank() != 2)
2213 return failure();
2214 /// If operandA can be written as tr(broadcast(A)) and operandB as
2215 /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2216 /// vector.outerproduct(A, B). Returns failure() otherwise.
2217 auto matchOuterProduct =
2218 [&](Value operandA,
2219 Value operandB) -> FailureOr<vector::OuterProductOp> {
2220 auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2221 if (!transposedLhs)
2222 return failure();
2223 // Fail unless this is a true 2-D matrix transpose.
2224 ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2225 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2226 return failure();
2227
2228 auto broadcastedLhs =
2229 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2230 if (!broadcastedLhs || !isValidBroadcastSource(broadcastOp: broadcastedLhs))
2231 return failure();
2232
2233 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2234 if (!broadcastedRhs || !isValidBroadcastSource(broadcastOp: broadcastedRhs))
2235 return failure();
2236
2237 return rewriter.create<vector::OuterProductOp>(
2238 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2239 broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2240 };
2241
2242 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2243 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2244 // Handle commutativity, the transposed op is the outerproduct LHS.
2245 if (failed(maybeOuterP))
2246 maybeOuterP = matchOuterProduct(rhs, lhs);
2247 if (failed(maybeOuterP))
2248 return failure();
2249 rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2250 return success();
2251 }
2252};
2253
2254} // namespace
2255
2256void mlir::vector::populateFoldArithExtensionPatterns(
2257 RewritePatternSet &patterns) {
2258 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2259 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2260 arg: patterns.getContext());
2261}
2262
2263void mlir::vector::populateVectorMaskMaterializationPatterns(
2264 RewritePatternSet &patterns, bool force32BitVectorIndices,
2265 PatternBenefit benefit) {
2266 patterns.add<VectorCreateMaskOpConversion,
2267 MaterializeTransferMask<vector::TransferReadOp>,
2268 MaterializeTransferMask<vector::TransferWriteOp>>(
2269 arg: patterns.getContext(), args&: force32BitVectorIndices, args&: benefit);
2270 patterns.add<FoldI1Select>(arg: patterns.getContext(), args&: benefit);
2271}
2272
2273void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2274 RewritePatternSet &patterns, PatternBenefit benefit) {
2275 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2276 DropUnitDimsFromTransposeOp>(arg: patterns.getContext(), args&: benefit);
2277}
2278
2279void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2280 RewritePatternSet &patterns, PatternBenefit benefit) {
2281 patterns.add<BubbleDownVectorBitCastForExtract,
2282 BubbleDownBitCastForStridedSliceExtract,
2283 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2284 arg: patterns.getContext(), args&: benefit);
2285}
2286
2287void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2288 RewritePatternSet &patterns,
2289 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2290 patterns.add<BreakDownVectorBitCast>(arg: patterns.getContext(),
2291 args: std::move(controlFn), args&: benefit);
2292}
2293
2294void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
2295 RewritePatternSet &patterns,
2296 std::function<LogicalResult(vector::ContractionOp)> constraint,
2297 PatternBenefit benefit) {
2298 patterns.add<CanonicalizeContractMatmulToMMT>(arg: patterns.getContext(), args&: benefit,
2299 args: std::move(constraint));
2300}
2301
2302void mlir::vector::populateVectorReductionToContractPatterns(
2303 RewritePatternSet &patterns, PatternBenefit benefit) {
2304 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2305 CombineContractABTranspose, CombineContractResultTranspose>(
2306 arg: patterns.getContext(), args&: benefit);
2307}
2308
2309void mlir::vector::populateDropInnerMostUnitDimsXferOpPatterns(
2310 RewritePatternSet &patterns, PatternBenefit benefit) {
2311 patterns.add<DropInnerMostUnitDimsTransferRead,
2312 DropInnerMostUnitDimsTransferWrite>(arg: patterns.getContext(),
2313 args&: benefit);
2314}
2315
2316void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
2317 PatternBenefit benefit) {
2318 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2319 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2320 arg: patterns.getContext(), args&: benefit);
2321}
2322
2323void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2324 PatternBenefit benefit) {
2325 // TODO: Consider converting these patterns to canonicalizations.
2326 patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2327 arg: patterns.getContext(), args&: benefit);
2328}
2329
2330void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2331 RewritePatternSet &patterns, PatternBenefit benefit) {
2332 patterns.add<ChainedReduction>(arg: patterns.getContext(), args&: benefit);
2333 patterns.add<ReduceRedundantZero>(arg: patterns.getContext(),
2334 args: PatternBenefit(benefit.getBenefit() + 1));
2335}
2336
2337void mlir::vector::populateBreakDownVectorReductionPatterns(
2338 RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2339 PatternBenefit benefit) {
2340 patterns.add<BreakDownVectorReduction>(arg: patterns.getContext(),
2341 args&: maxNumElementsToExtract, args&: benefit);
2342}
2343
2344void mlir::vector::populateElementwiseToVectorOpsPatterns(
2345 RewritePatternSet &patterns) {
2346 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2347 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2348 arg: patterns.getContext());
2349}
2350
2351//===----------------------------------------------------------------------===//
2352// TableGen'd enum attribute definitions
2353//===----------------------------------------------------------------------===//
2354
2355#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
2356

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp