1//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <numeric>
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
13#include "mlir/Dialect/Vector/IR/VectorOps.h"
14#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
15#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/TypeUtilities.h"
19
20#define DEBUG_TYPE "vector-drop-unit-dim"
21
22using namespace mlir;
23using namespace mlir::vector;
24
25// Trims leading one dimensions from `oldType` and returns the result type.
26// Returns `vector<1xT>` if `oldType` only has one element.
27static VectorType trimLeadingOneDims(VectorType oldType) {
28 ArrayRef<int64_t> oldShape = oldType.getShape();
29 ArrayRef<int64_t> newShape = oldShape;
30
31 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
32 ArrayRef<bool> newScalableDims = oldScalableDims;
33
34 while (!newShape.empty() && newShape.front() == 1 &&
35 !newScalableDims.front()) {
36 newShape = newShape.drop_front(N: 1);
37 newScalableDims = newScalableDims.drop_front(N: 1);
38 }
39
40 // Make sure we have at least 1 dimension per vector type requirements.
41 if (newShape.empty()) {
42 newShape = oldShape.take_back();
43 newScalableDims = oldType.getScalableDims().take_back();
44 }
45 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
46}
47
48/// Return a smallVector of size `rank` containing all zeros.
49static SmallVector<int64_t> splatZero(int64_t rank) {
50 return SmallVector<int64_t>(rank, 0);
51}
52namespace {
53
54// Casts away leading one dimensions in vector.extract_strided_slice's vector
55// input by inserting vector.broadcast.
56struct CastAwayExtractStridedSliceLeadingOneDim
57 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
58 using OpRewritePattern::OpRewritePattern;
59
60 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
61 PatternRewriter &rewriter) const override {
62 // vector.extract_strided_slice requires the input and output vector to have
63 // the same rank. Here we drop leading one dimensions from the input vector
64 // type to make sure we don't cause mismatch.
65 VectorType oldSrcType = extractOp.getSourceVectorType();
66 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
67
68 if (newSrcType.getRank() == oldSrcType.getRank())
69 return failure();
70
71 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
72
73 VectorType oldDstType = extractOp.getType();
74 VectorType newDstType =
75 VectorType::get(oldDstType.getShape().drop_front(dropCount),
76 oldDstType.getElementType(),
77 oldDstType.getScalableDims().drop_front(dropCount));
78
79 Location loc = extractOp.getLoc();
80
81 Value newSrcVector = rewriter.create<vector::ExtractOp>(
82 loc, extractOp.getVector(), splatZero(dropCount));
83
84 // The offsets/sizes/strides attribute can have a less number of elements
85 // than the input vector's rank: it is meant for the leading dimensions.
86 auto newOffsets = rewriter.getArrayAttr(
87 value: extractOp.getOffsets().getValue().drop_front(dropCount));
88 auto newSizes = rewriter.getArrayAttr(
89 value: extractOp.getSizes().getValue().drop_front(dropCount));
90 auto newStrides = rewriter.getArrayAttr(
91 value: extractOp.getStrides().getValue().drop_front(dropCount));
92
93 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
94 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
95
96 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
97 newExtractOp);
98
99 return success();
100 }
101};
102
103// Casts away leading one dimensions in vector.insert_strided_slice's vector
104// inputs by inserting vector.broadcast.
105struct CastAwayInsertStridedSliceLeadingOneDim
106 : public OpRewritePattern<vector::InsertStridedSliceOp> {
107 using OpRewritePattern::OpRewritePattern;
108
109 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
110 PatternRewriter &rewriter) const override {
111 VectorType oldSrcType = insertOp.getSourceVectorType();
112 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
113 VectorType oldDstType = insertOp.getDestVectorType();
114 VectorType newDstType = trimLeadingOneDims(oldDstType);
115
116 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
117 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
118 if (srcDropCount == 0 && dstDropCount == 0)
119 return failure();
120
121 // Trim leading one dimensions from both operands.
122 Location loc = insertOp.getLoc();
123
124 Value newSrcVector = rewriter.create<vector::ExtractOp>(
125 loc, insertOp.getSource(), splatZero(srcDropCount));
126 Value newDstVector = rewriter.create<vector::ExtractOp>(
127 loc, insertOp.getDest(), splatZero(dstDropCount));
128
129 auto newOffsets = rewriter.getArrayAttr(
130 value: insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
131 auto newStrides = rewriter.getArrayAttr(
132 value: insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
133
134 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
135 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
136
137 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
138 newInsertOp);
139
140 return success();
141 }
142};
143
144// Casts away leading one dimensions in vector.insert's vector inputs by
145// inserting vector.broadcast.
146struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
147 using OpRewritePattern::OpRewritePattern;
148
149 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
150 PatternRewriter &rewriter) const override {
151 Type oldSrcType = insertOp.getSourceType();
152 Type newSrcType = oldSrcType;
153 int64_t oldSrcRank = 0, newSrcRank = 0;
154 if (auto type = dyn_cast<VectorType>(oldSrcType)) {
155 newSrcType = trimLeadingOneDims(type);
156 oldSrcRank = type.getRank();
157 newSrcRank = cast<VectorType>(newSrcType).getRank();
158 }
159
160 VectorType oldDstType = insertOp.getDestVectorType();
161 VectorType newDstType = trimLeadingOneDims(oldDstType);
162
163 int64_t srcDropCount = oldSrcRank - newSrcRank;
164 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
165 if (srcDropCount == 0 && dstDropCount == 0)
166 return failure();
167
168 // Trim leading one dimensions from both operands.
169 Location loc = insertOp.getLoc();
170
171 Value newSrcVector = insertOp.getSource();
172 if (oldSrcRank != 0) {
173 newSrcVector = rewriter.create<vector::ExtractOp>(
174 loc, insertOp.getSource(), splatZero(srcDropCount));
175 }
176 Value newDstVector = rewriter.create<vector::ExtractOp>(
177 loc, insertOp.getDest(), splatZero(dstDropCount));
178
179 // New position rank needs to be computed in two steps: (1) if destination
180 // type has leading unit dims, we also trim the position array accordingly,
181 // then (2) if source type also has leading unit dims, we need to append
182 // zeroes to the position array accordingly.
183 unsigned oldPosRank = insertOp.getNumIndices();
184 unsigned newPosRank = std::max<int64_t>(a: 0, b: oldPosRank - dstDropCount);
185 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
186 SmallVector<OpFoldResult> newPosition =
187 llvm::to_vector(Range: ArrayRef(oldPosition).take_back(N: newPosRank));
188 newPosition.resize(newDstType.getRank() - newSrcRank,
189 rewriter.getI64IntegerAttr(0));
190
191 auto newInsertOp = rewriter.create<vector::InsertOp>(
192 loc, newSrcVector, newDstVector, newPosition);
193
194 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
195 newInsertOp);
196
197 return success();
198 }
199};
200
201static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
202 VectorType newType, AffineMap newMap,
203 VectorType oldMaskType) {
204 // Infer the type of the new mask from the new map.
205 VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
206
207 // If the new mask is broadcastable to the old result type, we can safely
208 // use a `vector.extract` to get the new mask. Otherwise the best we can
209 // do is shape cast.
210 if (vector::isBroadcastableTo(srcType: newMaskType, dstVectorType: oldMaskType) ==
211 BroadcastableToResult::Success) {
212 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
213 return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
214 }
215 return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
216}
217
218// Turns vector.transfer_read on vector with leading 1 dimensions into
219// vector.shape_cast followed by vector.transfer_read on vector without leading
220// 1 dimensions.
221struct CastAwayTransferReadLeadingOneDim
222 : public OpRewritePattern<vector::TransferReadOp> {
223 using OpRewritePattern::OpRewritePattern;
224
225 LogicalResult matchAndRewrite(vector::TransferReadOp read,
226 PatternRewriter &rewriter) const override {
227 // TODO(#78787): Not supported masked op yet.
228 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
229 return failure();
230 // TODO: support 0-d corner case.
231 if (read.getTransferRank() == 0)
232 return failure();
233
234 auto shapedType = cast<ShapedType>(read.getSource().getType());
235 if (shapedType.getElementType() != read.getVectorType().getElementType())
236 return failure();
237
238 VectorType oldType = read.getVectorType();
239 VectorType newType = trimLeadingOneDims(oldType);
240
241 if (newType == oldType)
242 return failure();
243
244 AffineMap oldMap = read.getPermutationMap();
245 ArrayRef<AffineExpr> newResults =
246 oldMap.getResults().take_back(N: newType.getRank());
247 AffineMap newMap =
248 AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), results: newResults,
249 context: rewriter.getContext());
250
251 ArrayAttr inBoundsAttr;
252 if (read.getInBounds())
253 inBoundsAttr = rewriter.getArrayAttr(
254 value: read.getInBoundsAttr().getValue().take_back(newType.getRank()));
255
256 Value mask = Value();
257 if (read.getMask()) {
258 VectorType maskType = read.getMaskType();
259 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
260 newType, newMap, maskType);
261 }
262
263 auto newRead = rewriter.create<vector::TransferReadOp>(
264 read.getLoc(), newType, read.getSource(), read.getIndices(),
265 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
266 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
267
268 return success();
269 }
270};
271
272// Turns vector.transfer_write on vector with leading 1 dimensions into
273// vector.shape_cast followed by vector.transfer_write on vector without leading
274// 1 dimensions.
275struct CastAwayTransferWriteLeadingOneDim
276 : public OpRewritePattern<vector::TransferWriteOp> {
277 using OpRewritePattern::OpRewritePattern;
278
279 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
280 PatternRewriter &rewriter) const override {
281 // TODO(#78787): Not supported masked op yet.
282 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
283 return failure();
284 // TODO: support 0-d corner case.
285 if (write.getTransferRank() == 0)
286 return failure();
287
288 auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
289 if (shapedType.getElementType() != write.getVectorType().getElementType())
290 return failure();
291
292 VectorType oldType = write.getVectorType();
293 VectorType newType = trimLeadingOneDims(oldType);
294 if (newType == oldType)
295 return failure();
296 int64_t dropDim = oldType.getRank() - newType.getRank();
297
298 AffineMap oldMap = write.getPermutationMap();
299 ArrayRef<AffineExpr> newResults =
300 oldMap.getResults().take_back(N: newType.getRank());
301 AffineMap newMap =
302 AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), results: newResults,
303 context: rewriter.getContext());
304
305 ArrayAttr inBoundsAttr;
306 if (write.getInBounds())
307 inBoundsAttr = rewriter.getArrayAttr(
308 value: write.getInBoundsAttr().getValue().take_back(newType.getRank()));
309
310 auto newVector = rewriter.create<vector::ExtractOp>(
311 write.getLoc(), write.getVector(), splatZero(dropDim));
312
313 if (write.getMask()) {
314 VectorType maskType = write.getMaskType();
315 Value newMask = dropUnitDimsFromMask(
316 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
317 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
318 write, newVector, write.getSource(), write.getIndices(),
319 AffineMapAttr::get(newMap), newMask, inBoundsAttr);
320 return success();
321 }
322
323 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
324 write, newVector, write.getSource(), write.getIndices(),
325 AffineMapAttr::get(newMap), inBoundsAttr);
326 return success();
327 }
328};
329
330} // namespace
331
332FailureOr<Value>
333mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
334 MaskingOpInterface maskingOp,
335 RewriterBase &rewriter) {
336 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
337 if (oldAccType == nullptr)
338 return failure();
339 if (oldAccType.getRank() < 2)
340 return failure();
341 if (oldAccType.getShape()[0] != 1)
342 return failure();
343 // currently we support only dropping one dim but the pattern can be applied
344 // greedily to drop more.
345 int64_t dropDim = 1;
346
347 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
348 SmallVector<AffineMap> newIndexingMaps;
349
350 auto oldIteratorTypes = contractOp.getIteratorTypes();
351 SmallVector<Attribute> newIteratorTypes;
352
353 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
354
355 if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
356 // only parallel type iterators can be dropped.
357 return failure();
358
359 for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
360 int64_t currDim = it.index();
361 if (currDim == dimToDrop)
362 continue;
363 newIteratorTypes.push_back(it.value());
364 }
365
366 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
367 contractOp.getAcc()};
368 SmallVector<Value> newOperands;
369 auto loc = contractOp.getLoc();
370
371 for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
372 // Check if the dim to be dropped exists as a leading dim in the operand
373 // if it does then we use vector.extract to drop it.
374 bool validExtract = false;
375 SmallVector<AffineExpr> results;
376 auto map = it.value();
377 int64_t orginalZeroDim = it.value().getDimPosition(0);
378 if (orginalZeroDim != dimToDrop) {
379 // There are two reasons to be in this path, 1. We need to
380 // tranpose the operand to make the dim to be dropped
381 // leading. 2. The dim to be dropped does not exist and in
382 // that case we dont want to add a unit tranpose but we must
383 // check all the indices to make sure this is the case.
384 bool tranposeNeeded = false;
385 SmallVector<int64_t> perm;
386 SmallVector<AffineExpr> transposeResults;
387
388 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
389 int64_t currDim = map.getDimPosition(i);
390 if (currDim == dimToDrop) {
391 tranposeNeeded = true;
392 perm.insert(perm.begin(), i);
393 auto targetExpr = rewriter.getAffineDimExpr(currDim);
394 transposeResults.insert(transposeResults.begin(), targetExpr);
395 } else {
396 perm.push_back(i);
397 auto targetExpr = rewriter.getAffineDimExpr(currDim);
398 transposeResults.push_back(targetExpr);
399 }
400 }
401
402 // Checks if only the outer, unit dimensions (of size 1) are permuted.
403 // Such transposes do not materially effect the underlying vector and can
404 // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
405 bool transposeNonOuterUnitDims = false;
406 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
407 for (auto [index, dim] :
408 llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
409 if (dim != static_cast<int64_t>(index) &&
410 operandShape.getDimSize(index) != 1) {
411 transposeNonOuterUnitDims = true;
412 break;
413 }
414 }
415
416 // Do the tranpose now if needed so that we can drop the
417 // correct dim using extract later.
418 if (tranposeNeeded) {
419 map = AffineMap::get(map.getNumDims(), 0, transposeResults,
420 contractOp.getContext());
421 if (transposeNonOuterUnitDims) {
422 operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
423 loc, operands[it.index()], perm);
424 }
425 }
426 }
427 // We have taken care to have the dim to be dropped be
428 // the leading dim. If its still not leading that means it
429 // does not exist in this operand and hence we do not need
430 // an extract.
431 if (map.getDimPosition(0) == dimToDrop)
432 validExtract = true;
433
434 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
435 int64_t currDim = map.getDimPosition(i);
436 if (currDim == dimToDrop)
437 // This is the dim we are dropping.
438 continue;
439 auto targetExpr = rewriter.getAffineDimExpr(
440 currDim < dimToDrop ? currDim : currDim - 1);
441 results.push_back(targetExpr);
442 }
443 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
444 contractOp.getContext()));
445 // Extract if its a valid extraction, otherwise use the operand
446 // without extraction.
447 newOperands.push_back(
448 validExtract ? rewriter.create<vector::ExtractOp>(
449 loc, operands[it.index()], splatZero(dropDim))
450 : operands[it.index()]);
451 }
452
453 // Depending on whether this vector.contract is masked, the replacing Op
454 // should either be a new vector.contract Op or vector.mask Op.
455 Operation *newOp = rewriter.create<vector::ContractionOp>(
456 loc, newOperands[0], newOperands[1], newOperands[2],
457 rewriter.getAffineMapArrayAttr(newIndexingMaps),
458 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
459
460 if (maskingOp) {
461 auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
462 splatZero(dropDim));
463
464 newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: newMask);
465 }
466
467 return rewriter
468 .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
469 newOp->getResults()[0])
470 .getResult();
471}
472
473namespace {
474
475/// Turns vector.contract on vector with leading 1 dimensions into
476/// vector.extract followed by vector.contract on vector without leading
477/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
478/// prior to extract.
479struct CastAwayContractionLeadingOneDim
480 : public MaskableOpRewritePattern<vector::ContractionOp> {
481 using MaskableOpRewritePattern::MaskableOpRewritePattern;
482
483 FailureOr<Value>
484 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
485 MaskingOpInterface maskingOp,
486 PatternRewriter &rewriter) const override {
487 return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
488 }
489};
490
491/// Looks at elementwise operations on vectors with at least one leading
492/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
493/// and cast aways the leading one dimensions (_plural_) and then broadcasts
494/// the results.
495///
496/// Example before:
497/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
498/// Example after:
499/// %2 = arith.mulf %0, %1 : vector<4x1xf32>
500/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
501///
502/// Does support scalable vectors.
503class CastAwayElementwiseLeadingOneDim : public RewritePattern {
504public:
505 CastAwayElementwiseLeadingOneDim(MLIRContext *context,
506 PatternBenefit benefit = 1)
507 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
508
509 LogicalResult matchAndRewrite(Operation *op,
510 PatternRewriter &rewriter) const override {
511 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
512 return failure();
513 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
514 if (!vecType)
515 return failure();
516 VectorType newVecType = trimLeadingOneDims(vecType);
517 if (newVecType == vecType)
518 return failure();
519 int64_t dropDim = vecType.getRank() - newVecType.getRank();
520 SmallVector<Value, 4> newOperands;
521 for (Value operand : op->getOperands()) {
522 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
523 newOperands.push_back(rewriter.create<vector::ExtractOp>(
524 op->getLoc(), operand, splatZero(dropDim)));
525 } else {
526 newOperands.push_back(Elt: operand);
527 }
528 }
529 Operation *newOp =
530 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
531 newOperands, newVecType, op->getAttrs());
532 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
533 newOp->getResult(0));
534 return success();
535 }
536};
537
538// Drops leading 1 dimensions from vector.constant_mask and inserts a
539// vector.broadcast back to the original shape.
540struct CastAwayConstantMaskLeadingOneDim
541 : public OpRewritePattern<vector::ConstantMaskOp> {
542 using OpRewritePattern::OpRewritePattern;
543
544 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
545 PatternRewriter &rewriter) const override {
546 VectorType oldType = mask.getType();
547 VectorType newType = trimLeadingOneDims(oldType);
548
549 if (newType == oldType)
550 return failure();
551
552 int64_t dropDim = oldType.getRank() - newType.getRank();
553 SmallVector<int64_t> dimSizes;
554 for (auto attr : mask.getMaskDimSizes())
555 dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
556
557 // If any of the dropped unit dims has a size of `0`, the entire mask is a
558 // zero mask, else the unit dim has no effect on the mask.
559 int64_t flatLeadingSize =
560 std::accumulate(first: dimSizes.begin(), last: dimSizes.begin() + dropDim + 1,
561 init: static_cast<int64_t>(1), binary_op: std::multiplies<int64_t>());
562 SmallVector<int64_t> newDimSizes({flatLeadingSize});
563 newDimSizes.append(in_start: dimSizes.begin() + dropDim + 1, in_end: dimSizes.end());
564
565 auto newMask = rewriter.create<vector::ConstantMaskOp>(
566 mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
567 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
568 return success();
569 }
570};
571
572} // namespace
573
574void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
575 RewritePatternSet &patterns, PatternBenefit benefit) {
576 patterns
577 .add<CastAwayExtractStridedSliceLeadingOneDim,
578 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
579 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
580 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
581 CastAwayContractionLeadingOneDim>(arg: patterns.getContext(), args&: benefit);
582 populateShapeCastFoldingPatterns(patterns, benefit);
583}
584

source code of mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp