1//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites as 1->N patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14
15#include <cassert>
16#include <cstdint>
17#include <functional>
18#include <optional>
19
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Arith/Utils/Utils.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26#include "mlir/Dialect/Vector/IR/VectorOps.h"
27#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
28#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
29#include "mlir/IR/BuiltinTypes.h"
30#include "mlir/IR/Location.h"
31#include "mlir/IR/Matchers.h"
32#include "mlir/IR/PatternMatch.h"
33#include "mlir/IR/TypeUtilities.h"
34
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/FormatVariadic.h"
37
38#define DEBUG_TYPE "vector-to-vector"
39
40using namespace mlir;
41using namespace mlir::vector;
42
43template <typename IntType>
44static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
48}
49
50// Helper to find an index in an affine map.
51static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
52 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
53 int64_t idx = map.getDimPosition(idx: i);
54 if (idx == index)
55 return i;
56 }
57 return std::nullopt;
58}
59
60namespace {
61
62/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
63/// Ex:
64/// ```
65/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
66/// %1 = vector.multi_reduction add, %0 [1]
67/// : vector<8x32x16xf32> to vector<8x16xf32>
68/// ```
69/// Gets converted to:
70/// ```
71/// %1 = vector.contract {indexing_maps = [
72/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
73/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74/// affine_map<(d0, d1, d2) -> (d0, d1)>],
75/// iterator_types = ["parallel", "parallel", "reduction"],
76/// kind = add} %0, %arg1, %cst_f0
77/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78/// ```
79struct MultiReduceToContract
80 : public OpRewritePattern<vector::MultiDimReductionOp> {
81 using OpRewritePattern::OpRewritePattern;
82
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84 PatternRewriter &rewriter) const override {
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
86 return failure();
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
89 return failure();
90 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
91 auto srcMap = rewriter.getMultiDimIdentityMap(rank: reductionMask.size());
92 SmallVector<AffineExpr> exprs;
93 SmallVector<vector::IteratorType> iteratorTypes;
94 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
97 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
98 } else {
99 iteratorTypes.push_back(vector::IteratorType::reduction);
100 }
101 }
102 auto dstMap =
103 AffineMap::get(/*dimCount=*/reductionMask.size(),
104 /*symbolCount=*/0, exprs, reduceOp.getContext());
105 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
106 reduceOp, mulOp->getOperand(idx: 0), mulOp->getOperand(idx: 1), reduceOp.getAcc(),
107 rewriter.getAffineMapArrayAttr(values: {srcMap, srcMap, dstMap}),
108 rewriter.getArrayAttr(value: llvm::to_vector(llvm::map_range(
109 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
111 }))));
112 return success();
113 }
114};
115
116/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
117/// Ex:
118/// ```
119/// %0 = vector.transpose %arg0, [2, 0, 1]
120/// : vector<32x16x8xf32> to vector<8x32x16xf32>
121/// %1 = vector.contract {indexing_maps = [
122/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
123/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124/// affine_map<(d0, d1, d2) -> (d0, d1)>],
125/// iterator_types = ["parallel", "parallel", "reduction"],
126/// kind = add} %0, %arg1, %cst_f0
127/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
128/// ```
129/// Gets converted to:
130/// ```
131/// %1 = vector.contract {indexing_maps = [
132/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
133/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134/// affine_map<(d0, d1, d2) -> (d0, d1)>],
135/// iterator_types = ["parallel", "parallel", "reduction"],
136/// kind = add} %arg0, %arg1, %cst_f0
137/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138/// ```
139struct CombineContractABTranspose final
140 : public OpRewritePattern<vector::ContractionOp> {
141 using OpRewritePattern::OpRewritePattern;
142
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144 PatternRewriter &rewriter) const override {
145 SmallVector<AffineMap> maps =
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
149 size_t index = 0;
150 bool changed = false;
151 for (Value *operand : {&lhs, &rhs}) {
152 AffineMap &map = maps[index++];
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
154 if (!transposeOp)
155 continue;
156 AffineMap permutationMap = AffineMap::getPermutationMap(
157 transposeOp.getPermutation(), contractOp.getContext());
158 map = inversePermutation(permutationMap).compose(map);
159 *operand = transposeOp.getVector();
160 changed = true;
161 }
162 if (!changed)
163 return failure();
164 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
165 contractOp, lhs, rhs, contractOp.getAcc(),
166 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
167 return success();
168 }
169};
170
171/// Merges accumulator and result transposes into contract.
172///
173/// For example:
174/// ```mlir
175/// %accT = vector.transpose %acc, [0, 2, 1]
176/// : vector<2x8x4xf32> to vector<2x4x8xf32>
177/// %contract = vector.contract {
178/// indexing_maps = [
179/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
180/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
181/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
182/// ],
183/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
184/// kind = #vector.kind<add>
185/// } %lhs, %rhs, %accT
186/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
187/// %0 = vector.transpose %contract, [0, 2, 1]
188/// : vector<2x4x8xf32> to vector<2x8x4>
189/// ```
190/// Becomes:
191/// ```mlir
192/// %0 = vector.contract {
193/// indexing_maps = [
194/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
195/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
196/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
197/// ],
198/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
199/// kind = #vector.kind<add>
200/// } %lhs, %rhs, %acc
201/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
202/// ```
203struct CombineContractResultTranspose final
204 : public OpRewritePattern<vector::TransposeOp> {
205 using OpRewritePattern::OpRewritePattern;
206
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208 PatternRewriter &rewriter) const override {
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
211 return failure();
212
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
214 if (!accTOp)
215 return failure();
216
217 MLIRContext *context = contractOp.getContext();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219 AffineMap contractMap = maps.back();
220
221 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
222 // To index into A in contract, we need revert(f)(g(C)) -> A.
223 auto accTMap =
224 AffineMap::getPermutationMap(accTOp.getPermutation(), context);
225
226 // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
227 // To index into E in contract, we need h(g(C)) -> E.
228 auto resTMap =
229 AffineMap::getPermutationMap(resTOp.getPermutation(), context);
230 auto combinedResMap = resTMap.compose(contractMap);
231
232 // The accumulator and result share the same indexing map. So they should be
233 // the same to be able to merge. This means combinedResMap is the same as
234 // inversePermutation(accTMap).compose(contractMap), which means
235 if (inversePermutation(accTMap) != resTMap)
236 return failure();
237 maps.back() = combinedResMap;
238
239 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
241 rewriter.getAffineMapArrayAttr(values: maps), contractOp.getIteratorTypes());
242 return success();
243 }
244};
245
246/// Merge BroadcastOp into ContractionOp user.
247/// Ex:
248/// ```
249/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
250/// %1 = vector.contract {indexing_maps = [
251/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
252/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253/// affine_map<(d0, d1, d2) -> (d0, d1)>],
254/// iterator_types = ["parallel", "parallel", "reduction"],
255/// kind = add} %0, %arg1, %cst_f0
256/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
257/// ```
258/// Gets converted to:
259/// ```
260/// %1 = vector.contract {indexing_maps = [
261/// affine_map<(d0, d1, d2) -> (d1, d2)>,
262/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
263/// affine_map<(d0, d1, d2) -> (d0, d1)>],
264/// iterator_types = ["parallel", "parallel", "reduction"],
265/// kind = add} %arg0, %arg1, %cst_f0
266/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267/// ```
268///
269/// For masked vector.contract, the mask requires updating when a dimension is
270/// dropped. In such cases, the dropped dimensions must correspond to the mask's
271/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272/// is not supported.
273FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
275 PatternRewriter &rewriter) {
276 SmallVector<AffineMap> maps =
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
280 size_t index = 0;
281 bool changed = false;
282 for (Value *operand : {&lhs, &rhs}) {
283 AffineMap &map = maps[index++];
284 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
285 if (!broadcast)
286 continue;
287 // contractionOp can only take vector as operands.
288 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
289 if (!srcType ||
290 srcType.getRank() == broadcast.getResultVectorType().getRank())
291 continue;
292 int64_t rankDiff =
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast = false;
295 SmallVector<AffineExpr> originalDims;
296 for (const auto &dim : llvm::enumerate(srcType.getShape())) {
297 if (dim.value() !=
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast = true;
300 break;
301 }
302 originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
303 }
304 // Contract doesn't support inner dimension broadcast. Once this is
305 // relaxed we can remove this case.
306 if (innerDimBroadcast)
307 continue;
308
309 // It would be incorrect to fold a broadcast onto a reduction dimension
310 // of non-unit size.
311 bool nonUnitDimReductionBroadcast = false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
314 isReductionIterator(contractOp.getIteratorTypes()
315 .getValue()[map.getDimPosition(i)])) {
316 nonUnitDimReductionBroadcast = true;
317 break;
318 }
319 }
320 if (nonUnitDimReductionBroadcast)
321 continue;
322
323 AffineMap broadcastMap =
324 AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
325 originalDims, contractOp.getContext());
326 map = broadcastMap.compose(map);
327 *operand = broadcast.getSource();
328 changed = true;
329 }
330
331 if (!changed)
332 return failure();
333
334 // Determine which dims are usused, now that the maps have been composed
335 // with the broadcast maps.
336 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
337 // Compress unused dims.
338 for (auto &m : maps)
339 m = compressDims(m, unusedDimsBitVector);
340 // Compute the combined iterators.
341 SmallVector<Attribute> iterators;
342 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(Idx: i))
344 iterators.push_back(Elt: contractOp.getIteratorTypes().getValue()[i]);
345 }
346
347 // Check whether any of the unused dims is non-unit, e.g.:
348 // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349 // This is only required when collapsing a mask. If there is no mask, skip.
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit = false;
352 if (maskingOp) {
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(Idx: i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit = true;
357 break;
358 }
359 }
360 }
361
362 // Check that compressing unused dims isn't removing all reduction dimension
363 // pairs. For example, if the vector.contract had only one reduction
364 // iterator and that was a unit-dimension created by a broadcast,
365 // then we should bail here, otherwise we would create a contract without
366 // a reduction dimension pair.
367 bool hasReductionIteratorApplyingOnBothSides = false;
368 for (unsigned i = 0; i < iterators.size(); ++i) {
369 if (!isReductionIterator(attr: iterators[i]))
370 continue;
371 if (getResultIndex(map: maps[0], index: i) && getResultIndex(map: maps[1], index: i)) {
372 hasReductionIteratorApplyingOnBothSides = true;
373 break;
374 }
375 }
376 if (!hasReductionIteratorApplyingOnBothSides)
377 return failure();
378
379 // If the compressed maps have a dimension that is not used by either LHS or
380 // RHS then the ContractionOp verifier would fail.
381 if (getUnusedDimsBitVector(maps: {maps[0], maps[1]}).any())
382 return failure();
383
384 Operation *newOp = rewriter.create<vector::ContractionOp>(
385 contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
386 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
387
388 // Handle the mask.
389 if (maskingOp) {
390 if (isAnyUnusedDimNonUnit)
391 return rewriter.notifyMatchFailure(contractOp,
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
396
397 // If a dimension has been dropped, update the mask accordingly. Otherwise,
398 // keep it as is.
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
401 // At this point, two assumptions are made:
402 // * The unused dimensions are the leading mask dimensions
403 // (vector.contract does not support inner dim broadcasting).
404 // * The unused dimensions are all unit.
405 // These conditions are effectively verified in the blocks preceeding this
406 // one.
407 auto newShape =
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
412 VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
413 mask = rewriter
414 .create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
415 maskingOp.getMask())
416 .getResult();
417 }
418
419 newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask);
420 }
421 return newOp->getResult(idx: 0);
422}
423
424struct CombineContractBroadcastMask
425 : public MaskableOpRewritePattern<vector::ContractionOp> {
426 using MaskableOpRewritePattern::MaskableOpRewritePattern;
427 FailureOr<Value>
428
429 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430 MaskingOpInterface maskingOp,
431 PatternRewriter &rewriter) const override {
432 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
433 }
434};
435
436/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
437/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
438/// casting ops are around these operations.
439/// Ex:
440/// ```
441/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
442/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
443/// ```
444/// Gets converted to:
445/// ```
446/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
447/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
448/// ```
449struct ReorderCastOpsOnBroadcast
450 : public OpInterfaceRewritePattern<CastOpInterface> {
451 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOpInterface op,
454 PatternRewriter &rewriter) const override {
455 if (op->getNumOperands() != 1)
456 return failure();
457 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
458 if (!bcastOp)
459 return failure();
460
461 Type castResTy = getElementTypeOrSelf(op->getResult(0));
462 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
463 castResTy = vecTy.clone(castResTy);
464 auto *castOp =
465 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
466 bcastOp.getSource(), castResTy, op->getAttrs());
467 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
468 op, op->getResult(0).getType(), castOp->getResult(0));
469 return success();
470 }
471};
472
473/// Reorders elementwise(transpose) to transpose(elementwise). This makes
474/// transpose ops and contraction ops closer, which kicks in
475/// CombineContractABTranspose pattern when elementwise ops are between these
476/// operations. Ex:
477/// ```
478/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
479/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
480/// %r = arith.addf %at, %bt : vector<2x4xf32>
481/// ```
482/// Gets converted to:
483/// ```
484/// %0 = arith.addf %a, %b : vector<4x2xf32>
485/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
486/// ```
487struct ReorderElementwiseOpsOnTranspose final
488 : public OpTraitRewritePattern<OpTrait::Elementwise> {
489 using OpTraitRewritePattern::OpTraitRewritePattern;
490 LogicalResult matchAndRewrite(Operation *op,
491 PatternRewriter &rewriter) const override {
492 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
493 return failure();
494
495 // Make sure all operands are transpose/constant ops and collect their
496 // transposition maps.
497 SmallVector<ArrayRef<int64_t>> transposeMaps;
498 transposeMaps.reserve(N: op->getNumOperands());
499 // Record the initial type before transposition. We'll use its shape later.
500 // Any type will do here as we will check all transpose maps are the same.
501 VectorType srcType;
502 for (Value operand : op->getOperands()) {
503 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
504 if (transposeOp) {
505 transposeMaps.push_back(Elt: transposeOp.getPermutation());
506 srcType = transposeOp.getSourceVectorType();
507 } else if (!matchPattern(value: operand, pattern: m_Constant())) {
508 return failure();
509 }
510 }
511 if (transposeMaps.empty())
512 return failure();
513 // This is an elementwise op, so all transposed operands should have the
514 // same type. We need to additionally check that all transposes uses the
515 // same map.
516 if (!llvm::all_equal(Range&: transposeMaps))
517 return rewriter.notifyMatchFailure(arg&: op, msg: "different transpose map");
518
519 SmallVector<Value> srcValues;
520 srcValues.reserve(N: op->getNumOperands());
521
522 // If there are constant operands, we need to insert inverse transposes for
523 // them. Calculate the inverse order first.
524 auto order = transposeMaps.front();
525 SmallVector<int64_t> invOrder(order.size());
526 for (int i = 0, e = order.size(); i < e; ++i)
527 invOrder[order[i]] = i;
528
529 for (Value operand : op->getOperands()) {
530 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
531 if (transposeOp) {
532 srcValues.push_back(Elt: transposeOp.getVector());
533 } else {
534 // This is a constant. Create a reverse transpose op for it.
535 auto vectorType =
536 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
537 srcValues.push_back(rewriter.create<vector::TransposeOp>(
538 operand.getLoc(), vectorType, operand, invOrder));
539 }
540 }
541
542 auto vectorType = srcType.clone(
543 cast<VectorType>(op->getResultTypes()[0]).getElementType());
544 Operation *elementwiseOp =
545 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
546 vectorType, op->getAttrs());
547 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
548 op, op->getResultTypes()[0], elementwiseOp->getResult(0),
549 transposeMaps.front());
550 return success();
551 }
552};
553
554// Returns the values in `arrayAttr` as an integer vector.
555static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
556 return llvm::to_vector<4>(
557 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
558 [](IntegerAttr attr) { return attr.getInt(); }));
559}
560
561// Shuffles vector.bitcast op after vector.extract op.
562//
563// This transforms IR like:
564// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
565// %1 = vector.extract %0[3] : f16 from vector<8xf16>
566// Into:
567// %0 = vector.extract %src[1] : f32 from vector<4xf32>
568// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
569// %2 = vector.extract %1[1] : f16 from vector<2xf16>
570struct BubbleDownVectorBitCastForExtract
571 : public OpRewritePattern<vector::ExtractOp> {
572 using OpRewritePattern::OpRewritePattern;
573
574 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
575 PatternRewriter &rewriter) const override {
576 // Only support extracting scalars for now.
577 if (extractOp.getSourceVectorType().getRank() != 1)
578 return failure();
579
580 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
581 if (!castOp)
582 return failure();
583
584 VectorType castSrcType = castOp.getSourceVectorType();
585 VectorType castDstType = castOp.getResultVectorType();
586 assert(castSrcType.getRank() == castDstType.getRank());
587
588 // Fail to match if we only have one element in the cast op source.
589 // This is to avoid infinite loop given that this pattern can generate
590 // such cases.
591 if (castSrcType.getNumElements() == 1)
592 return failure();
593
594 // Only support casting to a larger number of elements or now.
595 // E.g., vector<4xf32> -> vector<8xf16>.
596 if (castSrcType.getNumElements() > castDstType.getNumElements())
597 return failure();
598
599 unsigned expandRatio =
600 castDstType.getNumElements() / castSrcType.getNumElements();
601
602 // Get the first element of the mixed position as integer.
603 auto mixedPos = extractOp.getMixedPosition();
604 if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0]))
605 return failure();
606 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
607
608 // Get the single scalar (as a vector) in the source value that packs the
609 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
610 Location loc = extractOp.getLoc();
611 Value packedValue = rewriter.create<vector::ExtractOp>(
612 loc, castOp.getSource(), index / expandRatio);
613 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
614 Value zero = rewriter.create<arith::ConstantOp>(
615 loc, packedVecType, rewriter.getZeroAttr(packedVecType));
616 packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
617 /*position=*/0);
618
619 // Cast it to a vector with the desired scalar's type.
620 // E.g. f32 -> vector<2xf16>
621 VectorType packedType =
622 VectorType::get({expandRatio}, castDstType.getElementType());
623 Value castedValue =
624 rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
625
626 // Finally extract the desired scalar.
627 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
628 index % expandRatio);
629 return success();
630 }
631};
632
633// Shuffles vector.bitcast op after vector.extract_strided_slice op.
634//
635// This transforms IR like:
636// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
637// %0 = vector.extract_strided_slice %cast {
638// offsets = [4], sizes = [4], strides = [1]
639// } : vector<8xf16> to vector<4xf16>
640// Into:
641// %0 = vector.extract_strided_slice %src {
642// offsets = [2], sizes = [2], strides = [1]
643// } : vector<4xf32> to vector<2xf32>
644// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
645struct BubbleDownBitCastForStridedSliceExtract
646 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
647 using OpRewritePattern::OpRewritePattern;
648
649 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
650 PatternRewriter &rewriter) const override {
651 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
652 if (!castOp)
653 return failure();
654
655 VectorType castSrcType = castOp.getSourceVectorType();
656 VectorType castDstType = castOp.getResultVectorType();
657 assert(castSrcType.getRank() == castDstType.getRank());
658
659 int64_t castSrcLastDim = castSrcType.getShape().back();
660 int64_t castDstLastDim = castDstType.getShape().back();
661 // Require casting to more elements for now; other cases to be implemented.
662 if (castSrcLastDim > castDstLastDim)
663 return failure();
664
665 // Only accept all one strides for now.
666 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
667 [](const APInt &val) { return !val.isOne(); }))
668 return failure();
669
670 unsigned rank = extractOp.getSourceVectorType().getRank();
671 assert(castDstLastDim % castSrcLastDim == 0);
672 int64_t expandRatio = castDstLastDim / castSrcLastDim;
673
674 // If we have a less number of offsets than the rank, then implicitly we
675 // are selecting the full range for the last bitcasted dimension; other
676 // dimensions aren't affected. Otherwise, we need to scale down the last
677 // dimension's offset given we are extracting from less elements now.
678 ArrayAttr newOffsets = extractOp.getOffsets();
679 if (newOffsets.size() == rank) {
680 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
681 if (offsets.back() % expandRatio != 0)
682 return failure();
683 offsets.back() = offsets.back() / expandRatio;
684 newOffsets = rewriter.getI64ArrayAttr(offsets);
685 }
686
687 // Similarly for sizes.
688 ArrayAttr newSizes = extractOp.getSizes();
689 if (newSizes.size() == rank) {
690 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
691 if (sizes.back() % expandRatio != 0)
692 return failure();
693 sizes.back() = sizes.back() / expandRatio;
694 newSizes = rewriter.getI64ArrayAttr(sizes);
695 }
696
697 SmallVector<int64_t> dims =
698 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
699 dims.back() = dims.back() / expandRatio;
700 VectorType newExtractType =
701 VectorType::get(dims, castSrcType.getElementType());
702
703 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
704 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
705 newSizes, extractOp.getStrides());
706
707 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
708 extractOp, extractOp.getType(), newExtractOp);
709
710 return success();
711 }
712};
713
714// Shuffles vector.bitcast op before vector.insert_strided_slice op.
715//
716// This transforms IR like:
717// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
718// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
719// Into:
720// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
721// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
722// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
723//
724struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
725 using OpRewritePattern::OpRewritePattern;
726
727 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
728 PatternRewriter &rewriter) const override {
729 VectorType castSrcType = bitcastOp.getSourceVectorType();
730 VectorType castDstType = bitcastOp.getResultVectorType();
731
732 // 0-D and scalable vectors are not supported yet.
733 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
734 castDstType.isScalable())
735 return failure();
736
737 int64_t castSrcLastDim = castSrcType.getShape().back();
738 int64_t castDstLastDim = castDstType.getShape().back();
739 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
740 int64_t ratio;
741 if (isNumElemsShrink) {
742 assert(castSrcLastDim % castDstLastDim == 0);
743 ratio = castSrcLastDim / castDstLastDim;
744 } else {
745 assert(castDstLastDim % castSrcLastDim == 0);
746 ratio = castDstLastDim / castSrcLastDim;
747 }
748
749 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
750 if (!insertOp)
751 return failure();
752
753 // Only vector sources are supported for now.
754 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
755 if (!insertSrcType)
756 return failure();
757
758 // Bitcast the source.
759 SmallVector<int64_t> srcDims(insertSrcType.getShape());
760 srcDims.back() =
761 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
762 VectorType newCastSrcType =
763 VectorType::get(srcDims, castDstType.getElementType());
764 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
765 bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
766
767 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
768 dstDims.back() =
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
771 VectorType::get(dstDims, castDstType.getElementType());
772
773 // Bitcast the destination.
774 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
775 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
776
777 // Generate new insert.
778 rewriter.replaceOpWithNewOp<vector::InsertOp>(
779 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
780 return success();
781 }
782};
783
784// Shuffles vector.bitcast op before vector.insert_strided_slice op.
785//
786// This transforms IR like:
787// %0 = vector.insert_strided_slice %src, %dst {
788// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
789// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
790// Into:
791// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
792// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
793// %2 = vector.insert_strided_slice %src, %dst {
794// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
795struct BubbleUpBitCastForStridedSliceInsert
796 : public OpRewritePattern<vector::BitCastOp> {
797 using OpRewritePattern::OpRewritePattern;
798
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
800 PatternRewriter &rewriter) const override {
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
804 // Skip 0-D vector which will not from InsertStridedSliceOp.
805 if (castSrcType.getRank() == 0)
806 return failure();
807
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
810 // Require casting to less elements for now; other cases to be implemented.
811 if (castSrcLastDim < castDstLastDim)
812 return failure();
813
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
816
817 auto insertOp =
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
819 if (!insertOp)
820 return failure();
821
822 // Only accept all one strides for now.
823 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 [](const APInt &val) { return !val.isOne(); }))
825 return failure();
826
827 unsigned rank = insertOp.getSourceVectorType().getRank();
828 // Require insert op to have the same rank for the source and destination
829 // vector; other cases to be implemented.
830 if (rank != insertOp.getDestVectorType().getRank())
831 return failure();
832
833 // Requires that shape of insert op src is castable to dstType.
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
839 return failure();
840
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
843 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
844 if (offsets.back() % shrinkRatio != 0)
845 return failure();
846 offsets.back() = offsets.back() / shrinkRatio;
847 newOffsets = rewriter.getI64ArrayAttr(offsets);
848
849 SmallVector<int64_t> srcDims =
850 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
853 VectorType::get(srcDims, castDstType.getElementType());
854
855 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
856 bitcastOp.getLoc(), newCastSrcType, insertOp.getValueToStore());
857
858 SmallVector<int64_t> dstDims =
859 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
860 dstDims.back() = dstDims.back() / shrinkRatio;
861 VectorType newCastDstType =
862 VectorType::get(dstDims, castDstType.getElementType());
863
864 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
865 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
866
867 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
868 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
869 insertOp.getStrides());
870
871 return success();
872 }
873};
874
875// Breaks down vector.bitcast op
876//
877// This transforms IR like:
878// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
879// Into:
880// %cst = vector.splat %c0_f32 : vector<4xf32>
881// %1 = vector.extract_strided_slice %0 {
882// offsets = [0], sizes = [4], strides = [1]
883// } : vector<8xf16> to vector<4xf16>
884// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
885// %4 = vector.insert_strided_slice %2, %cst {
886// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
887// %5 = vector.extract_strided_slice %0 {
888// offsets = [4], sizes = [4], strides = [1]
889// } : vector<8xf16> to vector<4xf16>
890// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
891// %7 = vector.insert_strided_slice %6, %cst {
892// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
893struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
894 using OpRewritePattern::OpRewritePattern;
895
896public:
897 BreakDownVectorBitCast(MLIRContext *context,
898 std::function<bool(vector::BitCastOp)> controlFn,
899 PatternBenefit benefit)
900 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
901
902 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
903 PatternRewriter &rewriter) const override {
904
905 if (controlFn && !controlFn(bitcastOp))
906 return failure();
907
908 VectorType castSrcType = bitcastOp.getSourceVectorType();
909 VectorType castDstType = bitcastOp.getResultVectorType();
910 assert(castSrcType.getRank() == castDstType.getRank());
911
912 // This transformation builds on top of
913 // vector.{extract|insert}_strided_slice, which do not support
914 // extracting/inserting "scallable sub-vectors". Bail out.
915 if (castSrcType.isScalable())
916 return rewriter.notifyMatchFailure(bitcastOp,
917 "Scalable vectors are not supported");
918
919 // Only support rank 1 case for now.
920 if (castSrcType.getRank() != 1)
921 return failure();
922
923 int64_t castSrcLastDim = castSrcType.getShape().back();
924 int64_t castDstLastDim = castDstType.getShape().back();
925 // Require casting to less elements for now; other cases to be implemented.
926 if (castSrcLastDim < castDstLastDim)
927 return failure();
928
929 assert(castSrcLastDim % castDstLastDim == 0);
930 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
931 // Nothing to do if it is already bitcasting to a single element.
932 if (castSrcLastDim == shrinkRatio)
933 return failure();
934
935 Location loc = bitcastOp.getLoc();
936 Type elemType = castDstType.getElementType();
937 assert(elemType.isSignlessIntOrIndexOrFloat());
938
939 Value zero = rewriter.create<arith::ConstantOp>(
940 loc, elemType, rewriter.getZeroAttr(elemType));
941 Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
942
943 SmallVector<int64_t> sliceShape = {castDstLastDim};
944 SmallVector<int64_t> strides = {1};
945 VectorType newCastDstType =
946 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
947 castDstType.getElementType());
948
949 for (int i = 0, e = shrinkRatio; i < e; ++i) {
950 Value extracted = rewriter.create<ExtractStridedSliceOp>(
951 loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
952 sliceShape, strides);
953 Value bitcast =
954 rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
955 res = rewriter.create<InsertStridedSliceOp>(
956 loc, bitcast, res,
957 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
958 }
959 rewriter.replaceOp(bitcastOp, res);
960 return success();
961 }
962
963private:
964 std::function<bool(BitCastOp)> controlFn;
965};
966
967/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
968///
969/// Example:
970/// ```
971/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
972/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
973/// %r = arith.addi %a, %b : vector<1x4xindex>
974/// ```
975/// Gets converted to:
976/// ```
977/// %r = arith.addi %arg0, %arg1 : index
978/// %b = vector.broadcast %r : index to vector<1x4xindex>
979/// ```
980///
981/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
982/// ops.
983struct ReorderElementwiseOpsOnBroadcast final
984 : public OpTraitRewritePattern<OpTrait::Elementwise> {
985 using OpTraitRewritePattern::OpTraitRewritePattern;
986 LogicalResult matchAndRewrite(Operation *op,
987 PatternRewriter &rewriter) const override {
988 if (op->getNumResults() != 1)
989 return failure();
990 if (!llvm::isa<ShapedType>(Val: op->getResults()[0].getType()))
991 return failure();
992 if (!OpTrait::hasElementwiseMappableTraits(op))
993 return rewriter.notifyMatchFailure(
994 arg&: op, msg: "Op doesn't have ElementwiseMappableTraits");
995 if (op->getNumOperands() == 0)
996 return failure();
997 if (op->getResults()[0].getType() != op->getOperand(idx: 0).getType())
998 return rewriter.notifyMatchFailure(arg&: op,
999 msg: "result and operand type mismatch");
1000 if (isa<vector::FMAOp>(op)) {
1001 return rewriter.notifyMatchFailure(
1002 arg&: op,
1003 msg: "Op only accepts vector types - not supported as broadcast source "
1004 "might be a scalar");
1005 }
1006
1007 // Get the type of the lhs operand
1008 auto *lhsBcastOrSplat = op->getOperand(idx: 0).getDefiningOp();
1009 if (!lhsBcastOrSplat ||
1010 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
1011 return failure();
1012 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(idx: 0).getType();
1013
1014 // Make sure that all operands are broadcast from identical types:
1015 // * scalar (`vector.broadcast` + `vector.splat`), or
1016 // * vector (`vector.broadcast`).
1017 // Otherwise the re-ordering wouldn't be safe.
1018 if (!llvm::all_of(Range: op->getOperands(), P: [&lhsBcastOrSplatType](Value val) {
1019 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1020 if (bcast)
1021 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1022 auto splat = val.getDefiningOp<vector::SplatOp>();
1023 if (splat)
1024 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1025 return false;
1026 })) {
1027 return failure();
1028 }
1029
1030 // Collect the source values before broadcasting
1031 SmallVector<Value> srcValues;
1032 srcValues.reserve(N: op->getNumOperands());
1033 for (Value operand : op->getOperands()) {
1034 srcValues.push_back(Elt: operand.getDefiningOp()->getOperand(idx: 0));
1035 }
1036
1037 // Create the "elementwise" Op
1038 Operation *elementwiseOp =
1039 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1040 lhsBcastOrSplatType, op->getAttrs());
1041
1042 // Replace the original Op with the elementwise Op
1043 auto vectorType = op->getResultTypes()[0];
1044 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1045 op, vectorType, elementwiseOp->getResults());
1046
1047 return success();
1048 }
1049};
1050
1051/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1052/// This may result in cleaner code when extracting a single value
1053/// from multi-element vector and also to help canonicalize 1-element vectors to
1054/// scalars.
1055///
1056/// Example:
1057/// ```
1058/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1059/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1060/// ```
1061/// Gets converted to:
1062/// ```
1063/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1064/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1065/// %2 = arith.addf %0, %1 : f32
1066/// ```
1067class ExtractOpFromElementwise final
1068 : public OpRewritePattern<vector::ExtractOp> {
1069public:
1070 using OpRewritePattern::OpRewritePattern;
1071
1072 LogicalResult matchAndRewrite(vector::ExtractOp op,
1073 PatternRewriter &rewriter) const override {
1074 Operation *eltwise = op.getVector().getDefiningOp();
1075
1076 // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1077 // as it doesn't support scalars.
1078 if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1079 isa<vector::FMAOp>(eltwise))
1080 return rewriter.notifyMatchFailure(op, "not an elementwise op");
1081
1082 if (eltwise->getNumResults() != 1)
1083 return rewriter.notifyMatchFailure(op, "expected single result");
1084
1085 if (!eltwise->hasOneUse())
1086 return rewriter.notifyMatchFailure(op, "expected single op use");
1087
1088 if (!llvm::all_equal(Range: eltwise->getOperandTypes()))
1089 return rewriter.notifyMatchFailure(op, "operand types are different");
1090
1091 Type dstType = op.getType();
1092
1093 OpBuilder::InsertionGuard g(rewriter);
1094 rewriter.setInsertionPoint(eltwise);
1095
1096 IRMapping mapping;
1097 Location loc = eltwise->getLoc();
1098 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1099 for (Value arg : eltwise->getOperands()) {
1100 Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
1101 mapping.map(arg, newArg);
1102 }
1103
1104 Operation *newEltwise = rewriter.clone(op&: *eltwise, mapper&: mapping);
1105 newEltwise->getResult(idx: 0).setType(dstType);
1106
1107 rewriter.replaceOp(op, newEltwise);
1108 rewriter.eraseOp(op: eltwise);
1109 return success();
1110 }
1111};
1112
1113/// Check if the element type is suitable for vector.load/store sinking.
1114/// Element type must be index or byte-aligned integer or floating-point type.
1115static bool isSupportedMemSinkElementType(Type type) {
1116 if (isa<IndexType>(Val: type))
1117 return true;
1118
1119 return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1120}
1121
1122/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1123/// Only index and byte-aligned integer and floating-point element types are
1124/// supported for now.
1125///
1126/// Example:
1127/// ```
1128/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1129/// vector.extract %0[1] : f32 from vector<4xf32>
1130/// ```
1131/// Gets converted to:
1132/// ```
1133/// %c1 = arith.constant 1 : index
1134/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1135/// %1 = memref.load %arg0[%0] : memref<?xf32>
1136/// ```
1137class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1138public:
1139 using OpRewritePattern::OpRewritePattern;
1140
1141 LogicalResult matchAndRewrite(vector::ExtractOp op,
1142 PatternRewriter &rewriter) const override {
1143 auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1144 if (!loadOp)
1145 return rewriter.notifyMatchFailure(op, "expected a load op");
1146
1147 // Checking for single use so we won't duplicate load ops.
1148 if (!loadOp->hasOneUse())
1149 return rewriter.notifyMatchFailure(op, "expected single op use");
1150
1151 VectorType loadVecType = loadOp.getVectorType();
1152 if (loadVecType.isScalable())
1153 return rewriter.notifyMatchFailure(op,
1154 "scalable vectors are not supported");
1155
1156 MemRefType memType = loadOp.getMemRefType();
1157
1158 // Non-byte-aligned types are tricky and may require special handling,
1159 // ignore them for now.
1160 if (!isSupportedMemSinkElementType(memType.getElementType()))
1161 return rewriter.notifyMatchFailure(op, "unsupported element type");
1162
1163 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1164 if (rankOffset < 0)
1165 return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1166
1167 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1168 int64_t finalRank = 0;
1169 if (extractVecType)
1170 finalRank = extractVecType.getRank();
1171
1172 SmallVector<Value> indices = loadOp.getIndices();
1173 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1174
1175 // There may be memory stores between the load and the extract op, so we
1176 // need to make sure that the new load op is inserted at the same place as
1177 // the original load op.
1178 OpBuilder::InsertionGuard g(rewriter);
1179 rewriter.setInsertionPoint(loadOp);
1180 Location loc = loadOp.getLoc();
1181 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1182 for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1183 OpFoldResult pos = extractPos[i - rankOffset];
1184 if (isZeroInteger(pos))
1185 continue;
1186
1187 Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1188 indices[i] = idxBuilderf.add(indices[i], offset);
1189 }
1190
1191 Value base = loadOp.getBase();
1192 if (extractVecType) {
1193 rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1194 indices);
1195 } else {
1196 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1197 }
1198 // We checked for single use so we can safely erase the load op.
1199 rewriter.eraseOp(op: loadOp);
1200 return success();
1201 }
1202};
1203
1204/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1205///
1206/// Example:
1207/// ```
1208/// %0 = vector.splat %arg2 : vector<1xf32>
1209/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1210/// ```
1211/// Gets converted to:
1212/// ```
1213/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1214/// ```
1215class StoreOpFromSplatOrBroadcast final
1216 : public OpRewritePattern<vector::StoreOp> {
1217public:
1218 using OpRewritePattern::OpRewritePattern;
1219
1220 LogicalResult matchAndRewrite(vector::StoreOp op,
1221 PatternRewriter &rewriter) const override {
1222 VectorType vecType = op.getVectorType();
1223 if (vecType.isScalable())
1224 return rewriter.notifyMatchFailure(op,
1225 "scalable vectors are not supported");
1226
1227 if (isa<VectorType>(op.getMemRefType().getElementType()))
1228 return rewriter.notifyMatchFailure(
1229 op, "memrefs of vectors are not supported");
1230
1231 if (vecType.getNumElements() != 1)
1232 return rewriter.notifyMatchFailure(
1233 op, "only 1-element vectors are supported");
1234
1235 Operation *splat = op.getValueToStore().getDefiningOp();
1236 if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1237 return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
1238
1239 // Checking for single use so we can remove splat.
1240 if (!splat->hasOneUse())
1241 return rewriter.notifyMatchFailure(op, "expected single op use");
1242
1243 Value source = splat->getOperand(idx: 0);
1244 Value base = op.getBase();
1245 ValueRange indices = op.getIndices();
1246
1247 if (isa<VectorType>(Val: source.getType())) {
1248 rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1249 } else {
1250 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1251 }
1252 rewriter.eraseOp(op: splat);
1253 return success();
1254 }
1255};
1256
1257// Helper that returns a vector comparison that constructs a mask:
1258// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1259//
1260// If `dim == 0` then the result will be a 0-D vector.
1261//
1262// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1263// much more compact, IR for this operation, but LLVM eventually
1264// generates more elaborate instructions for this intrinsic since it
1265// is very conservative on the boundary conditions.
1266static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1267 bool force32BitVectorIndices, int64_t dim,
1268 Value b, Value *off = nullptr) {
1269 auto loc = op->getLoc();
1270 // If we can assume all indices fit in 32-bit, we perform the vector
1271 // comparison in 32-bit to get a higher degree of SIMD parallelism.
1272 // Otherwise we perform the vector comparison using 64-bit indices.
1273 Type idxType =
1274 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1275 DenseIntElementsAttr indicesAttr;
1276 if (dim == 0 && force32BitVectorIndices) {
1277 indicesAttr = DenseIntElementsAttr::get(
1278 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1279 } else if (dim == 0) {
1280 indicesAttr = DenseIntElementsAttr::get(
1281 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1282 } else if (force32BitVectorIndices) {
1283 indicesAttr = rewriter.getI32VectorAttr(
1284 values: llvm::to_vector<4>(Range: llvm::seq<int32_t>(Begin: 0, End: dim)));
1285 } else {
1286 indicesAttr = rewriter.getI64VectorAttr(
1287 values: llvm::to_vector<4>(Range: llvm::seq<int64_t>(Begin: 0, End: dim)));
1288 }
1289 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1290 // Add in an offset if requested.
1291 if (off) {
1292 Value o = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: *off);
1293 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1294 indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1295 }
1296 // Construct the vector comparison.
1297 Value bound = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: b);
1298 Value bounds =
1299 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1300 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1301 bounds);
1302}
1303
1304template <typename ConcreteOp>
1305struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1306public:
1307 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1308 PatternBenefit benefit = 1)
1309 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1310 force32BitVectorIndices(enableIndexOpt) {}
1311
1312 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1313 PatternRewriter &rewriter) const override {
1314 if (!xferOp.hasOutOfBoundsDim())
1315 return failure();
1316
1317 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1318 return failure();
1319
1320 Location loc = xferOp->getLoc();
1321 VectorType vtp = xferOp.getVectorType();
1322
1323 // Create the in-bounds mask with all elements between [0 .. dim - offset)
1324 // set and [dim - offset .. vector_length) unset.
1325 //
1326 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1327 // dimensions here.
1328 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1329 Value off = xferOp.getIndices()[lastIndex];
1330 Value dim =
1331 vector::createOrFoldDimOp(b&: rewriter, loc, source: xferOp.getBase(), dim: lastIndex);
1332 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1333 Value mask = rewriter.create<vector::CreateMaskOp>(
1334 loc,
1335 VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1336 vtp.getScalableDims()),
1337 b);
1338 if (xferOp.getMask()) {
1339 // Intersect the in-bounds with the mask specified as an op parameter.
1340 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1341 }
1342
1343 rewriter.modifyOpInPlace(xferOp, [&]() {
1344 xferOp.getMaskMutable().assign(mask);
1345 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1346 });
1347
1348 return success();
1349 }
1350
1351private:
1352 const bool force32BitVectorIndices;
1353};
1354
1355/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1356class VectorCreateMaskOpConversion
1357 : public OpRewritePattern<vector::CreateMaskOp> {
1358public:
1359 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1360 bool enableIndexOpt,
1361 PatternBenefit benefit = 1)
1362 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1363 force32BitVectorIndices(enableIndexOpt) {}
1364
1365 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1366 PatternRewriter &rewriter) const override {
1367 auto dstType = op.getType();
1368 if (cast<VectorType>(dstType).isScalable())
1369 return failure();
1370 int64_t rank = dstType.getRank();
1371 if (rank > 1)
1372 return failure();
1373 rewriter.replaceOp(
1374 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1375 rank == 0 ? 0 : dstType.getDimSize(0),
1376 op.getOperand(0)));
1377 return success();
1378 }
1379
1380private:
1381 const bool force32BitVectorIndices;
1382};
1383
1384/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1385static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1386 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1387 // TODO: Support non-dense constant.
1388 if (!denseAttr)
1389 return false;
1390
1391 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1392 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1393}
1394
1395/// Folds a select operation between an all-true and all-false vector. For now,
1396/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1397///
1398/// %true = arith.constant dense<true> : vector<1xi1>
1399/// %false = arith.constant dense<false> : vector<1xi1>
1400/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1401/// =>
1402/// %result = vector.broadcast %cond : i1 to vector<1xi1>
1403///
1404/// InstCombine seems to handle vectors with multiple elements but not the
1405/// single element ones.
1406struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1407 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1408
1409 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1410 PatternRewriter &rewriter) const override {
1411 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1412 if (!vecType || !vecType.getElementType().isInteger(1))
1413 return failure();
1414
1415 // Only scalar conditions can be folded.
1416 Value cond = selectOp.getCondition();
1417 if (isa<VectorType>(Val: cond.getType()))
1418 return failure();
1419
1420 // TODO: Support n-D and scalable vectors.
1421 if (vecType.getRank() != 1 || vecType.isScalable())
1422 return failure();
1423
1424 // TODO: Support vectors with multiple elements.
1425 if (vecType.getShape()[0] != 1)
1426 return failure();
1427
1428 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1429 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1430 return failure();
1431
1432 auto falseConst =
1433 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1434 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1435 return failure();
1436
1437 // Replace select with its condition broadcasted to single element vector.
1438 auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1439 auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1440 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1441 return success();
1442 }
1443};
1444
1445/// Returns the number of dims can be folded away from transfer ops. It returns
1446/// a failure if it can not determine the number of dims to be folded.
1447///
1448/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1449/// `vectorType` is vector<16x16x1x1xf32>
1450/// (there two inner most dims can be dropped by memref.subview ops)
1451///
1452/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1453/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1454/// (only the inner most unit dim of `srcType` can be dropped)
1455///
1456/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1457/// `vectorType` is vector<16x16x1x[1]xf32>
1458/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1459/// unit")
1460static FailureOr<size_t>
1461getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1462 SmallVector<int64_t> srcStrides;
1463 int64_t srcOffset;
1464 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1465 return failure();
1466
1467 auto isUnitDim = [](VectorType type, int dim) {
1468 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1469 };
1470
1471 // According to vector.transfer_read/write semantics, the vector can be a
1472 // slice. Thus, we have to offset the check index with `rankDiff` in
1473 // `srcStrides` and source dim sizes.
1474 size_t result = 0;
1475 int rankDiff = srcType.getRank() - vectorType.getRank();
1476 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1477 // Check that the inner dim size is 1 for both memref type and vector slice.
1478 // It can be folded only if they are 1 and the stride is 1.
1479 int dim = vectorType.getRank() - i - 1;
1480 if (srcStrides[dim + rankDiff] != 1 ||
1481 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1482 break;
1483 result++;
1484 }
1485 return result;
1486}
1487
1488/// Drop inner most contiguous unit dimensions from transfer_read operand.
1489class DropInnerMostUnitDimsTransferRead
1490 : public OpRewritePattern<vector::TransferReadOp> {
1491 using OpRewritePattern::OpRewritePattern;
1492
1493 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1494 PatternRewriter &rewriter) const override {
1495 // TODO: support 0-d corner case.
1496 if (readOp.getTransferRank() == 0)
1497 return failure();
1498
1499 // TODO: support mask.
1500 if (readOp.getMask())
1501 return failure();
1502
1503 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1504 if (!srcType)
1505 return failure();
1506
1507 if (!readOp.getPermutationMap().isMinorIdentity())
1508 return failure();
1509
1510 auto targetType = readOp.getVectorType();
1511 if (targetType.getRank() <= 1)
1512 return failure();
1513
1514 FailureOr<size_t> maybeDimsToDrop =
1515 getTransferFoldableInnerUnitDims(srcType, targetType);
1516 if (failed(Result: maybeDimsToDrop))
1517 return failure();
1518
1519 size_t dimsToDrop = maybeDimsToDrop.value();
1520 if (dimsToDrop == 0)
1521 return failure();
1522
1523 auto inBounds = readOp.getInBoundsValues();
1524 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(N: dimsToDrop);
1525 if (llvm::is_contained(droppedInBounds, false))
1526 return failure();
1527
1528 auto resultTargetVecType =
1529 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1530 targetType.getElementType(),
1531 targetType.getScalableDims().drop_back(dimsToDrop));
1532
1533 auto loc = readOp.getLoc();
1534 SmallVector<OpFoldResult> sizes =
1535 memref::getMixedSizes(builder&: rewriter, loc: loc, value: readOp.getBase());
1536 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1537 rewriter.getIndexAttr(0));
1538 SmallVector<OpFoldResult> strides(srcType.getRank(),
1539 rewriter.getIndexAttr(1));
1540 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1541 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1542 strides);
1543 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1544 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1545 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1546 loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides);
1547 auto permMap = getTransferMinorIdentityMap(
1548 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1549 Value result = rewriter.create<vector::TransferReadOp>(
1550 loc, resultTargetVecType, rankedReducedView,
1551 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1552 readOp.getPadding(),
1553 // TODO: support mask.
1554 /*mask=*/Value(), inBoundsAttr);
1555 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1556 result);
1557 return success();
1558 }
1559};
1560
1561/// Drop inner most contiguous unit dimensions from transfer_write operand.
1562/// E.g.,
1563/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1564/// {in_bounds = [true, true, true, true, true]}
1565/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1566///
1567/// will be replaced with
1568///
1569/// %subview = memref.subview %arg0
1570/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1571/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1572/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1573/// to vector<1x16x16xf32>
1574/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1575/// {in_bounds = [true, true, true]}
1576/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1577///
1578/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1579class DropInnerMostUnitDimsTransferWrite
1580 : public OpRewritePattern<vector::TransferWriteOp> {
1581 using OpRewritePattern::OpRewritePattern;
1582
1583 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1584 PatternRewriter &rewriter) const override {
1585 // TODO: support 0-d corner case.
1586 if (writeOp.getTransferRank() == 0)
1587 return failure();
1588
1589 // TODO: support mask.
1590 if (writeOp.getMask())
1591 return failure();
1592
1593 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1594 if (!srcType)
1595 return failure();
1596
1597 if (!writeOp.getPermutationMap().isMinorIdentity())
1598 return failure();
1599
1600 auto targetType = writeOp.getVectorType();
1601 if (targetType.getRank() <= 1)
1602 return failure();
1603
1604 FailureOr<size_t> maybeDimsToDrop =
1605 getTransferFoldableInnerUnitDims(srcType, targetType);
1606 if (failed(Result: maybeDimsToDrop))
1607 return failure();
1608
1609 size_t dimsToDrop = maybeDimsToDrop.value();
1610 if (dimsToDrop == 0)
1611 return failure();
1612
1613 auto inBounds = writeOp.getInBoundsValues();
1614 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(N: dimsToDrop);
1615 if (llvm::is_contained(droppedInBounds, false))
1616 return failure();
1617
1618 auto resultTargetVecType =
1619 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1620 targetType.getElementType(),
1621 targetType.getScalableDims().drop_back(dimsToDrop));
1622
1623 Location loc = writeOp.getLoc();
1624 SmallVector<OpFoldResult> sizes =
1625 memref::getMixedSizes(builder&: rewriter, loc, value: writeOp.getBase());
1626 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1627 rewriter.getIndexAttr(0));
1628 SmallVector<OpFoldResult> strides(srcType.getRank(),
1629 rewriter.getIndexAttr(1));
1630 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1631 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1632 strides);
1633 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1634 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1635
1636 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1637 loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides);
1638 auto permMap = getTransferMinorIdentityMap(
1639 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1640
1641 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1642 loc, resultTargetVecType, writeOp.getVector());
1643 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1644 writeOp, shapeCast, rankedReducedView,
1645 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1646 // TODO: support mask.
1647 /*mask=*/Value(), inBoundsAttr);
1648 return success();
1649 }
1650};
1651
1652/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1653/// semantics to a contraction suitable for MMT (matrix matrix multiplication
1654/// with the RHS transposed) lowering.
1655struct CanonicalizeContractMatmulToMMT final
1656 : OpRewritePattern<vector::ContractionOp> {
1657 using OpRewritePattern::OpRewritePattern;
1658
1659 using FilterConstraintType =
1660 std::function<LogicalResult(vector::ContractionOp op)>;
1661
1662 CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1663 FilterConstraintType constraint)
1664 : OpRewritePattern<vector::ContractionOp>(context, benefit),
1665 filter(std::move(constraint)) {}
1666
1667 LogicalResult matchAndRewrite(vector::ContractionOp op,
1668 PatternRewriter &rewriter) const override {
1669 if (failed(filter(op)))
1670 return failure();
1671
1672 Location loc = op.getLoc();
1673 Value lhs = op.getLhs();
1674 Value rhs = op.getRhs();
1675 Value res = op.getAcc();
1676
1677 // Set up the parallel/reduction structure in right form.
1678 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1679 auto infer = [&](MapList m) {
1680 return AffineMap::inferFromExprList(m, op.getContext());
1681 };
1682 AffineExpr m;
1683 AffineExpr n;
1684 AffineExpr k;
1685 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
1686 static constexpr std::array<int64_t, 2> perm = {1, 0};
1687 auto iteratorTypes = op.getIteratorTypes().getValue();
1688 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1689 if (iteratorTypes.size() != 3 ||
1690 !vector::isParallelIterator(attr: iteratorTypes[0]) ||
1691 !vector::isParallelIterator(attr: iteratorTypes[1]) ||
1692 !vector::isReductionIterator(attr: iteratorTypes[2]))
1693 return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1694
1695 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1696 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1697 if (maps == canonicalForm)
1698 return rewriter.notifyMatchFailure(op, "already in the canonical form");
1699
1700 // Create a vector transpose making sure to emit zero/sign-extend at the
1701 // end.
1702 auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1703 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1704 Value trans =
1705 rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1706 VectorType newType =
1707 cast<VectorType>(trans.getType())
1708 .clone(cast<VectorType>(mat.getType()).getElementType());
1709 return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1710 }
1711 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1712 Value trans =
1713 rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1714 VectorType newType =
1715 VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1716 cast<VectorType>(mat.getType()).getElementType());
1717 return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1718 }
1719 return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1720 };
1721
1722 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1723 rhs = createTranspose(rhs);
1724 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1725 lhs = createTranspose(lhs);
1726 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1727 rhs = createTranspose(rhs);
1728 lhs = createTranspose(lhs);
1729 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1730 std::swap(a&: rhs, b&: lhs);
1731 rhs = createTranspose(rhs);
1732 lhs = createTranspose(lhs);
1733 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1734 std::swap(a&: rhs, b&: lhs);
1735 rhs = createTranspose(rhs);
1736 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1737 std::swap(a&: lhs, b&: rhs);
1738 lhs = createTranspose(lhs);
1739 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1740 std::swap(a&: lhs, b&: rhs);
1741 } else {
1742 return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1743 }
1744 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1745 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(values: canonicalForm),
1746 op.getIteratorTypes());
1747 return success();
1748 };
1749
1750private:
1751 FilterConstraintType filter;
1752};
1753
1754/// Pattern to fold arithmetic extensions on floating point data types into
1755/// vector contraction operations. linalg.matmul introduces arithmetic
1756/// extensions on its operands. Please mlir snippets below for more details.
1757/// ```mlir
1758/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1759/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1760/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1761/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1762/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1763/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1764/// "linalg.yield"(%acc) : (f32) -> ()
1765/// })
1766/// ```
1767/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1768/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1769/// This pattern folds the arithmetic extensions into the vector contraction and
1770/// enables the usage of native mixed precision Tensor Core instructions.
1771template <typename ExtOp>
1772struct FoldArithExtIntoContractionOp
1773 : public OpRewritePattern<vector::ContractionOp> {
1774 using OpRewritePattern::OpRewritePattern;
1775
1776 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1777 PatternRewriter &rewriter) const override {
1778
1779 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1780 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1781
1782 if (!lhsDefOp || !rhsDefOp) {
1783 return rewriter.notifyMatchFailure(contractOp,
1784 "no defining op on contract operands");
1785 }
1786
1787 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1788 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1789 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1790 contractOp.getIteratorTypesAttr());
1791
1792 return success();
1793 }
1794};
1795
1796/// Pattern to fold chained reduction to a series of vector additions and a
1797/// final reduction. This form should require fewer subgroup operations.
1798///
1799/// ```mlir
1800/// %a = vector.reduction <add> %x, %acc
1801/// %b = vector.reduction <add> %y, %a
1802/// ==>
1803/// %a = arith.addf %x, %y
1804/// %b = vector.reduction <add> %a, %acc
1805/// ```
1806struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1807 using OpRewritePattern::OpRewritePattern;
1808
1809 LogicalResult matchAndRewrite(vector::ReductionOp op,
1810 PatternRewriter &rewriter) const override {
1811 // TODO: Handle other combining kinds.
1812 if (op.getKind() != vector::CombiningKind::ADD)
1813 return failure();
1814
1815 // Accumulator is optional.
1816 Value acc = op.getAcc();
1817 if (!acc)
1818 return failure();
1819
1820 if (!acc.getType().isIntOrFloat())
1821 return failure();
1822
1823 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1824 if (!parentReduction)
1825 return failure();
1826
1827 Location loc = op.getLoc();
1828 Value vAdd;
1829 if (isa<IntegerType>(Val: acc.getType())) {
1830 vAdd = rewriter.createOrFold<arith::AddIOp>(
1831 loc, parentReduction.getVector(), op.getVector());
1832 } else {
1833 vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1834 op.getVector());
1835 }
1836 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1837 parentReduction.getAcc());
1838 return success();
1839 }
1840};
1841
1842// Helper function dropping unit non-scalable dimension from a VectorType
1843// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1844// dimensions are not dropped. Folding such dimensions would require "shifting"
1845// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1846// vector<[4]xf32>). This could be implemented in the future.
1847static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1848 auto inVecShape = inVecTy.getShape();
1849 SmallVector<int64_t> newShape;
1850 SmallVector<bool> newScalableDims;
1851 for (auto [dim, isScalable] :
1852 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1853 if (dim == 1 && !isScalable)
1854 continue;
1855
1856 newShape.push_back(dim);
1857 newScalableDims.push_back(isScalable);
1858 }
1859 // All dims have been dropped, return vector<1xeType>.
1860 if (newShape.empty()) {
1861 newShape.push_back(Elt: 1);
1862 newScalableDims.push_back(Elt: false);
1863 }
1864
1865 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1866}
1867
1868/// For vectors with at least one unit dim, replaces:
1869/// elementwise(a, b)
1870/// with:
1871/// sc_a = shape_cast(a)
1872/// sc_b = shape_cast(b)
1873/// res = elementwise(sc_a, sc_b)
1874/// return shape_cast(res)
1875/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1876/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1877/// required to be rank > 1.
1878///
1879/// Ex:
1880/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1881/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1882///
1883/// gets converted to:
1884///
1885/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1886/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1887/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1888/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1889/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1890///
1891/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1892/// `%cast`.
1893struct DropUnitDimFromElementwiseOps final
1894 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1895 using OpTraitRewritePattern::OpTraitRewritePattern;
1896 LogicalResult matchAndRewrite(Operation *op,
1897 PatternRewriter &rewriter) const override {
1898 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1899 return failure();
1900
1901 auto resultVectorType = dyn_cast<VectorType>(op->getResult(idx: 0).getType());
1902 if (!resultVectorType)
1903 return failure();
1904
1905 // Check the operand pre-conditions. For `Elementwise` ops all operands are
1906 // guaranteed to have identical shapes (with some exceptions such as
1907 // `arith.select`) and it suffices to only check one of them.
1908 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(idx: 0).getType());
1909 if (!sourceVectorType)
1910 return failure();
1911 if (sourceVectorType.getRank() < 2)
1912 return failure();
1913
1914 SmallVector<Value> newOperands;
1915 auto loc = op->getLoc();
1916 for (auto operand : op->getOperands()) {
1917 auto opVectorType = cast<VectorType>(operand.getType());
1918 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1919 if (newVType == opVectorType)
1920 return rewriter.notifyMatchFailure(arg&: op, msg: "No unit dimension to remove.");
1921
1922 auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1923 newOperands.push_back(Elt: opSC);
1924 }
1925
1926 VectorType newResultVectorType =
1927 dropNonScalableUnitDimFromType(resultVectorType);
1928 // Create an updated elementwise Op without unit dim.
1929 Operation *elementwiseOp =
1930 rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1931 newResultVectorType, op->getAttrs());
1932
1933 // Restore the unit dim by applying vector.shape_cast to the result.
1934 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1935 elementwiseOp->getResult(0));
1936
1937 return success();
1938 }
1939};
1940
1941/// A pattern to drop unit dims from vector.transpose.
1942///
1943/// Example:
1944///
1945/// BEFORE:
1946/// ```mlir
1947/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
1948/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
1949/// ```
1950///
1951/// AFTER:
1952/// ```mlir
1953/// %dropDims = vector.shape_cast %vector
1954/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
1955/// %transpose = vector.transpose %0, [1, 0]
1956/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1957/// %restoreDims = vector.shape_cast %transpose
1958/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
1959/// ```
1960struct DropUnitDimsFromTransposeOp final
1961 : OpRewritePattern<vector::TransposeOp> {
1962 using OpRewritePattern::OpRewritePattern;
1963
1964 LogicalResult matchAndRewrite(vector::TransposeOp op,
1965 PatternRewriter &rewriter) const override {
1966 VectorType sourceType = op.getSourceVectorType();
1967 VectorType sourceTypeWithoutUnitDims =
1968 dropNonScalableUnitDimFromType(sourceType);
1969
1970 if (sourceType == sourceTypeWithoutUnitDims)
1971 return failure();
1972
1973 // Construct a map from dimIdx -> number of dims dropped before dimIdx.
1974 auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
1975 SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
1976 int64_t droppedDims = 0;
1977 for (auto [i, dim] : llvm::enumerate(sourceDims)) {
1978 droppedDimsBefore[i] = droppedDims;
1979 if (dim == std::make_tuple(1, false))
1980 ++droppedDims;
1981 }
1982
1983 // Drop unit dims from transpose permutation.
1984 ArrayRef<int64_t> perm = op.getPermutation();
1985 SmallVector<int64_t> newPerm;
1986 for (int64_t idx : perm) {
1987 if (sourceDims[idx] == std::make_tuple(1, false))
1988 continue;
1989 newPerm.push_back(idx - droppedDimsBefore[idx]);
1990 }
1991
1992 // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
1993 // type when the dimensions are unit dimensions. In this case, the newPerm
1994 // should be [0].
1995 if (newPerm.empty()) {
1996 newPerm.push_back(Elt: 0);
1997 }
1998
1999 Location loc = op.getLoc();
2000 // Drop the unit dims via shape_cast.
2001 auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
2002 loc, sourceTypeWithoutUnitDims, op.getVector());
2003 // Create the new transpose.
2004 auto transposeWithoutUnitDims =
2005 rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
2006 // Restore the unit dims via shape cast.
2007 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2008 op, op.getResultVectorType(), transposeWithoutUnitDims);
2009
2010 return success();
2011 }
2012};
2013
2014/// A pattern to drop unit dims from the iter_args of an scf.for.
2015///
2016/// Example:
2017///
2018/// BEFORE:
2019/// ```mlir
2020/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
2021/// ...
2022/// scf.yield %
2023/// }
2024/// ```
2025///
2026/// AFTER:
2027/// ```mlir
2028/// %drop = vector.shape_cast %init
2029/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
2030/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
2031/// %new_iter = vector.shape_cast %iter
2032/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2033/// ...
2034/// }
2035/// %res = vector.shape_cast %new_loop
2036/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2037/// ```
2038struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
2039 using OpRewritePattern::OpRewritePattern;
2040
2041 LogicalResult matchAndRewrite(scf::ForOp forOp,
2042 PatternRewriter &rewriter) const override {
2043 /// Find the first iter_arg with droppable unit dims. Further applications
2044 /// of this pattern will apply to later arguments.
2045 for (OpOperand &operand : forOp.getInitArgsMutable()) {
2046 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2047 if (!vectorType)
2048 continue;
2049
2050 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2051 if (vectorType == newVectorType)
2052 continue;
2053
2054 // Create a new ForOp with that iter operand replaced.
2055 auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
2056 return b.create<vector::ShapeCastOp>(loc, type, source);
2057 };
2058
2059 Value replacement =
2060 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2061 rewriter.replaceOp(forOp,
2062 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2063 replacement, castFn));
2064 return success();
2065 }
2066 return failure();
2067 }
2068};
2069
2070/// Pattern to eliminate redundant zero-constants added to reduction operands.
2071/// It's enough for there to be one initial zero value, so we can eliminate the
2072/// extra ones that feed into `vector.reduction <add>`. These get created by the
2073/// `ChainedReduction` pattern.
2074///
2075/// ```mlir
2076/// %a = arith.addf %x, %zero
2077/// %b = arith.addf %a, %y
2078/// %c = vector.reduction <add> %b, %acc
2079/// ==>
2080/// %b = arith.addf %a, %y
2081/// %c = vector.reduction <add> %b, %acc
2082/// ```
2083struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2084 using OpRewritePattern::OpRewritePattern;
2085
2086 LogicalResult matchAndRewrite(vector::ReductionOp op,
2087 PatternRewriter &rewriter) const override {
2088 // TODO: Handle other reduction kinds and their identity values.
2089 if (op.getKind() != vector::CombiningKind::ADD)
2090 return failure();
2091
2092 Type elemType = op.getSourceVectorType().getElementType();
2093 // The integer case should be handled by `arith.addi` folders, only check
2094 // for floats here.
2095 if (!isa<FloatType>(Val: elemType))
2096 return failure();
2097
2098 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2099 if (!vAdd)
2100 return failure();
2101 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2102 if (!addLhs)
2103 return failure();
2104
2105 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
2106 return failure();
2107
2108 auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
2109 vAdd.getRhs());
2110 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
2111 op.getAcc());
2112 return success();
2113 }
2114};
2115
2116/// Example:
2117/// ```
2118/// %a = vector.reduction <add> %x : vector<2xf32> into f32
2119/// ```
2120/// is transformed into:
2121/// ```
2122/// %y = vector.extract %x[0] : f32 from vector<2xf32>
2123/// %z = vector.extract %x[1] : f32 from vector<2xf32>
2124/// %a = arith.addf %y, %z : f32
2125/// ```
2126struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2127 BreakDownVectorReduction(MLIRContext *context,
2128 unsigned maxNumElementsToExtract,
2129 PatternBenefit benefit)
2130 : OpRewritePattern(context, benefit),
2131 maxNumElementsToExtract(maxNumElementsToExtract) {}
2132
2133 LogicalResult matchAndRewrite(vector::ReductionOp op,
2134 PatternRewriter &rewriter) const override {
2135 VectorType type = op.getSourceVectorType();
2136 if (type.isScalable() || op.isMasked())
2137 return failure();
2138 assert(type.getRank() == 1 && "Expected a 1-d vector");
2139
2140 int64_t numElems = type.getNumElements();
2141 if (numElems > maxNumElementsToExtract) {
2142 return rewriter.notifyMatchFailure(
2143 op, llvm::formatv(Fmt: "has too many vector elements ({0}) to break down "
2144 "(max allowed: {1})",
2145 Vals&: numElems, Vals: maxNumElementsToExtract));
2146 }
2147
2148 Location loc = op.getLoc();
2149 SmallVector<Value> extracted(numElems, nullptr);
2150 for (auto [idx, extractedElem] : llvm::enumerate(extracted))
2151 extractedElem = rewriter.create<vector::ExtractOp>(
2152 loc, op.getVector(), static_cast<int64_t>(idx));
2153
2154 Value res = extracted.front();
2155 for (auto extractedElem : llvm::drop_begin(extracted))
2156 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
2157 extractedElem, op.getFastmathAttr());
2158 if (Value acc = op.getAcc())
2159 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
2160 op.getFastmathAttr());
2161
2162 rewriter.replaceOp(op, res);
2163 return success();
2164 }
2165
2166private:
2167 unsigned maxNumElementsToExtract = 0;
2168};
2169
2170/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2171/// B)`.
2172/// Example:
2173/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2174/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2175/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2176/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2177///
2178/// Becomes :
2179///
2180/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2181///
2182/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2183/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2184/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2185/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2186template <typename MulOpType>
2187struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2188 using OpRewritePattern<MulOpType>::OpRewritePattern;
2189 // Returns whether a vector.broadcast matches requirements for an outerproduct
2190 // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2191 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2192 // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2193 // shape_casts/broadcasts which does not belong in this pattern.
2194 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2195 return false;
2196 // Avoid broadcast like f32 or vector<f32> -> ResType
2197 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2198 return srcType && srcType.getRank() != 2;
2199 }
2200
2201 LogicalResult matchAndRewrite(MulOpType mulOp,
2202 PatternRewriter &rewriter) const override {
2203 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
2204 if (!resType)
2205 return failure();
2206 if (resType.getRank() != 2)
2207 return failure();
2208 /// If operandA can be written as tr(broadcast(A)) and operandB as
2209 /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2210 /// vector.outerproduct(A, B). Returns failure() otherwise.
2211 auto matchOuterProduct =
2212 [&](Value operandA,
2213 Value operandB) -> FailureOr<vector::OuterProductOp> {
2214 auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2215 if (!transposedLhs)
2216 return failure();
2217 // Fail unless this is a true 2-D matrix transpose.
2218 ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2219 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2220 return failure();
2221
2222 auto broadcastedLhs =
2223 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2224 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2225 return failure();
2226
2227 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2228 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2229 return failure();
2230
2231 return rewriter.create<vector::OuterProductOp>(
2232 mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2233 broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2234 };
2235
2236 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2237 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2238 // Handle commutativity, the transposed op is the outerproduct LHS.
2239 if (failed(maybeOuterP))
2240 maybeOuterP = matchOuterProduct(rhs, lhs);
2241 if (failed(maybeOuterP))
2242 return failure();
2243 rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2244 return success();
2245 }
2246};
2247
2248} // namespace
2249
2250void mlir::vector::populateFoldArithExtensionPatterns(
2251 RewritePatternSet &patterns) {
2252 patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
2253 FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
2254 patterns.getContext());
2255}
2256
2257void mlir::vector::populateVectorMaskMaterializationPatterns(
2258 RewritePatternSet &patterns, bool force32BitVectorIndices,
2259 PatternBenefit benefit) {
2260 patterns.add<VectorCreateMaskOpConversion,
2261 MaterializeTransferMask<vector::TransferReadOp>,
2262 MaterializeTransferMask<vector::TransferWriteOp>>(
2263 arg: patterns.getContext(), args&: force32BitVectorIndices, args&: benefit);
2264 patterns.add<FoldI1Select>(arg: patterns.getContext(), args&: benefit);
2265}
2266
2267void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2268 RewritePatternSet &patterns, PatternBenefit benefit) {
2269 // TODO: Consider either:
2270 // * including DropInnerMostUnitDimsTransferRead and
2271 // DropInnerMostUnitDimsTransferWrite, or
2272 // * better naming to distinguish this and
2273 // populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
2274 patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
2275 DropUnitDimsFromTransposeOp>(arg: patterns.getContext(), args&: benefit);
2276}
2277
2278void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2279 RewritePatternSet &patterns, PatternBenefit benefit) {
2280 patterns.add<BubbleDownVectorBitCastForExtract,
2281 BubbleDownBitCastForStridedSliceExtract,
2282 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2283 arg: patterns.getContext(), args&: benefit);
2284}
2285
2286void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2287 RewritePatternSet &patterns,
2288 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2289 patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2290 std::move(controlFn), benefit);
2291}
2292
2293void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
2294 RewritePatternSet &patterns,
2295 std::function<LogicalResult(vector::ContractionOp)> constraint,
2296 PatternBenefit benefit) {
2297 patterns.add<CanonicalizeContractMatmulToMMT>(arg: patterns.getContext(), args&: benefit,
2298 args: std::move(constraint));
2299}
2300
2301void mlir::vector::populateVectorReductionToContractPatterns(
2302 RewritePatternSet &patterns, PatternBenefit benefit) {
2303 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2304 CombineContractABTranspose, CombineContractResultTranspose>(
2305 arg: patterns.getContext(), args&: benefit);
2306}
2307
2308void mlir::vector::
2309 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
2310 RewritePatternSet &patterns, PatternBenefit benefit) {
2311 patterns.add<DropInnerMostUnitDimsTransferRead,
2312 DropInnerMostUnitDimsTransferWrite>(arg: patterns.getContext(),
2313 args&: benefit);
2314}
2315
2316void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
2317 PatternBenefit benefit) {
2318 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2319 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2320 arg: patterns.getContext(), args&: benefit);
2321}
2322
2323void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2324 PatternBenefit benefit) {
2325 // TODO: Consider converting these patterns to canonicalizations.
2326 patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2327 arg: patterns.getContext(), args&: benefit);
2328}
2329
2330void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2331 RewritePatternSet &patterns, PatternBenefit benefit) {
2332 patterns.add<ChainedReduction>(arg: patterns.getContext(), args&: benefit);
2333 patterns.add<ReduceRedundantZero>(arg: patterns.getContext(),
2334 args: PatternBenefit(benefit.getBenefit() + 1));
2335}
2336
2337void mlir::vector::populateBreakDownVectorReductionPatterns(
2338 RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2339 PatternBenefit benefit) {
2340 patterns.add<BreakDownVectorReduction>(arg: patterns.getContext(),
2341 args&: maxNumElementsToExtract, args&: benefit);
2342}
2343
2344void mlir::vector::populateElementwiseToVectorOpsPatterns(
2345 RewritePatternSet &patterns) {
2346 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2347 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2348 patterns.getContext());
2349}
2350
2351//===----------------------------------------------------------------------===//
2352// TableGen'd enum attribute definitions
2353//===----------------------------------------------------------------------===//
2354
2355#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
2356

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp