1//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites as 1->N patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
14
15#include <cassert>
16#include <cstdint>
17#include <functional>
18#include <optional>
19#include <type_traits>
20
21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/Arith/Utils/Utils.h"
24#include "mlir/Dialect/Linalg/IR/Linalg.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/Tensor/IR/Tensor.h"
28#include "mlir/Dialect/Utils/IndexingUtils.h"
29#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
30#include "mlir/Dialect/Vector/IR/VectorOps.h"
31#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
32#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
33#include "mlir/IR/BuiltinAttributeInterfaces.h"
34#include "mlir/IR/BuiltinTypes.h"
35#include "mlir/IR/ImplicitLocOpBuilder.h"
36#include "mlir/IR/Location.h"
37#include "mlir/IR/Matchers.h"
38#include "mlir/IR/PatternMatch.h"
39#include "mlir/IR/TypeUtilities.h"
40#include "mlir/Interfaces/VectorInterfaces.h"
41#include "mlir/Support/LogicalResult.h"
42
43#include "llvm/ADT/DenseSet.h"
44#include "llvm/ADT/MapVector.h"
45#include "llvm/ADT/STLExtras.h"
46#include "llvm/Support/CommandLine.h"
47#include "llvm/Support/Debug.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/raw_ostream.h"
50
51#define DEBUG_TYPE "vector-to-vector"
52
53using namespace mlir;
54using namespace mlir::vector;
55
56template <typename IntType>
57static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
58 return llvm::to_vector<4>(llvm::map_range(
59 arrayAttr.getAsRange<IntegerAttr>(),
60 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
61}
62
63// Helper to find an index in an affine map.
64static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
65 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
66 int64_t idx = map.getDimPosition(idx: i);
67 if (idx == index)
68 return i;
69 }
70 return std::nullopt;
71}
72
73namespace {
74
75/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
76//
77// Example:
78//
79// The following MLIR with cancelling ShapeCastOps:
80//
81// %0 = source : vector<5x4x2xf32>
82// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
83// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
84// %3 = user %2 : vector<5x4x2xf32>
85//
86// Should canonicalize to the following:
87//
88// %0 = source : vector<5x4x2xf32>
89// %1 = user %0 : vector<5x4x2xf32>
90//
91struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
92 using OpRewritePattern::OpRewritePattern;
93
94 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
95 PatternRewriter &rewriter) const override {
96 // Check if 'shapeCastOp' has vector source/result type.
97 auto sourceVectorType =
98 dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
99 auto resultVectorType =
100 dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
101 if (!sourceVectorType || !resultVectorType)
102 return failure();
103
104 // Check if shape cast op source operand is also a shape cast op.
105 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
106 shapeCastOp.getSource().getDefiningOp());
107 if (!sourceShapeCastOp)
108 return failure();
109 auto operandSourceVectorType =
110 cast<VectorType>(sourceShapeCastOp.getSource().getType());
111 auto operandResultVectorType = sourceShapeCastOp.getType();
112
113 // Check if shape cast operations invert each other.
114 if (operandSourceVectorType != resultVectorType ||
115 operandResultVectorType != sourceVectorType)
116 return failure();
117
118 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
119 return success();
120 }
121};
122
123/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
124/// Ex:
125/// ```
126/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
127/// %1 = vector.multi_reduction add, %0 [1]
128/// : vector<8x32x16xf32> to vector<8x16xf32>
129/// ```
130/// Gets converted to:
131/// ```
132/// %1 = vector.contract {indexing_maps = [
133/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
135/// affine_map<(d0, d1, d2) -> (d0, d1)>],
136/// iterator_types = ["parallel", "parallel", "reduction"],
137/// kind = add} %0, %arg1, %cst_f0
138/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
139/// ```
140struct MultiReduceToContract
141 : public OpRewritePattern<vector::MultiDimReductionOp> {
142 using OpRewritePattern::OpRewritePattern;
143
144 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
145 PatternRewriter &rewriter) const override {
146 if (reduceOp.getKind() != vector::CombiningKind::ADD)
147 return failure();
148 Operation *mulOp = reduceOp.getSource().getDefiningOp();
149 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
150 return failure();
151 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
152 auto srcMap = rewriter.getMultiDimIdentityMap(rank: reductionMask.size());
153 SmallVector<AffineExpr> exprs;
154 SmallVector<vector::IteratorType> iteratorTypes;
155 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
156 if (!isReduceDim.value()) {
157 iteratorTypes.push_back(vector::IteratorType::parallel);
158 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
159 } else {
160 iteratorTypes.push_back(vector::IteratorType::reduction);
161 }
162 }
163 auto dstMap =
164 AffineMap::get(/*dimCount=*/reductionMask.size(),
165 /*symbolCount=*/0, exprs, reduceOp.getContext());
166 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
167 reduceOp, mulOp->getOperand(idx: 0), mulOp->getOperand(idx: 1), reduceOp.getAcc(),
168 rewriter.getAffineMapArrayAttr(values: {srcMap, srcMap, dstMap}),
169 rewriter.getArrayAttr(value: llvm::to_vector(llvm::map_range(
170 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
171 return IteratorTypeAttr::get(rewriter.getContext(), t);
172 }))));
173 return success();
174 }
175};
176
177/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
178/// Ex:
179/// ```
180/// %0 = vector.transpose %arg0, [2, 0, 1]
181/// : vector<32x16x8xf32> to vector<8x32x16xf32>
182/// %1 = vector.contract {indexing_maps = [
183/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
184/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
185/// affine_map<(d0, d1, d2) -> (d0, d1)>],
186/// iterator_types = ["parallel", "parallel", "reduction"],
187/// kind = add} %0, %arg1, %cst_f0
188/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
189/// ```
190/// Gets converted to:
191/// ```
192/// %1 = vector.contract {indexing_maps = [
193/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
194/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
195/// affine_map<(d0, d1, d2) -> (d0, d1)>],
196/// iterator_types = ["parallel", "parallel", "reduction"],
197/// kind = add} %arg0, %arg1, %cst_f0
198/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
199/// ```
200struct CombineContractABTranspose final
201 : public OpRewritePattern<vector::ContractionOp> {
202 using OpRewritePattern::OpRewritePattern;
203
204 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
205 PatternRewriter &rewriter) const override {
206 SmallVector<AffineMap> maps =
207 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
208 Value lhs = contractOp.getLhs();
209 Value rhs = contractOp.getRhs();
210 size_t index = 0;
211 bool changed = false;
212 for (Value *operand : {&lhs, &rhs}) {
213 AffineMap &map = maps[index++];
214 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
215 if (!transposeOp)
216 continue;
217 AffineMap permutationMap = AffineMap::getPermutationMap(
218 transposeOp.getPermutation(), contractOp.getContext());
219 map = inversePermutation(permutationMap).compose(map);
220 *operand = transposeOp.getVector();
221 changed = true;
222 }
223 if (!changed)
224 return failure();
225 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
226 contractOp, lhs, rhs, contractOp.getAcc(),
227 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
228 return success();
229 }
230};
231
232/// Merges accumulator and result transposes into contract.
233///
234/// For example:
235/// ```mlir
236/// %accT = vector.transpose %acc, [0, 2, 1]
237/// : vector<2x8x4xf32> to vector<2x4x8xf32>
238/// %contract = vector.contract {
239/// indexing_maps = [
240/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
241/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
242/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
243/// ],
244/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
245/// kind = #vector.kind<add>
246/// } %lhs, %rhs, %accT
247/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
248/// %0 = vector.transpose %contract, [0, 2, 1]
249/// : vector<2x4x8xf32> to vector<2x8x4>
250/// ```
251/// Becomes:
252/// ```mlir
253/// %0 = vector.contract {
254/// indexing_maps = [
255/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
256/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
257/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
258/// ],
259/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
260/// kind = #vector.kind<add>
261/// } %lhs, %rhs, %acc
262/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
263/// ```
264struct CombineContractResultTranspose final
265 : public OpRewritePattern<vector::TransposeOp> {
266 using OpRewritePattern::OpRewritePattern;
267
268 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
269 PatternRewriter &rewriter) const override {
270 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
271 if (!contractOp || !contractOp->hasOneUse())
272 return failure();
273
274 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
275 if (!accTOp)
276 return failure();
277
278 MLIRContext *context = contractOp.getContext();
279 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
280 AffineMap contractMap = maps.back();
281
282 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
283 // To index into A in contract, we need revert(f)(g(C)) -> A.
284 auto accTMap =
285 AffineMap::getPermutationMap(accTOp.getPermutation(), context);
286
287 // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
288 // To index into E in contract, we need h(g(C)) -> E.
289 auto resTMap =
290 AffineMap::getPermutationMap(resTOp.getPermutation(), context);
291 auto combinedResMap = resTMap.compose(contractMap);
292
293 // The accumulator and result share the same indexing map. So they should be
294 // the same to be able to merge. This means combinedResMap is the same as
295 // inversePermutation(accTMap).compose(contractMap), which means
296 if (inversePermutation(accTMap) != resTMap)
297 return failure();
298 maps.back() = combinedResMap;
299
300 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
301 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
302 rewriter.getAffineMapArrayAttr(values: maps), contractOp.getIteratorTypes());
303 return success();
304 }
305};
306
307/// Merge BroadcastOp into ContractionOp user.
308/// Ex:
309/// ```
310/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
311/// %1 = vector.contract {indexing_maps = [
312/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
313/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
314/// affine_map<(d0, d1, d2) -> (d0, d1)>],
315/// iterator_types = ["parallel", "parallel", "reduction"],
316/// kind = add} %0, %arg1, %cst_f0
317/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
318/// ```
319/// Gets converted to:
320/// ```
321/// %1 = vector.contract {indexing_maps = [
322/// affine_map<(d0, d1, d2) -> (d1, d2)>,
323/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
324/// affine_map<(d0, d1, d2) -> (d0, d1)>],
325/// iterator_types = ["parallel", "parallel", "reduction"],
326/// kind = add} %arg0, %arg1, %cst_f0
327/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
328/// ```
329struct CombineContractBroadcast
330 : public OpRewritePattern<vector::ContractionOp> {
331 using OpRewritePattern::OpRewritePattern;
332
333 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
334 PatternRewriter &rewriter) const override {
335 SmallVector<AffineMap> maps =
336 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
337 Value lhs = contractOp.getLhs();
338 Value rhs = contractOp.getRhs();
339 size_t index = 0;
340 bool changed = false;
341 for (Value *operand : {&lhs, &rhs}) {
342 AffineMap &map = maps[index++];
343 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
344 if (!broadcast)
345 continue;
346 // contractionOp can only take vector as operands.
347 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
348 if (!srcType ||
349 srcType.getRank() == broadcast.getResultVectorType().getRank())
350 continue;
351 int64_t rankDiff =
352 broadcast.getResultVectorType().getRank() - srcType.getRank();
353 bool innerDimBroadcast = false;
354 SmallVector<AffineExpr> originalDims;
355 for (const auto &dim : llvm::enumerate(srcType.getShape())) {
356 if (dim.value() != broadcast.getResultVectorType().getDimSize(
357 rankDiff + dim.index())) {
358 innerDimBroadcast = true;
359 break;
360 }
361 originalDims.push_back(
362 rewriter.getAffineDimExpr(dim.index() + rankDiff));
363 }
364 // Contract doesn't support inner dimension broadcast. Once this is
365 // relaxed we can remove this case.
366 if (innerDimBroadcast)
367 continue;
368
369 // It would be incorrect to fold a broadcast onto a reduction dimension
370 // of non-unit size.
371 bool nonUnitDimReductionBroadcast = false;
372 for (int64_t i = 0; i < rankDiff; ++i) {
373 if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
374 isReductionIterator(contractOp.getIteratorTypes()
375 .getValue()[map.getDimPosition(i)])) {
376 nonUnitDimReductionBroadcast = true;
377 break;
378 }
379 }
380 if (nonUnitDimReductionBroadcast)
381 continue;
382
383 AffineMap broadcastMap =
384 AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
385 originalDims, contractOp.getContext());
386 map = broadcastMap.compose(map);
387 *operand = broadcast.getSource();
388 changed = true;
389 }
390
391 if (!changed)
392 return failure();
393
394 // Determine which dims are usused, now that the maps have been composed
395 // with the broadcast maps.
396 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
397 // Compress unused dims.
398 for (auto &m : maps)
399 m = compressDims(m, unusedDimsBitVector);
400 // Compute the combined iterators.
401 SmallVector<Attribute> iterators;
402 for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
403 if (!unusedDimsBitVector.test(Idx: i))
404 iterators.push_back(Elt: contractOp.getIteratorTypes().getValue()[i]);
405 }
406 // Check that compressing unused dims isn't removing all reduction dimension
407 // pairs. For example, if the vector.contract had only one reduction
408 // iterator and that was a unit-dimension created by a broadcast,
409 // then we should bail here, otherwise we would create a contract without
410 // a reduction dimension pair.
411 bool hasReductionIteratorApplyingOnBothSides = false;
412 for (unsigned i = 0; i < iterators.size(); ++i) {
413 if (!isReductionIterator(attr: iterators[i]))
414 continue;
415 if (getResultIndex(map: maps[0], index: i) && getResultIndex(map: maps[1], index: i)) {
416 hasReductionIteratorApplyingOnBothSides = true;
417 break;
418 }
419 }
420 if (!hasReductionIteratorApplyingOnBothSides)
421 return failure();
422
423 // If the compressed maps have a dimension that is not used by either LHS or
424 // RHS then the ContractionOp verifier would fail.
425 if (getUnusedDimsBitVector(maps: {maps[0], maps[1]}).any())
426 return failure();
427 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
428 contractOp, lhs, rhs, contractOp.getAcc(),
429 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
430 return success();
431 }
432};
433
434/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
435/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
436/// casting ops are around these operations.
437/// Ex:
438/// ```
439/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
440/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
441/// ```
442/// Gets converted to:
443/// ```
444/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
445/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
446/// ```
447struct ReorderCastOpsOnBroadcast
448 : public OpInterfaceRewritePattern<CastOpInterface> {
449 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
450
451 LogicalResult matchAndRewrite(CastOpInterface op,
452 PatternRewriter &rewriter) const override {
453 if (op->getNumOperands() != 1)
454 return failure();
455 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
456 if (!bcastOp)
457 return failure();
458
459 Type castResTy = getElementTypeOrSelf(op->getResult(0));
460 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
461 castResTy = vecTy.clone(castResTy);
462 auto *castOp =
463 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
464 bcastOp.getSource(), castResTy, op->getAttrs());
465 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
466 op, op->getResult(0).getType(), castOp->getResult(0));
467 return success();
468 }
469};
470
471/// Reorders elementwise(transpose) to transpose(elementwise). This makes
472/// transpose ops and contraction ops closer, which kicks in
473/// CombineContractABTranspose pattern when elementwise ops are between these
474/// operations. Ex:
475/// ```
476/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
477/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
478/// %r = arith.addf %at, %bt : vector<2x4xf32>
479/// ```
480/// Gets converted to:
481/// ```
482/// %0 = arith.addf %a, %b : vector<4x2xf32>
483/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
484/// ```
485struct ReorderElementwiseOpsOnTranspose final
486 : public OpTraitRewritePattern<OpTrait::Elementwise> {
487 using OpTraitRewritePattern::OpTraitRewritePattern;
488 LogicalResult matchAndRewrite(Operation *op,
489 PatternRewriter &rewriter) const override {
490 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
491 return failure();
492
493 // Make sure all operands are transpose/constant ops and collect their
494 // transposition maps.
495 SmallVector<ArrayRef<int64_t>> transposeMaps;
496 transposeMaps.reserve(N: op->getNumOperands());
497 // Record the initial type before transposition. We'll use its shape later.
498 // Any type will do here as we will check all transpose maps are the same.
499 VectorType srcType;
500 for (Value operand : op->getOperands()) {
501 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
502 if (transposeOp) {
503 transposeMaps.push_back(Elt: transposeOp.getPermutation());
504 srcType = transposeOp.getSourceVectorType();
505 } else if (!matchPattern(value: operand, pattern: m_Constant())) {
506 return failure();
507 }
508 }
509 if (transposeMaps.empty())
510 return failure();
511 // This is an elementwise op, so all transposed operands should have the
512 // same type. We need to additionally check that all transposes uses the
513 // same map.
514 if (!llvm::all_equal(Range&: transposeMaps))
515 return rewriter.notifyMatchFailure(arg&: op, msg: "different transpose map");
516
517 SmallVector<Value> srcValues;
518 srcValues.reserve(N: op->getNumOperands());
519
520 // If there are constant operands, we need to insert inverse transposes for
521 // them. Calculate the inverse order first.
522 auto order = transposeMaps.front();
523 SmallVector<int64_t> invOrder(order.size());
524 for (int i = 0, e = order.size(); i < e; ++i)
525 invOrder[order[i]] = i;
526
527 for (Value operand : op->getOperands()) {
528 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
529 if (transposeOp) {
530 srcValues.push_back(Elt: transposeOp.getVector());
531 } else {
532 // This is a constant. Create a reverse transpose op for it.
533 auto vectorType =
534 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
535 srcValues.push_back(rewriter.create<vector::TransposeOp>(
536 operand.getLoc(), vectorType, operand, invOrder));
537 }
538 }
539
540 auto vectorType = srcType.clone(
541 cast<VectorType>(op->getResultTypes()[0]).getElementType());
542 Operation *elementwiseOp =
543 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
544 vectorType, op->getAttrs());
545 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
546 op, op->getResultTypes()[0], elementwiseOp->getResult(0),
547 transposeMaps.front());
548 return success();
549 }
550};
551
552// Returns the values in `arrayAttr` as an integer vector.
553static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
554 return llvm::to_vector<4>(
555 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
556 [](IntegerAttr attr) { return attr.getInt(); }));
557}
558
559// Shuffles vector.bitcast op after vector.extract op.
560//
561// This transforms IR like:
562// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
563// %1 = vector.extract %0[3] : f16 from vector<8xf16>
564// Into:
565// %0 = vector.extract %src[1] : f32 from vector<4xf32>
566// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
567// %2 = vector.extract %1[1] : f16 from vector<2xf16>
568struct BubbleDownVectorBitCastForExtract
569 : public OpRewritePattern<vector::ExtractOp> {
570 using OpRewritePattern::OpRewritePattern;
571
572 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
573 PatternRewriter &rewriter) const override {
574 // Only support extracting scalars for now.
575 if (extractOp.getSourceVectorType().getRank() != 1)
576 return failure();
577
578 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
579 if (!castOp)
580 return failure();
581
582 VectorType castSrcType = castOp.getSourceVectorType();
583 VectorType castDstType = castOp.getResultVectorType();
584 assert(castSrcType.getRank() == castDstType.getRank());
585
586 // Fail to match if we only have one element in the cast op source.
587 // This is to avoid infinite loop given that this pattern can generate
588 // such cases.
589 if (castSrcType.getNumElements() == 1)
590 return failure();
591
592 // Only support casting to a larger number of elements or now.
593 // E.g., vector<4xf32> -> vector<8xf16>.
594 if (castSrcType.getNumElements() > castDstType.getNumElements())
595 return failure();
596
597 unsigned expandRatio =
598 castDstType.getNumElements() / castSrcType.getNumElements();
599
600 auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
601 assert(values[0].is<Attribute>() && "Unexpected non-constant index");
602 return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
603 };
604
605 uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
606
607 // Get the single scalar (as a vector) in the source value that packs the
608 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
609 Location loc = extractOp.getLoc();
610 Value packedValue = rewriter.create<vector::ExtractOp>(
611 loc, castOp.getSource(), index / expandRatio);
612 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
613 Value zero = rewriter.create<arith::ConstantOp>(
614 loc, packedVecType, rewriter.getZeroAttr(packedVecType));
615 packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
616 /*position=*/0);
617
618 // Cast it to a vector with the desired scalar's type.
619 // E.g. f32 -> vector<2xf16>
620 VectorType packedType =
621 VectorType::get({expandRatio}, castDstType.getElementType());
622 Value castedValue =
623 rewriter.create<vector::BitCastOp>(loc, packedType, packedValue);
624
625 // Finally extract the desired scalar.
626 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
627 index % expandRatio);
628 return success();
629 }
630};
631
632// Shuffles vector.bitcast op after vector.extract_strided_slice op.
633//
634// This transforms IR like:
635// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
636// %0 = vector.extract_strided_slice %cast {
637// offsets = [4], sizes = [4], strides = [1]
638// } : vector<8xf16> to vector<4xf16>
639// Into:
640// %0 = vector.extract_strided_slice %src {
641// offsets = [2], sizes = [2], strides = [1]
642// } : vector<4xf32> to vector<2xf32>
643// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
644struct BubbleDownBitCastForStridedSliceExtract
645 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
646 using OpRewritePattern::OpRewritePattern;
647
648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649 PatternRewriter &rewriter) const override {
650 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
651 if (!castOp)
652 return failure();
653
654 VectorType castSrcType = castOp.getSourceVectorType();
655 VectorType castDstType = castOp.getResultVectorType();
656 assert(castSrcType.getRank() == castDstType.getRank());
657
658 int64_t castSrcLastDim = castSrcType.getShape().back();
659 int64_t castDstLastDim = castDstType.getShape().back();
660 // Require casting to more elements for now; other cases to be implemented.
661 if (castSrcLastDim > castDstLastDim)
662 return failure();
663
664 // Only accept all one strides for now.
665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666 [](const APInt &val) { return !val.isOne(); }))
667 return failure();
668
669 unsigned rank = extractOp.getSourceVectorType().getRank();
670 assert(castDstLastDim % castSrcLastDim == 0);
671 int64_t expandRatio = castDstLastDim / castSrcLastDim;
672
673 // If we have a less number of offsets than the rank, then implicitly we
674 // are selecting the full range for the last bitcasted dimension; other
675 // dimensions aren't affected. Otherwise, we need to scale down the last
676 // dimension's offset given we are extracting from less elements now.
677 ArrayAttr newOffsets = extractOp.getOffsets();
678 if (newOffsets.size() == rank) {
679 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
680 if (offsets.back() % expandRatio != 0)
681 return failure();
682 offsets.back() = offsets.back() / expandRatio;
683 newOffsets = rewriter.getI64ArrayAttr(offsets);
684 }
685
686 // Similarly for sizes.
687 ArrayAttr newSizes = extractOp.getSizes();
688 if (newSizes.size() == rank) {
689 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
690 if (sizes.back() % expandRatio != 0)
691 return failure();
692 sizes.back() = sizes.back() / expandRatio;
693 newSizes = rewriter.getI64ArrayAttr(sizes);
694 }
695
696 SmallVector<int64_t> dims =
697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698 dims.back() = dims.back() / expandRatio;
699 VectorType newExtractType =
700 VectorType::get(dims, castSrcType.getElementType());
701
702 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
703 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets,
704 newSizes, extractOp.getStrides());
705
706 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
707 extractOp, extractOp.getType(), newExtractOp);
708
709 return success();
710 }
711};
712
713// Shuffles vector.bitcast op before vector.insert_strided_slice op.
714//
715// This transforms IR like:
716// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
717// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
718// Into:
719// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
720// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
721// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
722//
723struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
724 using OpRewritePattern::OpRewritePattern;
725
726 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727 PatternRewriter &rewriter) const override {
728 VectorType castSrcType = bitcastOp.getSourceVectorType();
729 VectorType castDstType = bitcastOp.getResultVectorType();
730
731 // 0-D and scalable vectors are not supported yet.
732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733 castDstType.isScalable())
734 return failure();
735
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739 int64_t ratio;
740 if (isNumElemsShrink) {
741 assert(castSrcLastDim % castDstLastDim == 0);
742 ratio = castSrcLastDim / castDstLastDim;
743 } else {
744 assert(castDstLastDim % castSrcLastDim == 0);
745 ratio = castDstLastDim / castSrcLastDim;
746 }
747
748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
749 if (!insertOp)
750 return failure();
751
752 // Only vector sources are supported for now.
753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType());
754 if (!insertSrcType)
755 return failure();
756
757 // Bitcast the source.
758 SmallVector<int64_t> srcDims(insertSrcType.getShape());
759 srcDims.back() =
760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761 VectorType newCastSrcType =
762 VectorType::get(srcDims, castDstType.getElementType());
763 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
764 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
765
766 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
767 dstDims.back() =
768 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
769 VectorType newCastDstType =
770 VectorType::get(dstDims, castDstType.getElementType());
771
772 // Bitcast the destination.
773 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
774 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
775
776 // Generate new insert.
777 rewriter.replaceOpWithNewOp<vector::InsertOp>(
778 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
779 return success();
780 }
781};
782
783// Shuffles vector.bitcast op before vector.insert_strided_slice op.
784//
785// This transforms IR like:
786// %0 = vector.insert_strided_slice %src, %dst {
787// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
788// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
789// Into:
790// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
791// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
792// %2 = vector.insert_strided_slice %src, %dst {
793// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
794struct BubbleUpBitCastForStridedSliceInsert
795 : public OpRewritePattern<vector::BitCastOp> {
796 using OpRewritePattern::OpRewritePattern;
797
798 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
799 PatternRewriter &rewriter) const override {
800 VectorType castSrcType = bitcastOp.getSourceVectorType();
801 VectorType castDstType = bitcastOp.getResultVectorType();
802 assert(castSrcType.getRank() == castDstType.getRank());
803 // Skip 0-D vector which will not from InsertStridedSliceOp.
804 if (castSrcType.getRank() == 0)
805 return failure();
806
807 int64_t castSrcLastDim = castSrcType.getShape().back();
808 int64_t castDstLastDim = castDstType.getShape().back();
809 // Require casting to less elements for now; other cases to be implemented.
810 if (castSrcLastDim < castDstLastDim)
811 return failure();
812
813 assert(castSrcLastDim % castDstLastDim == 0);
814 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
815
816 auto insertOp =
817 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
818 if (!insertOp)
819 return failure();
820
821 // Only accept all one strides for now.
822 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
823 [](const APInt &val) { return !val.isOne(); }))
824 return failure();
825
826 unsigned rank = insertOp.getSourceVectorType().getRank();
827 // Require insert op to have the same rank for the source and destination
828 // vector; other cases to be implemented.
829 if (rank != insertOp.getDestVectorType().getRank())
830 return failure();
831
832 // Requires that shape of insert op src is castable to dstType.
833 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
834 unsigned destinationWidth =
835 castDstType.getElementType().getIntOrFloatBitWidth();
836 unsigned numElements = destinationWidth / sourceWidth;
837 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
838 return failure();
839
840 ArrayAttr newOffsets = insertOp.getOffsets();
841 assert(newOffsets.size() == rank);
842 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
843 if (offsets.back() % shrinkRatio != 0)
844 return failure();
845 offsets.back() = offsets.back() / shrinkRatio;
846 newOffsets = rewriter.getI64ArrayAttr(offsets);
847
848 SmallVector<int64_t> srcDims =
849 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
850 srcDims.back() = srcDims.back() / shrinkRatio;
851 VectorType newCastSrcType =
852 VectorType::get(srcDims, castDstType.getElementType());
853
854 auto newCastSrcOp = rewriter.create<vector::BitCastOp>(
855 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource());
856
857 SmallVector<int64_t> dstDims =
858 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
859 dstDims.back() = dstDims.back() / shrinkRatio;
860 VectorType newCastDstType =
861 VectorType::get(dstDims, castDstType.getElementType());
862
863 auto newCastDstOp = rewriter.create<vector::BitCastOp>(
864 bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
865
866 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
867 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
868 insertOp.getStrides());
869
870 return success();
871 }
872};
873
874// Breaks down vector.bitcast op
875//
876// This transforms IR like:
877// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
878// Into:
879// %cst = vector.splat %c0_f32 : vector<4xf32>
880// %1 = vector.extract_strided_slice %0 {
881// offsets = [0], sizes = [4], strides = [1]
882// } : vector<8xf16> to vector<4xf16>
883// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
884// %4 = vector.insert_strided_slice %2, %cst {
885// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
886// %5 = vector.extract_strided_slice %0 {
887// offsets = [4], sizes = [4], strides = [1]
888// } : vector<8xf16> to vector<4xf16>
889// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
890// %7 = vector.insert_strided_slice %6, %cst {
891// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
892struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
893 using OpRewritePattern::OpRewritePattern;
894
895public:
896 BreakDownVectorBitCast(MLIRContext *context,
897 std::function<bool(vector::BitCastOp)> controlFn,
898 PatternBenefit benefit)
899 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
900
901 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
902 PatternRewriter &rewriter) const override {
903
904 if (controlFn && !controlFn(bitcastOp))
905 return failure();
906
907 VectorType castSrcType = bitcastOp.getSourceVectorType();
908 VectorType castDstType = bitcastOp.getResultVectorType();
909 assert(castSrcType.getRank() == castDstType.getRank());
910
911 // Only support rank 1 case for now.
912 if (castSrcType.getRank() != 1)
913 return failure();
914
915 int64_t castSrcLastDim = castSrcType.getShape().back();
916 int64_t castDstLastDim = castDstType.getShape().back();
917 // Require casting to less elements for now; other cases to be implemented.
918 if (castSrcLastDim < castDstLastDim)
919 return failure();
920
921 assert(castSrcLastDim % castDstLastDim == 0);
922 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
923 // Nothing to do if it is already bitcasting to a single element.
924 if (castSrcLastDim == shrinkRatio)
925 return failure();
926
927 Location loc = bitcastOp.getLoc();
928 Type elemType = castDstType.getElementType();
929 assert(elemType.isSignlessIntOrIndexOrFloat());
930
931 Value zero = rewriter.create<arith::ConstantOp>(
932 loc, elemType, rewriter.getZeroAttr(elemType));
933 Value res = rewriter.create<SplatOp>(loc, castDstType, zero);
934
935 SmallVector<int64_t> sliceShape{castDstLastDim};
936 SmallVector<int64_t> strides{1};
937 VectorType newCastDstType =
938 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
939 castDstType.getElementType());
940
941 for (int i = 0, e = shrinkRatio; i < e; ++i) {
942 Value extracted = rewriter.create<ExtractStridedSliceOp>(
943 loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim},
944 sliceShape, strides);
945 Value bitcast =
946 rewriter.create<BitCastOp>(loc, newCastDstType, extracted);
947 res = rewriter.create<InsertStridedSliceOp>(
948 loc, bitcast, res,
949 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
950 }
951 rewriter.replaceOp(bitcastOp, res);
952 return success();
953 }
954
955private:
956 std::function<bool(BitCastOp)> controlFn;
957};
958
959/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
960/// ```
961/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
962/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
963/// %r = arith.addi %a, %b : vector<1x4xindex>
964/// ```
965/// Gets converted to:
966/// ```
967/// %r = arith.addi %arg0, %arg1 : index
968/// %b = vector.broadcast %r : index to vector<1x4xindex>
969/// ```
970///
971/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
972/// ops.
973struct ReorderElementwiseOpsOnBroadcast final
974 : public OpTraitRewritePattern<OpTrait::Elementwise> {
975 using OpTraitRewritePattern::OpTraitRewritePattern;
976 LogicalResult matchAndRewrite(Operation *op,
977 PatternRewriter &rewriter) const override {
978 if (op->getNumResults() != 1)
979 return failure();
980 if (!llvm::isa<ShapedType>(Val: op->getResults()[0].getType()))
981 return failure();
982 if (!OpTrait::hasElementwiseMappableTraits(op))
983 return failure();
984 if (op->getNumOperands() == 0 ||
985 op->getResults()[0].getType() != op->getOperand(idx: 0).getType()) {
986 return failure();
987 }
988 // Avoid operations that only accept vector types, since broadcast
989 // source might be scalar types.
990 if (isa<vector::FMAOp>(op)) {
991 return failure();
992 }
993
994 // Get the type of the lhs operand
995 auto *lhsBcastOrSplat = op->getOperand(idx: 0).getDefiningOp();
996 if (!lhsBcastOrSplat ||
997 !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
998 return failure();
999 auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(idx: 0).getType();
1000
1001 // Make sure that all operands are broadcast from identical types:
1002 // * scalar (`vector.broadcast` + `vector.splat`), or
1003 // * vector (`vector.broadcast`).
1004 // Otherwise the re-ordering wouldn't be safe.
1005 if (!llvm::all_of(Range: op->getOperands(), P: [&lhsBcastOrSplatType](Value val) {
1006 auto bcast = val.getDefiningOp<vector::BroadcastOp>();
1007 if (bcast)
1008 return (bcast.getOperand().getType() == lhsBcastOrSplatType);
1009 auto splat = val.getDefiningOp<vector::SplatOp>();
1010 if (splat)
1011 return (splat.getOperand().getType() == lhsBcastOrSplatType);
1012 return false;
1013 })) {
1014 return failure();
1015 }
1016
1017 // Collect the source values before broadcasting
1018 SmallVector<Value> srcValues;
1019 srcValues.reserve(N: op->getNumOperands());
1020 for (Value operand : op->getOperands()) {
1021 srcValues.push_back(Elt: operand.getDefiningOp()->getOperand(idx: 0));
1022 }
1023
1024 // Create the "elementwise" Op
1025 Operation *elementwiseOp =
1026 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1027 lhsBcastOrSplatType, op->getAttrs());
1028
1029 // Replace the original Op with the elementwise Op
1030 auto vectorType = op->getResultTypes()[0];
1031 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1032 op, vectorType, elementwiseOp->getResults());
1033
1034 return success();
1035 }
1036};
1037
1038// Helper that returns a vector comparison that constructs a mask:
1039// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1040//
1041// If `dim == 0` then the result will be a 0-D vector.
1042//
1043// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1044// much more compact, IR for this operation, but LLVM eventually
1045// generates more elaborate instructions for this intrinsic since it
1046// is very conservative on the boundary conditions.
1047static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1048 bool force32BitVectorIndices, int64_t dim,
1049 Value b, Value *off = nullptr) {
1050 auto loc = op->getLoc();
1051 // If we can assume all indices fit in 32-bit, we perform the vector
1052 // comparison in 32-bit to get a higher degree of SIMD parallelism.
1053 // Otherwise we perform the vector comparison using 64-bit indices.
1054 Type idxType =
1055 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1056 DenseIntElementsAttr indicesAttr;
1057 if (dim == 0 && force32BitVectorIndices) {
1058 indicesAttr = DenseIntElementsAttr::get(
1059 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1060 } else if (dim == 0) {
1061 indicesAttr = DenseIntElementsAttr::get(
1062 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1063 } else if (force32BitVectorIndices) {
1064 indicesAttr = rewriter.getI32VectorAttr(
1065 values: llvm::to_vector<4>(Range: llvm::seq<int32_t>(Begin: 0, End: dim)));
1066 } else {
1067 indicesAttr = rewriter.getI64VectorAttr(
1068 values: llvm::to_vector<4>(Range: llvm::seq<int64_t>(Begin: 0, End: dim)));
1069 }
1070 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
1071 // Add in an offset if requested.
1072 if (off) {
1073 Value o = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: *off);
1074 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o);
1075 indices = rewriter.create<arith::AddIOp>(loc, ov, indices);
1076 }
1077 // Construct the vector comparison.
1078 Value bound = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: b);
1079 Value bounds =
1080 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound);
1081 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices,
1082 bounds);
1083}
1084
1085template <typename ConcreteOp>
1086struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1087public:
1088 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1089 PatternBenefit benefit = 1)
1090 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1091 force32BitVectorIndices(enableIndexOpt) {}
1092
1093 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1094 PatternRewriter &rewriter) const override {
1095 if (!xferOp.hasOutOfBoundsDim())
1096 return failure();
1097
1098 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1099 return failure();
1100
1101 Location loc = xferOp->getLoc();
1102 VectorType vtp = xferOp.getVectorType();
1103
1104 // Create the in-bounds mask with all elements between [0 .. dim - offset)
1105 // set and [dim - offset .. vector_length) unset.
1106 //
1107 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1108 // dimensions here.
1109 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1110 Value off = xferOp.getIndices()[lastIndex];
1111 Value dim =
1112 vector::createOrFoldDimOp(b&: rewriter, loc, source: xferOp.getSource(), dim: lastIndex);
1113 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
1114 Value mask = rewriter.create<vector::CreateMaskOp>(
1115 loc,
1116 VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1117 vtp.getScalableDims()),
1118 b);
1119 if (xferOp.getMask()) {
1120 // Intersect the in-bounds with the mask specified as an op parameter.
1121 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
1122 }
1123
1124 rewriter.modifyOpInPlace(xferOp, [&]() {
1125 xferOp.getMaskMutable().assign(mask);
1126 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1127 });
1128
1129 return success();
1130 }
1131
1132private:
1133 const bool force32BitVectorIndices;
1134};
1135
1136/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1137class VectorCreateMaskOpConversion
1138 : public OpRewritePattern<vector::CreateMaskOp> {
1139public:
1140 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1141 bool enableIndexOpt,
1142 PatternBenefit benefit = 1)
1143 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1144 force32BitVectorIndices(enableIndexOpt) {}
1145
1146 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1147 PatternRewriter &rewriter) const override {
1148 auto dstType = op.getType();
1149 if (cast<VectorType>(dstType).isScalable())
1150 return failure();
1151 int64_t rank = dstType.getRank();
1152 if (rank > 1)
1153 return failure();
1154 rewriter.replaceOp(
1155 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1156 rank == 0 ? 0 : dstType.getDimSize(0),
1157 op.getOperand(0)));
1158 return success();
1159 }
1160
1161private:
1162 const bool force32BitVectorIndices;
1163};
1164
1165/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1166static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1167 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1168 // TODO: Support non-dense constant.
1169 if (!denseAttr)
1170 return false;
1171
1172 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1173 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1174}
1175
1176/// Folds a select operation between an all-true and all-false vector. For now,
1177/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1178///
1179/// %true = arith.constant dense<true> : vector<1xi1>
1180/// %false = arith.constant dense<false> : vector<1xi1>
1181/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1182/// =>
1183/// %result = vector.broadcast %cond : i1 to vector<1xi1>
1184///
1185/// InstCombine seems to handle vectors with multiple elements but not the
1186/// single element ones.
1187struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1188 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1189
1190 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1191 PatternRewriter &rewriter) const override {
1192 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1193 if (!vecType || !vecType.getElementType().isInteger(1))
1194 return failure();
1195
1196 // Only scalar conditions can be folded.
1197 Value cond = selectOp.getCondition();
1198 if (isa<VectorType>(Val: cond.getType()))
1199 return failure();
1200
1201 // TODO: Support n-D and scalable vectors.
1202 if (vecType.getRank() != 1 || vecType.isScalable())
1203 return failure();
1204
1205 // TODO: Support vectors with multiple elements.
1206 if (vecType.getShape()[0] != 1)
1207 return failure();
1208
1209 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1210 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1211 return failure();
1212
1213 auto falseConst =
1214 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1215 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1216 return failure();
1217
1218 // Replace select with its condition broadcasted to single element vector.
1219 auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1220 auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1221 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1222 return success();
1223 }
1224};
1225
1226/// Returns the number of dims can be folded away from transfer ops. It returns
1227/// a failure if it can not determine the number of dims to be folded.
1228/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
1229/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
1230/// can be dropped by memref.subview ops.
1231/// Example 2: it returns "1" if `srcType` is the same memref type with
1232/// [8192, 16, 8, 1] strides.
1233static FailureOr<size_t>
1234getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1235 SmallVector<int64_t> srcStrides;
1236 int64_t srcOffset;
1237 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
1238 return failure();
1239
1240 // According to vector.transfer_read/write semantics, the vector can be a
1241 // slice. Thus, we have to offset the check index with `rankDiff` in
1242 // `srcStrides` and source dim sizes.
1243 size_t result = 0;
1244 int rankDiff = srcType.getRank() - vectorType.getRank();
1245 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1246 // Check that the inner dim size is 1 for both memref type and vector slice.
1247 // It can be folded only if they are 1 and the stride is 1.
1248 int dim = vectorType.getRank() - i - 1;
1249 if (srcStrides[dim + rankDiff] != 1 ||
1250 srcType.getDimSize(dim + rankDiff) != 1 ||
1251 vectorType.getDimSize(dim) != 1)
1252 break;
1253 result++;
1254 }
1255 return result;
1256}
1257
1258/// Drop inner most contiguous unit dimensions from transfer_read operand.
1259class DropInnerMostUnitDimsTransferRead
1260 : public OpRewritePattern<vector::TransferReadOp> {
1261 using OpRewritePattern::OpRewritePattern;
1262
1263 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1264 PatternRewriter &rewriter) const override {
1265 // TODO: support 0-d corner case.
1266 if (readOp.getTransferRank() == 0)
1267 return failure();
1268
1269 // TODO: support mask.
1270 if (readOp.getMask())
1271 return failure();
1272
1273 auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
1274 if (!srcType)
1275 return failure();
1276
1277 if (!readOp.getPermutationMap().isMinorIdentity())
1278 return failure();
1279
1280 auto targetType = readOp.getVectorType();
1281 if (targetType.getRank() <= 1)
1282 return failure();
1283
1284 FailureOr<size_t> maybeDimsToDrop =
1285 getTransferFoldableInnerUnitDims(srcType, targetType);
1286 if (failed(result: maybeDimsToDrop))
1287 return failure();
1288
1289 size_t dimsToDrop = maybeDimsToDrop.value();
1290 if (dimsToDrop == 0)
1291 return failure();
1292
1293 auto resultTargetVecType =
1294 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1295 targetType.getElementType());
1296
1297 auto loc = readOp.getLoc();
1298 SmallVector<OpFoldResult> sizes =
1299 memref::getMixedSizes(builder&: rewriter, loc: loc, value: readOp.getSource());
1300 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1301 rewriter.getIndexAttr(0));
1302 SmallVector<OpFoldResult> strides(srcType.getRank(),
1303 rewriter.getIndexAttr(1));
1304 auto resultMemrefType =
1305 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1306 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1307 strides));
1308 ArrayAttr inBoundsAttr =
1309 readOp.getInBounds()
1310 ? rewriter.getArrayAttr(
1311 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1312 : ArrayAttr();
1313 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1314 loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
1315 auto permMap = getTransferMinorIdentityMap(
1316 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1317 Value result = rewriter.create<vector::TransferReadOp>(
1318 loc, resultTargetVecType, rankedReducedView,
1319 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1320 readOp.getPadding(),
1321 // TODO: support mask.
1322 /*mask=*/Value(), inBoundsAttr);
1323 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1324 result);
1325 return success();
1326 }
1327};
1328
1329/// Drop inner most contiguous unit dimensions from transfer_write operand.
1330/// E.g.,
1331/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1332/// {in_bounds = [true, true, true, true, true]}
1333/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1334///
1335/// will be replaced with
1336///
1337/// %subview = memref.subview %arg0
1338/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1339/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1340/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1341/// to vector<1x16x16xf32>
1342/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1343/// {in_bounds = [true, true, true]}
1344/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1345class DropInnerMostUnitDimsTransferWrite
1346 : public OpRewritePattern<vector::TransferWriteOp> {
1347 using OpRewritePattern::OpRewritePattern;
1348
1349 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1350 PatternRewriter &rewriter) const override {
1351 // TODO: support 0-d corner case.
1352 if (writeOp.getTransferRank() == 0)
1353 return failure();
1354
1355 // TODO: support mask.
1356 if (writeOp.getMask())
1357 return failure();
1358
1359 auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
1360 if (!srcType)
1361 return failure();
1362
1363 if (!writeOp.getPermutationMap().isMinorIdentity())
1364 return failure();
1365
1366 auto targetType = writeOp.getVectorType();
1367 if (targetType.getRank() <= 1)
1368 return failure();
1369
1370 FailureOr<size_t> maybeDimsToDrop =
1371 getTransferFoldableInnerUnitDims(srcType, targetType);
1372 if (failed(result: maybeDimsToDrop))
1373 return failure();
1374
1375 size_t dimsToDrop = maybeDimsToDrop.value();
1376 if (dimsToDrop == 0)
1377 return failure();
1378
1379 auto resultTargetVecType =
1380 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1381 targetType.getElementType());
1382
1383 Location loc = writeOp.getLoc();
1384 SmallVector<OpFoldResult> sizes =
1385 memref::getMixedSizes(builder&: rewriter, loc, value: writeOp.getSource());
1386 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1387 rewriter.getIndexAttr(0));
1388 SmallVector<OpFoldResult> strides(srcType.getRank(),
1389 rewriter.getIndexAttr(1));
1390 auto resultMemrefType =
1391 cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
1392 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1393 strides));
1394 ArrayAttr inBoundsAttr =
1395 writeOp.getInBounds()
1396 ? rewriter.getArrayAttr(
1397 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
1398 : ArrayAttr();
1399
1400 Value rankedReducedView = rewriter.create<memref::SubViewOp>(
1401 loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
1402 auto permMap = getTransferMinorIdentityMap(
1403 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1404
1405 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1406 loc, resultTargetVecType, writeOp.getVector());
1407 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1408 writeOp, shapeCast, rankedReducedView,
1409 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1410 // TODO: support mask.
1411 /*mask=*/Value(), inBoundsAttr);
1412 return success();
1413 }
1414};
1415
1416/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1417/// semantics to a contraction suitable for MMT (matrix matrix multiplication
1418/// with the RHS transposed) lowering.
1419struct CanonicalizeContractMatmulToMMT final
1420 : OpRewritePattern<vector::ContractionOp> {
1421 using OpRewritePattern::OpRewritePattern;
1422
1423 using FilterConstraintType =
1424 std::function<LogicalResult(vector::ContractionOp op)>;
1425
1426 CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit,
1427 FilterConstraintType constraint)
1428 : OpRewritePattern<vector::ContractionOp>(context, benefit),
1429 filter(std::move(constraint)) {}
1430
1431 LogicalResult matchAndRewrite(vector::ContractionOp op,
1432 PatternRewriter &rewriter) const override {
1433 if (failed(filter(op)))
1434 return failure();
1435
1436 Location loc = op.getLoc();
1437 Value lhs = op.getLhs();
1438 Value rhs = op.getRhs();
1439 Value res = op.getAcc();
1440
1441 // Set up the parallel/reduction structure in right form.
1442 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1443 auto infer = [&](MapList m) {
1444 return AffineMap::inferFromExprList(m, op.getContext());
1445 };
1446 AffineExpr m;
1447 AffineExpr n;
1448 AffineExpr k;
1449 bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k);
1450 static constexpr std::array<int64_t, 2> perm = {1, 0};
1451 auto iteratorTypes = op.getIteratorTypes().getValue();
1452 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1453 if (iteratorTypes.size() != 3 ||
1454 !vector::isParallelIterator(attr: iteratorTypes[0]) ||
1455 !vector::isParallelIterator(attr: iteratorTypes[1]) ||
1456 !vector::isReductionIterator(attr: iteratorTypes[2]))
1457 return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1458
1459 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1460 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1461 if (maps == canonicalForm)
1462 return rewriter.notifyMatchFailure(op, "already in the canonical form");
1463
1464 // Create a vector transpose making sure to emit zero/sign-extend at the
1465 // end.
1466 auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1467 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1468 Value trans =
1469 rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm);
1470 VectorType newType =
1471 cast<VectorType>(trans.getType())
1472 .clone(cast<VectorType>(mat.getType()).getElementType());
1473 return rewriter.create<arith::ExtSIOp>(loc, newType, trans);
1474 }
1475 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1476 Value trans =
1477 rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm);
1478 VectorType newType =
1479 VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1480 cast<VectorType>(mat.getType()).getElementType());
1481 return rewriter.create<arith::ExtUIOp>(loc, newType, trans);
1482 }
1483 return rewriter.create<vector::TransposeOp>(loc, mat, perm);
1484 };
1485
1486 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1487 rhs = createTranspose(rhs);
1488 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1489 lhs = createTranspose(lhs);
1490 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1491 rhs = createTranspose(rhs);
1492 lhs = createTranspose(lhs);
1493 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1494 std::swap(a&: rhs, b&: lhs);
1495 rhs = createTranspose(rhs);
1496 lhs = createTranspose(lhs);
1497 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1498 std::swap(a&: rhs, b&: lhs);
1499 rhs = createTranspose(rhs);
1500 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1501 std::swap(a&: lhs, b&: rhs);
1502 lhs = createTranspose(lhs);
1503 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1504 std::swap(a&: lhs, b&: rhs);
1505 } else {
1506 return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1507 }
1508 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1509 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(values: canonicalForm),
1510 op.getIteratorTypes());
1511 return success();
1512 };
1513
1514private:
1515 FilterConstraintType filter;
1516};
1517
1518/// Pattern to fold arithmetic extensions on floating point data types into
1519/// vector contraction operations. linalg.matmul introduces arithmetic
1520/// extensions on its operands. Please mlir snippets below for more details.
1521/// ```mlir
1522/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1523/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1524/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1525/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1526/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1527/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1528/// "linalg.yield"(%acc) : (f32) -> ()
1529/// })
1530/// ```
1531/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1532/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1533/// This pattern folds the arithmetic extensions into the vector contraction and
1534/// enables the usage of native mixed precision Tensor Core instructions.
1535struct FoldArithExtIntoContractionOp
1536 : public OpRewritePattern<vector::ContractionOp> {
1537 using OpRewritePattern::OpRewritePattern;
1538
1539 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1540 PatternRewriter &rewriter) const override {
1541
1542 auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1543 auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1544
1545 if (!lhsDefOp || !rhsDefOp) {
1546 return rewriter.notifyMatchFailure(contractOp,
1547 "no defining op on contract operands");
1548 }
1549
1550 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1551 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1552 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1553 contractOp.getIteratorTypesAttr());
1554
1555 return success();
1556 }
1557};
1558
1559/// Pattern to fold chained reduction to a series of vector additions and a
1560/// final reduction. This form should require fewer subgroup operations.
1561///
1562/// ```mlir
1563/// %a = vector.reduction <add> %x, %acc
1564/// %b = vector.reduction <add> %y, %a
1565/// ==>
1566/// %a = arith.addf %x, %y
1567/// %b = vector.reduction <add> %a, %acc
1568/// ```
1569struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1570 using OpRewritePattern::OpRewritePattern;
1571
1572 LogicalResult matchAndRewrite(vector::ReductionOp op,
1573 PatternRewriter &rewriter) const override {
1574 // TODO: Handle other combining kinds.
1575 if (op.getKind() != vector::CombiningKind::ADD)
1576 return failure();
1577
1578 // Accumulator is optional.
1579 Value acc = op.getAcc();
1580 if (!acc)
1581 return failure();
1582
1583 if (!acc.getType().isIntOrFloat())
1584 return failure();
1585
1586 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1587 if (!parentReduction)
1588 return failure();
1589
1590 Location loc = op.getLoc();
1591 Value vAdd;
1592 if (isa<IntegerType>(Val: acc.getType())) {
1593 vAdd = rewriter.createOrFold<arith::AddIOp>(
1594 loc, parentReduction.getVector(), op.getVector());
1595 } else {
1596 vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(),
1597 op.getVector());
1598 }
1599 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1600 parentReduction.getAcc());
1601 return success();
1602 }
1603};
1604
1605/// For vectors with either leading or trailing unit dim, replaces:
1606/// elementwise(a, b)
1607/// with:
1608/// sc_a = shape_cast(a)
1609/// sc_b = shape_cast(b)
1610/// res = elementwise(sc_a, sc_b)
1611/// return shape_cast(res)
1612/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1613/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1614/// required to be rank > 1.
1615///
1616/// Ex:
1617/// ```
1618/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1619/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1620/// ```
1621///
1622/// gets converted to:
1623///
1624/// ```
1625/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1626/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1627/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1628/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1629/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1630/// ```
1631///
1632/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1633/// `%cast`.
1634struct DropUnitDimFromElementwiseOps final
1635 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1636 using OpTraitRewritePattern::OpTraitRewritePattern;
1637 LogicalResult matchAndRewrite(Operation *op,
1638 PatternRewriter &rewriter) const override {
1639 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1640 return failure();
1641
1642 auto resultVectorType = dyn_cast<VectorType>(op->getResult(idx: 0).getType());
1643 if (!resultVectorType)
1644 return failure();
1645
1646 // Check the operand pre-conditions. For `Elementwise` ops all operands are
1647 // guaranteed to have identical shapes (with some exceptions such as
1648 // `arith.select`) and it suffices to only check one of them.
1649 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(idx: 0).getType());
1650 if (!sourceVectorType)
1651 return failure();
1652 if (sourceVectorType.getRank() < 2)
1653 return failure();
1654
1655 bool hasTrailingDimUnitFixed =
1656 ((sourceVectorType.getShape().back() == 1) &&
1657 (!sourceVectorType.getScalableDims().back()));
1658 bool hasLeadingDimUnitFixed =
1659 ((sourceVectorType.getShape().front() == 1) &&
1660 (!sourceVectorType.getScalableDims().front()));
1661 if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1662 return failure();
1663
1664 // Drop leading/trailing unit dim by applying vector.shape_cast to all
1665 // operands
1666 int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1667
1668 SmallVector<Value> newOperands;
1669 auto loc = op->getLoc();
1670 for (auto operand : op->getOperands()) {
1671 auto opVectorType = cast<VectorType>(operand.getType());
1672 VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
1673 auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
1674 newOperands.push_back(Elt: opSC);
1675 }
1676
1677 VectorType newResultVectorType =
1678 VectorType::Builder(resultVectorType).dropDim(dim);
1679 // Create an updated elementwise Op without leading/trailing unit dim
1680 Operation *elementwiseOp =
1681 rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1682 newResultVectorType, op->getAttrs());
1683
1684 // Restore the leading/trailing unit dim by applying vector.shape_cast
1685 // to the result
1686 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
1687 elementwiseOp->getResult(0));
1688
1689 return success();
1690 }
1691};
1692
1693/// Pattern to eliminate redundant zero-constants added to reduction operands.
1694/// It's enough for there to be one initial zero value, so we can eliminate the
1695/// extra ones that feed into `vector.reduction <add>`. These get created by the
1696/// `ChainedReduction` pattern.
1697///
1698/// ```mlir
1699/// %a = arith.addf %x, %zero
1700/// %b = arith.addf %a, %y
1701/// %c = vector.reduction <add> %b, %acc
1702/// ==>
1703/// %b = arith.addf %a, %y
1704/// %c = vector.reduction <add> %b, %acc
1705/// ```
1706struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
1707 using OpRewritePattern::OpRewritePattern;
1708
1709 LogicalResult matchAndRewrite(vector::ReductionOp op,
1710 PatternRewriter &rewriter) const override {
1711 // TODO: Handle other reduction kinds and their identity values.
1712 if (op.getKind() != vector::CombiningKind::ADD)
1713 return failure();
1714
1715 Type elemType = op.getSourceVectorType().getElementType();
1716 // The integer case should be handled by `arith.addi` folders, only check
1717 // for floats here.
1718 if (!isa<FloatType>(Val: elemType))
1719 return failure();
1720
1721 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
1722 if (!vAdd)
1723 return failure();
1724 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
1725 if (!addLhs)
1726 return failure();
1727
1728 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
1729 return failure();
1730
1731 auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(),
1732 vAdd.getRhs());
1733 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
1734 op.getAcc());
1735 return success();
1736 }
1737};
1738
1739/// Example:
1740/// ```
1741/// %a = vector.reduction <add> %x : vector<2xf32> into f32
1742/// ```
1743/// is transformed into:
1744/// ```
1745/// %y = vector.extract %x[0] : f32 from vector<2xf32>
1746/// %z = vector.extract %x[1] : f32 from vector<2xf32>
1747/// %a = arith.addf %y, %z : f32
1748/// ```
1749struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
1750 BreakDownVectorReduction(MLIRContext *context,
1751 unsigned maxNumElementsToExtract,
1752 PatternBenefit benefit)
1753 : OpRewritePattern(context, benefit),
1754 maxNumElementsToExtract(maxNumElementsToExtract) {}
1755
1756 LogicalResult matchAndRewrite(vector::ReductionOp op,
1757 PatternRewriter &rewriter) const override {
1758 VectorType type = op.getSourceVectorType();
1759 if (type.isScalable() || op.isMasked())
1760 return failure();
1761 assert(type.getRank() == 1 && "Expected a 1-d vector");
1762
1763 int64_t numElems = type.getNumElements();
1764 if (numElems > maxNumElementsToExtract) {
1765 return rewriter.notifyMatchFailure(
1766 op, llvm::formatv(Fmt: "has too many vector elements ({0}) to break down "
1767 "(max allowed: {1})",
1768 Vals&: numElems, Vals: maxNumElementsToExtract));
1769 }
1770
1771 Location loc = op.getLoc();
1772 SmallVector<Value> extracted(numElems, nullptr);
1773 for (auto [idx, extractedElem] : llvm::enumerate(extracted))
1774 extractedElem = rewriter.create<vector::ExtractOp>(
1775 loc, op.getVector(), static_cast<int64_t>(idx));
1776
1777 Value res = extracted.front();
1778 for (auto extractedElem : llvm::drop_begin(extracted))
1779 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
1780 extractedElem, op.getFastmathAttr());
1781 if (Value acc = op.getAcc())
1782 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
1783 op.getFastmathAttr());
1784
1785 rewriter.replaceOp(op, res);
1786 return success();
1787 }
1788
1789private:
1790 unsigned maxNumElementsToExtract = 0;
1791};
1792
1793} // namespace
1794
1795void mlir::vector::populateFoldArithExtensionPatterns(
1796 RewritePatternSet &patterns) {
1797 patterns.add<FoldArithExtIntoContractionOp>(arg: patterns.getContext());
1798}
1799
1800void mlir::vector::populateVectorMaskMaterializationPatterns(
1801 RewritePatternSet &patterns, bool force32BitVectorIndices,
1802 PatternBenefit benefit) {
1803 patterns.add<VectorCreateMaskOpConversion,
1804 MaterializeTransferMask<vector::TransferReadOp>,
1805 MaterializeTransferMask<vector::TransferWriteOp>>(
1806 arg: patterns.getContext(), args&: force32BitVectorIndices, args&: benefit);
1807 patterns.add<FoldI1Select>(arg: patterns.getContext(), args&: benefit);
1808}
1809
1810void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
1811 PatternBenefit benefit) {
1812 patterns.add<ShapeCastOpFolder>(arg: patterns.getContext(), args&: benefit);
1813}
1814
1815void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
1816 RewritePatternSet &patterns, PatternBenefit benefit) {
1817 patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
1818 arg: patterns.getContext(), args&: benefit);
1819}
1820
1821void mlir::vector::populateBubbleVectorBitCastOpPatterns(
1822 RewritePatternSet &patterns, PatternBenefit benefit) {
1823 patterns.add<BubbleDownVectorBitCastForExtract,
1824 BubbleDownBitCastForStridedSliceExtract,
1825 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
1826 arg: patterns.getContext(), args&: benefit);
1827}
1828
1829void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
1830 RewritePatternSet &patterns,
1831 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
1832 patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
1833 std::move(controlFn), benefit);
1834}
1835
1836void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
1837 RewritePatternSet &patterns,
1838 std::function<LogicalResult(vector::ContractionOp)> constraint,
1839 PatternBenefit benefit) {
1840 patterns.add<CanonicalizeContractMatmulToMMT>(arg: patterns.getContext(), args&: benefit,
1841 args: std::move(constraint));
1842}
1843
1844void mlir::vector::populateVectorReductionToContractPatterns(
1845 RewritePatternSet &patterns, PatternBenefit benefit) {
1846 patterns.add<MultiReduceToContract, CombineContractBroadcast,
1847 CombineContractABTranspose, CombineContractResultTranspose,
1848 ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
1849 arg: patterns.getContext(), args&: benefit);
1850}
1851
1852void mlir::vector::
1853 populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
1854 RewritePatternSet &patterns, PatternBenefit benefit) {
1855 patterns.add<DropInnerMostUnitDimsTransferRead,
1856 DropInnerMostUnitDimsTransferWrite>(arg: patterns.getContext(),
1857 args&: benefit);
1858}
1859
1860void mlir::vector::populateSinkVectorBroadcastPatterns(
1861 RewritePatternSet &patterns, PatternBenefit benefit) {
1862 patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
1863 arg: patterns.getContext(), args&: benefit);
1864}
1865
1866void mlir::vector::populateChainedVectorReductionFoldingPatterns(
1867 RewritePatternSet &patterns, PatternBenefit benefit) {
1868 patterns.add<ChainedReduction>(arg: patterns.getContext(), args&: benefit);
1869 patterns.add<ReduceRedundantZero>(arg: patterns.getContext(),
1870 args: PatternBenefit(benefit.getBenefit() + 1));
1871}
1872
1873void mlir::vector::populateBreakDownVectorReductionPatterns(
1874 RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
1875 PatternBenefit benefit) {
1876 patterns.add<BreakDownVectorReduction>(arg: patterns.getContext(),
1877 args&: maxNumElementsToExtract, args&: benefit);
1878}
1879
1880//===----------------------------------------------------------------------===//
1881// TableGen'd enum attribute definitions
1882//===----------------------------------------------------------------------===//
1883
1884#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
1885

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp