1//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <numeric>
10
11#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
12#include "mlir/Dialect/Vector/IR/VectorOps.h"
13#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
15#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/TypeUtilities.h"
18
19#define DEBUG_TYPE "vector-drop-unit-dim"
20
21using namespace mlir;
22using namespace mlir::vector;
23
24// Trims leading one dimensions from `oldType` and returns the result type.
25// Returns `vector<1xT>` if `oldType` only has one element.
26static VectorType trimLeadingOneDims(VectorType oldType) {
27 ArrayRef<int64_t> oldShape = oldType.getShape();
28 ArrayRef<int64_t> newShape = oldShape;
29
30 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
31 ArrayRef<bool> newScalableDims = oldScalableDims;
32
33 while (!newShape.empty() && newShape.front() == 1 &&
34 !newScalableDims.front()) {
35 newShape = newShape.drop_front(N: 1);
36 newScalableDims = newScalableDims.drop_front(N: 1);
37 }
38
39 // Make sure we have at least 1 dimension per vector type requirements.
40 if (newShape.empty()) {
41 newShape = oldShape.take_back();
42 newScalableDims = oldType.getScalableDims().take_back();
43 }
44 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
45}
46
47/// Return a smallVector of size `rank` containing all zeros.
48static SmallVector<int64_t> splatZero(int64_t rank) {
49 return SmallVector<int64_t>(rank, 0);
50}
51namespace {
52
53// Casts away leading one dimensions in vector.extract_strided_slice's vector
54// input by inserting vector.broadcast.
55struct CastAwayExtractStridedSliceLeadingOneDim
56 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
57 using OpRewritePattern::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
60 PatternRewriter &rewriter) const override {
61 // vector.extract_strided_slice requires the input and output vector to have
62 // the same rank. Here we drop leading one dimensions from the input vector
63 // type to make sure we don't cause mismatch.
64 VectorType oldSrcType = extractOp.getSourceVectorType();
65 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
66
67 if (newSrcType.getRank() == oldSrcType.getRank())
68 return failure();
69
70 int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
71
72 VectorType oldDstType = extractOp.getType();
73 VectorType newDstType =
74 VectorType::get(oldDstType.getShape().drop_front(dropCount),
75 oldDstType.getElementType(),
76 oldDstType.getScalableDims().drop_front(dropCount));
77
78 Location loc = extractOp.getLoc();
79
80 Value newSrcVector = rewriter.create<vector::ExtractOp>(
81 loc, extractOp.getVector(), splatZero(dropCount));
82
83 // The offsets/sizes/strides attribute can have a less number of elements
84 // than the input vector's rank: it is meant for the leading dimensions.
85 auto newOffsets = rewriter.getArrayAttr(
86 value: extractOp.getOffsets().getValue().drop_front(dropCount));
87 auto newSizes = rewriter.getArrayAttr(
88 value: extractOp.getSizes().getValue().drop_front(dropCount));
89 auto newStrides = rewriter.getArrayAttr(
90 value: extractOp.getStrides().getValue().drop_front(dropCount));
91
92 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
93 loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
94
95 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
96 newExtractOp);
97
98 return success();
99 }
100};
101
102// Casts away leading one dimensions in vector.insert_strided_slice's vector
103// inputs by inserting vector.broadcast.
104struct CastAwayInsertStridedSliceLeadingOneDim
105 : public OpRewritePattern<vector::InsertStridedSliceOp> {
106 using OpRewritePattern::OpRewritePattern;
107
108 LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
109 PatternRewriter &rewriter) const override {
110 VectorType oldSrcType = insertOp.getSourceVectorType();
111 VectorType newSrcType = trimLeadingOneDims(oldSrcType);
112 VectorType oldDstType = insertOp.getDestVectorType();
113 VectorType newDstType = trimLeadingOneDims(oldDstType);
114
115 int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
116 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
117 if (srcDropCount == 0 && dstDropCount == 0)
118 return failure();
119
120 // Trim leading one dimensions from both operands.
121 Location loc = insertOp.getLoc();
122
123 Value newSrcVector = rewriter.create<vector::ExtractOp>(
124 loc, insertOp.getValueToStore(), splatZero(srcDropCount));
125 Value newDstVector = rewriter.create<vector::ExtractOp>(
126 loc, insertOp.getDest(), splatZero(dstDropCount));
127
128 auto newOffsets = rewriter.getArrayAttr(
129 value: insertOp.getOffsets().getValue().take_back(newDstType.getRank()));
130 auto newStrides = rewriter.getArrayAttr(
131 value: insertOp.getStrides().getValue().take_back(newSrcType.getRank()));
132
133 auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
134 loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
135
136 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
137 newInsertOp);
138
139 return success();
140 }
141};
142
143// Casts away leading one dimensions in vector.insert's vector inputs by
144// inserting vector.broadcast.
145struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
146 using OpRewritePattern::OpRewritePattern;
147
148 LogicalResult matchAndRewrite(vector::InsertOp insertOp,
149 PatternRewriter &rewriter) const override {
150 Type oldSrcType = insertOp.getValueToStoreType();
151 Type newSrcType = oldSrcType;
152 int64_t oldSrcRank = 0, newSrcRank = 0;
153 if (auto type = dyn_cast<VectorType>(oldSrcType)) {
154 newSrcType = trimLeadingOneDims(type);
155 oldSrcRank = type.getRank();
156 newSrcRank = cast<VectorType>(newSrcType).getRank();
157 }
158
159 VectorType oldDstType = insertOp.getDestVectorType();
160 VectorType newDstType = trimLeadingOneDims(oldDstType);
161
162 int64_t srcDropCount = oldSrcRank - newSrcRank;
163 int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
164 if (srcDropCount == 0 && dstDropCount == 0)
165 return failure();
166
167 // Trim leading one dimensions from both operands.
168 Location loc = insertOp.getLoc();
169
170 Value newSrcVector = insertOp.getValueToStore();
171 if (oldSrcRank != 0) {
172 newSrcVector = rewriter.create<vector::ExtractOp>(
173 loc, insertOp.getValueToStore(), splatZero(srcDropCount));
174 }
175 Value newDstVector = rewriter.create<vector::ExtractOp>(
176 loc, insertOp.getDest(), splatZero(dstDropCount));
177
178 // New position rank needs to be computed in two steps: (1) if destination
179 // type has leading unit dims, we also trim the position array accordingly,
180 // then (2) if source type also has leading unit dims, we need to append
181 // zeroes to the position array accordingly.
182 unsigned oldPosRank = insertOp.getNumIndices();
183 unsigned newPosRank = std::max<int64_t>(a: 0, b: oldPosRank - dstDropCount);
184 SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition();
185 SmallVector<OpFoldResult> newPosition =
186 llvm::to_vector(Range: ArrayRef(oldPosition).take_back(N: newPosRank));
187 newPosition.resize(newDstType.getRank() - newSrcRank,
188 rewriter.getI64IntegerAttr(0));
189
190 auto newInsertOp = rewriter.create<vector::InsertOp>(
191 loc, newSrcVector, newDstVector, newPosition);
192
193 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
194 newInsertOp);
195
196 return success();
197 }
198};
199
200static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
201 VectorType newType, AffineMap newMap,
202 VectorType oldMaskType) {
203 // Infer the type of the new mask from the new map.
204 VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
205
206 // If the new mask is broadcastable to the old result type, we can safely
207 // use a `vector.extract` to get the new mask. Otherwise the best we can
208 // do is shape cast.
209 if (vector::isBroadcastableTo(srcType: newMaskType, dstVectorType: oldMaskType) ==
210 BroadcastableToResult::Success) {
211 int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
212 return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
213 }
214 return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
215}
216
217// Turns vector.transfer_read on vector with leading 1 dimensions into
218// vector.shape_cast followed by vector.transfer_read on vector without leading
219// 1 dimensions.
220struct CastAwayTransferReadLeadingOneDim
221 : public OpRewritePattern<vector::TransferReadOp> {
222 using OpRewritePattern::OpRewritePattern;
223
224 LogicalResult matchAndRewrite(vector::TransferReadOp read,
225 PatternRewriter &rewriter) const override {
226 // TODO(#78787): Not supported masked op yet.
227 if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
228 return failure();
229 // TODO: support 0-d corner case.
230 if (read.getTransferRank() == 0)
231 return failure();
232
233 auto shapedType = cast<ShapedType>(read.getBase().getType());
234 if (shapedType.getElementType() != read.getVectorType().getElementType())
235 return failure();
236
237 VectorType oldType = read.getVectorType();
238 VectorType newType = trimLeadingOneDims(oldType);
239
240 if (newType == oldType)
241 return failure();
242
243 AffineMap oldMap = read.getPermutationMap();
244 ArrayRef<AffineExpr> newResults =
245 oldMap.getResults().take_back(N: newType.getRank());
246 AffineMap newMap =
247 AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), results: newResults,
248 context: rewriter.getContext());
249
250 ArrayAttr inBoundsAttr;
251 if (read.getInBounds())
252 inBoundsAttr = rewriter.getArrayAttr(
253 value: read.getInBoundsAttr().getValue().take_back(newType.getRank()));
254
255 Value mask = Value();
256 if (read.getMask()) {
257 VectorType maskType = read.getMaskType();
258 mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
259 newType, newMap, maskType);
260 }
261
262 auto newRead = rewriter.create<vector::TransferReadOp>(
263 read.getLoc(), newType, read.getBase(), read.getIndices(),
264 AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
265 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
266
267 return success();
268 }
269};
270
271// Turns vector.transfer_write on vector with leading 1 dimensions into
272// vector.shape_cast followed by vector.transfer_write on vector without leading
273// 1 dimensions.
274struct CastAwayTransferWriteLeadingOneDim
275 : public OpRewritePattern<vector::TransferWriteOp> {
276 using OpRewritePattern::OpRewritePattern;
277
278 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
279 PatternRewriter &rewriter) const override {
280 // TODO(#78787): Not supported masked op yet.
281 if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
282 return failure();
283 // TODO: support 0-d corner case.
284 if (write.getTransferRank() == 0)
285 return failure();
286
287 auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
288 if (shapedType.getElementType() != write.getVectorType().getElementType())
289 return failure();
290
291 VectorType oldType = write.getVectorType();
292 VectorType newType = trimLeadingOneDims(oldType);
293 if (newType == oldType)
294 return failure();
295 int64_t dropDim = oldType.getRank() - newType.getRank();
296
297 AffineMap oldMap = write.getPermutationMap();
298 ArrayRef<AffineExpr> newResults =
299 oldMap.getResults().take_back(N: newType.getRank());
300 AffineMap newMap =
301 AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), results: newResults,
302 context: rewriter.getContext());
303
304 ArrayAttr inBoundsAttr;
305 if (write.getInBounds())
306 inBoundsAttr = rewriter.getArrayAttr(
307 value: write.getInBoundsAttr().getValue().take_back(newType.getRank()));
308
309 auto newVector = rewriter.create<vector::ExtractOp>(
310 write.getLoc(), write.getVector(), splatZero(dropDim));
311
312 if (write.getMask()) {
313 VectorType maskType = write.getMaskType();
314 Value newMask = dropUnitDimsFromMask(
315 rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
316 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
317 write, newVector, write.getBase(), write.getIndices(),
318 AffineMapAttr::get(newMap), newMask, inBoundsAttr);
319 return success();
320 }
321
322 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
323 write, newVector, write.getBase(), write.getIndices(),
324 AffineMapAttr::get(newMap), inBoundsAttr);
325 return success();
326 }
327};
328
329} // namespace
330
331FailureOr<Value>
332mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
333 MaskingOpInterface maskingOp,
334 RewriterBase &rewriter) {
335 VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
336 if (oldAccType == nullptr)
337 return failure();
338 if (oldAccType.getRank() < 2)
339 return failure();
340 if (oldAccType.getShape()[0] != 1)
341 return failure();
342 // currently we support only dropping one dim but the pattern can be applied
343 // greedily to drop more.
344 int64_t dropDim = 1;
345
346 auto oldIndexingMaps = contractOp.getIndexingMapsArray();
347 SmallVector<AffineMap> newIndexingMaps;
348
349 auto oldIteratorTypes = contractOp.getIteratorTypes();
350 SmallVector<Attribute> newIteratorTypes;
351
352 int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
353
354 if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
355 // only parallel type iterators can be dropped.
356 return failure();
357
358 for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
359 int64_t currDim = it.index();
360 if (currDim == dimToDrop)
361 continue;
362 newIteratorTypes.push_back(it.value());
363 }
364
365 SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
366 contractOp.getAcc()};
367 SmallVector<Value> newOperands;
368 auto loc = contractOp.getLoc();
369
370 for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
371 // Check if the dim to be dropped exists as a leading dim in the operand
372 // if it does then we use vector.extract to drop it.
373 bool validExtract = false;
374 SmallVector<AffineExpr> results;
375 auto map = it.value();
376 int64_t orginalZeroDim = it.value().getDimPosition(0);
377 if (orginalZeroDim != dimToDrop) {
378 // There are two reasons to be in this path, 1. We need to
379 // transpose the operand to make the dim to be dropped
380 // leading. 2. The dim to be dropped does not exist and in
381 // that case we dont want to add a unit transpose but we must
382 // check all the indices to make sure this is the case.
383 bool transposeNeeded = false;
384 SmallVector<int64_t> perm;
385 SmallVector<AffineExpr> transposeResults;
386
387 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
388 int64_t currDim = map.getDimPosition(i);
389 if (currDim == dimToDrop) {
390 transposeNeeded = true;
391 perm.insert(perm.begin(), i);
392 auto targetExpr = rewriter.getAffineDimExpr(currDim);
393 transposeResults.insert(transposeResults.begin(), targetExpr);
394 } else {
395 perm.push_back(i);
396 auto targetExpr = rewriter.getAffineDimExpr(currDim);
397 transposeResults.push_back(targetExpr);
398 }
399 }
400
401 // Checks if only the outer, unit dimensions (of size 1) are permuted.
402 // Such transposes do not materially effect the underlying vector and can
403 // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
404 bool transposeNonOuterUnitDims = false;
405 auto operandShape = cast<ShapedType>(operands[it.index()].getType());
406 for (auto [index, dim] :
407 llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
408 if (dim != static_cast<int64_t>(index) &&
409 operandShape.getDimSize(index) != 1) {
410 transposeNonOuterUnitDims = true;
411 break;
412 }
413 }
414
415 // Do the transpose now if needed so that we can drop the
416 // correct dim using extract later.
417 if (transposeNeeded) {
418 map = AffineMap::get(map.getNumDims(), 0, transposeResults,
419 contractOp.getContext());
420 if (transposeNonOuterUnitDims) {
421 operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
422 loc, operands[it.index()], perm);
423 }
424 }
425 }
426 // We have taken care to have the dim to be dropped be
427 // the leading dim. If its still not leading that means it
428 // does not exist in this operand and hence we do not need
429 // an extract.
430 if (map.getDimPosition(0) == dimToDrop)
431 validExtract = true;
432
433 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
434 int64_t currDim = map.getDimPosition(i);
435 if (currDim == dimToDrop)
436 // This is the dim we are dropping.
437 continue;
438 auto targetExpr = rewriter.getAffineDimExpr(
439 currDim < dimToDrop ? currDim : currDim - 1);
440 results.push_back(targetExpr);
441 }
442 newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
443 contractOp.getContext()));
444 // Extract if its a valid extraction, otherwise use the operand
445 // without extraction.
446 newOperands.push_back(
447 validExtract ? rewriter.create<vector::ExtractOp>(
448 loc, operands[it.index()], splatZero(dropDim))
449 : operands[it.index()]);
450 }
451
452 // Depending on whether this vector.contract is masked, the replacing Op
453 // should either be a new vector.contract Op or vector.mask Op.
454 Operation *newOp = rewriter.create<vector::ContractionOp>(
455 loc, newOperands[0], newOperands[1], newOperands[2],
456 rewriter.getAffineMapArrayAttr(newIndexingMaps),
457 rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
458
459 if (maskingOp) {
460 auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
461 splatZero(dropDim));
462
463 newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: newMask);
464 }
465
466 return rewriter
467 .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
468 newOp->getResults()[0])
469 .getResult();
470}
471
472namespace {
473
474/// Turns vector.contract on vector with leading 1 dimensions into
475/// vector.extract followed by vector.contract on vector without leading
476/// 1 dimensions. Also performs transpose of lhs and rhs operands if required
477/// prior to extract.
478struct CastAwayContractionLeadingOneDim
479 : public MaskableOpRewritePattern<vector::ContractionOp> {
480 using MaskableOpRewritePattern::MaskableOpRewritePattern;
481
482 FailureOr<Value>
483 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
484 MaskingOpInterface maskingOp,
485 PatternRewriter &rewriter) const override {
486 return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
487 }
488};
489
490/// Looks at elementwise operations on vectors with at least one leading
491/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
492/// and cast aways the leading one dimensions (_plural_) and then broadcasts
493/// the results.
494///
495/// Example before:
496/// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
497/// Example after:
498/// %2 = arith.mulf %0, %1 : vector<4x1xf32>
499/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
500///
501/// Does support scalable vectors.
502class CastAwayElementwiseLeadingOneDim : public RewritePattern {
503public:
504 CastAwayElementwiseLeadingOneDim(MLIRContext *context,
505 PatternBenefit benefit = 1)
506 : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
507
508 LogicalResult matchAndRewrite(Operation *op,
509 PatternRewriter &rewriter) const override {
510 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
511 return failure();
512 auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
513 if (!vecType)
514 return failure();
515 VectorType newVecType = trimLeadingOneDims(vecType);
516 if (newVecType == vecType)
517 return failure();
518 int64_t dropDim = vecType.getRank() - newVecType.getRank();
519 SmallVector<Value, 4> newOperands;
520 for (Value operand : op->getOperands()) {
521 if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
522 newOperands.push_back(rewriter.create<vector::ExtractOp>(
523 op->getLoc(), operand, splatZero(dropDim)));
524 } else {
525 newOperands.push_back(Elt: operand);
526 }
527 }
528 Operation *newOp =
529 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
530 newOperands, newVecType, op->getAttrs());
531 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
532 newOp->getResult(0));
533 return success();
534 }
535};
536
537// Drops leading 1 dimensions from vector.constant_mask and inserts a
538// vector.broadcast back to the original shape.
539struct CastAwayConstantMaskLeadingOneDim
540 : public OpRewritePattern<vector::ConstantMaskOp> {
541 using OpRewritePattern::OpRewritePattern;
542
543 LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
544 PatternRewriter &rewriter) const override {
545 VectorType oldType = mask.getType();
546 VectorType newType = trimLeadingOneDims(oldType);
547
548 if (newType == oldType)
549 return failure();
550
551 int64_t dropDim = oldType.getRank() - newType.getRank();
552 ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
553
554 // If any of the dropped unit dims has a size of `0`, the entire mask is a
555 // zero mask, else the unit dim has no effect on the mask.
556 int64_t flatLeadingSize =
557 std::accumulate(first: dimSizes.begin(), last: dimSizes.begin() + dropDim + 1,
558 init: static_cast<int64_t>(1), binary_op: std::multiplies<int64_t>());
559 SmallVector<int64_t> newDimSizes = {flatLeadingSize};
560 newDimSizes.append(in_start: dimSizes.begin() + dropDim + 1, in_end: dimSizes.end());
561
562 auto newMask = rewriter.create<vector::ConstantMaskOp>(
563 mask.getLoc(), newType, newDimSizes);
564 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
565 return success();
566 }
567};
568
569} // namespace
570
571void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
572 RewritePatternSet &patterns, PatternBenefit benefit) {
573 patterns
574 .add<CastAwayExtractStridedSliceLeadingOneDim,
575 CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
576 CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
577 CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
578 CastAwayContractionLeadingOneDim>(arg: patterns.getContext(), args&: benefit);
579}
580

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp