1//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrite patterns for the permutation_map attribute of
10// vector.transfer operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18#include "mlir/Interfaces/VectorInterfaces.h"
19
20using namespace mlir;
21using namespace mlir::vector;
22
23/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24/// permutation based on the given indices.
25static ArrayAttr
26inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27 const SmallVector<unsigned> &permutation) {
28 SmallVector<bool> newInBoundsValues(permutation.size());
29 size_t index = 0;
30 for (unsigned pos : permutation)
31 newInBoundsValues[pos] =
32 cast<BoolAttr>(attr.getValue()[index++]).getValue();
33 return builder.getBoolArrayAttr(newInBoundsValues);
34}
35
36/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37/// dimensions.
38static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39 int64_t addedRank) {
40 auto originalVecType = cast<VectorType>(vec.getType());
41 SmallVector<int64_t> newShape(addedRank, 1);
42 newShape.append(originalVecType.getShape().begin(),
43 originalVecType.getShape().end());
44
45 SmallVector<bool> newScalableDims(addedRank, false);
46 newScalableDims.append(originalVecType.getScalableDims().begin(),
47 originalVecType.getScalableDims().end());
48 VectorType newVecType = VectorType::get(
49 newShape, originalVecType.getElementType(), newScalableDims);
50 return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
51}
52
53/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
54/// dimensions.
55static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
56 int64_t addedRank) {
57 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
58 SmallVector<int64_t> permutation;
59 for (int64_t i = addedRank,
60 e = cast<VectorType>(broadcasted.getType()).getRank();
61 i < e; ++i)
62 permutation.push_back(Elt: i);
63 for (int64_t i = 0; i < addedRank; ++i)
64 permutation.push_back(Elt: i);
65 return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
66}
67
68//===----------------------------------------------------------------------===//
69// populateVectorTransferPermutationMapLoweringPatterns
70//===----------------------------------------------------------------------===//
71
72namespace {
73/// Lower transfer_read op with permutation into a transfer_read with a
74/// permutation map composed of leading zeros followed by a minor identiy +
75/// vector.transpose op.
76/// Ex:
77/// vector.transfer_read ...
78/// permutation_map: (d0, d1, d2) -> (0, d1)
79/// into:
80/// %v = vector.transfer_read ...
81/// permutation_map: (d0, d1, d2) -> (d1, 0)
82/// vector.transpose %v, [1, 0]
83///
84/// vector.transfer_read ...
85/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
86/// into:
87/// %v = vector.transfer_read ...
88/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
89/// vector.transpose %v, [0, 1, 3, 2, 4]
90/// Note that an alternative is to transform it to linalg.transpose +
91/// vector.transfer_read to do the transpose in memory instead.
92struct TransferReadPermutationLowering
93 : public OpRewritePattern<vector::TransferReadOp> {
94 using OpRewritePattern::OpRewritePattern;
95
96 LogicalResult matchAndRewrite(vector::TransferReadOp op,
97 PatternRewriter &rewriter) const override {
98 // TODO: support 0-d corner case.
99 if (op.getTransferRank() == 0)
100 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
101
102 SmallVector<unsigned> permutation;
103 AffineMap map = op.getPermutationMap();
104 if (map.getNumResults() == 0)
105 return rewriter.notifyMatchFailure(op, "0 result permutation map");
106 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
107 return rewriter.notifyMatchFailure(
108 op, "map is not permutable to minor identity, apply another pattern");
109 }
110 AffineMap permutationMap =
111 map.getPermutationMap(permutation, op.getContext());
112 if (permutationMap.isIdentity())
113 return rewriter.notifyMatchFailure(op, "map is not identity");
114
115 permutationMap = map.getPermutationMap(permutation, op.getContext());
116 // Caluclate the map of the new read by applying the inverse permutation.
117 permutationMap = inversePermutation(map: permutationMap);
118 AffineMap newMap = permutationMap.compose(map);
119 // Apply the reverse transpose to deduce the type of the transfer_read.
120 ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
121 SmallVector<int64_t> newVectorShape(originalShape.size());
122 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
123 SmallVector<bool> newScalableDims(originalShape.size());
124 for (const auto &pos : llvm::enumerate(First&: permutation)) {
125 newVectorShape[pos.value()] = originalShape[pos.index()];
126 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
127 }
128
129 // Transpose in_bounds attribute.
130 ArrayAttr newInBoundsAttr =
131 op.getInBounds() ? inverseTransposeInBoundsAttr(
132 rewriter, op.getInBounds().value(), permutation)
133 : ArrayAttr();
134
135 // Generate new transfer_read operation.
136 VectorType newReadType = VectorType::get(
137 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
138 Value newRead = rewriter.create<vector::TransferReadOp>(
139 op.getLoc(), newReadType, op.getSource(), op.getIndices(),
140 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
141 newInBoundsAttr);
142
143 // Transpose result of transfer_read.
144 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
146 transposePerm);
147 return success();
148 }
149};
150
151/// Lower transfer_write op with permutation into a transfer_write with a
152/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
153/// Ex:
154/// vector.transfer_write %v ...
155/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
156/// into:
157/// %tmp = vector.transpose %v, [2, 0, 1]
158/// vector.transfer_write %tmp ...
159/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
160///
161/// vector.transfer_write %v ...
162/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
163/// into:
164/// %tmp = vector.transpose %v, [1, 0]
165/// %v = vector.transfer_write %tmp ...
166/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167struct TransferWritePermutationLowering
168 : public OpRewritePattern<vector::TransferWriteOp> {
169 using OpRewritePattern::OpRewritePattern;
170
171 LogicalResult matchAndRewrite(vector::TransferWriteOp op,
172 PatternRewriter &rewriter) const override {
173 // TODO: support 0-d corner case.
174 if (op.getTransferRank() == 0)
175 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
176
177 SmallVector<unsigned> permutation;
178 AffineMap map = op.getPermutationMap();
179 if (map.isMinorIdentity())
180 return rewriter.notifyMatchFailure(op, "map is already minor identity");
181
182 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
183 return rewriter.notifyMatchFailure(
184 op, "map is not permutable to minor identity, apply another pattern");
185 }
186
187 // Remove unused dims from the permutation map. E.g.:
188 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
189 // comp = (d0, d1, d2) -> (d2, d0, d1)
190 auto comp = compressUnusedDims(map);
191 AffineMap permutationMap = inversePermutation(comp);
192 // Get positions of remaining result dims.
193 SmallVector<int64_t> indices;
194 llvm::transform(Range: permutationMap.getResults(), d_first: std::back_inserter(x&: indices),
195 F: [](AffineExpr expr) {
196 return dyn_cast<AffineDimExpr>(Val&: expr).getPosition();
197 });
198
199 // Transpose in_bounds attribute.
200 ArrayAttr newInBoundsAttr =
201 op.getInBounds() ? inverseTransposeInBoundsAttr(
202 rewriter, op.getInBounds().value(), permutation)
203 : ArrayAttr();
204
205 // Generate new transfer_write operation.
206 Value newVec = rewriter.create<vector::TransposeOp>(
207 op.getLoc(), op.getVector(), indices);
208 auto newMap = AffineMap::getMinorIdentityMap(
209 dims: map.getNumDims(), results: map.getNumResults(), context: rewriter.getContext());
210 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
211 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
212 op.getMask(), newInBoundsAttr);
213
214 return success();
215 }
216};
217
218/// Convert a transfer.write op with a map which isn't the permutation of a
219/// minor identity into a vector.broadcast + transfer_write with permutation of
220/// minor identity map by adding unit dim on inner dimension. Ex:
221/// ```
222/// vector.transfer_write %v
223/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
224/// vector<8x16xf32>
225/// ```
226/// into:
227/// ```
228/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
229/// vector.transfer_write %v1
230/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
231/// vector<1x8x16xf32>
232/// ```
233struct TransferWriteNonPermutationLowering
234 : public OpRewritePattern<vector::TransferWriteOp> {
235 using OpRewritePattern::OpRewritePattern;
236
237 LogicalResult matchAndRewrite(vector::TransferWriteOp op,
238 PatternRewriter &rewriter) const override {
239 // TODO: support 0-d corner case.
240 if (op.getTransferRank() == 0)
241 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
242
243 SmallVector<unsigned> permutation;
244 AffineMap map = op.getPermutationMap();
245 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
246 return rewriter.notifyMatchFailure(
247 op,
248 "map is already permutable to minor identity, apply another pattern");
249 }
250
251 // Missing outer dimensions are allowed, find the most outer existing
252 // dimension then deduce the missing inner dimensions.
253 SmallVector<bool> foundDim(map.getNumDims(), false);
254 for (AffineExpr exp : map.getResults())
255 foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
256 SmallVector<AffineExpr> exprs;
257 bool foundFirstDim = false;
258 SmallVector<int64_t> missingInnerDim;
259 for (size_t i = 0; i < foundDim.size(); i++) {
260 if (foundDim[i]) {
261 foundFirstDim = true;
262 continue;
263 }
264 if (!foundFirstDim)
265 continue;
266 // Once we found one outer dimension existing in the map keep track of all
267 // the missing dimensions after that.
268 missingInnerDim.push_back(Elt: i);
269 exprs.push_back(Elt: rewriter.getAffineDimExpr(position: i));
270 }
271 // Vector: add unit dims at the beginning of the shape.
272 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
273 missingInnerDim.size());
274 // Mask: add unit dims at the end of the shape.
275 Value newMask;
276 if (op.getMask())
277 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
278 missingInnerDim.size());
279 exprs.append(in_start: map.getResults().begin(), in_end: map.getResults().end());
280 AffineMap newMap =
281 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
282 // All the new dimensions added are inbound.
283 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
284 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
285 newInBoundsValues.push_back(Elt: op.isDimInBounds(i));
286 }
287 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
288 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
289 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
290 newMask, newInBoundsAttr);
291 return success();
292 }
293};
294
295/// Lower transfer_read op with broadcast in the leading dimensions into
296/// transfer_read of lower rank + vector.broadcast.
297/// Ex: vector.transfer_read ...
298/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
299/// into:
300/// %v = vector.transfer_read ...
301/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
302/// vector.broadcast %v
303struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
304 using OpRewritePattern::OpRewritePattern;
305
306 LogicalResult matchAndRewrite(vector::TransferReadOp op,
307 PatternRewriter &rewriter) const override {
308 // TODO: support 0-d corner case.
309 if (op.getTransferRank() == 0)
310 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
311
312 AffineMap map = op.getPermutationMap();
313 unsigned numLeadingBroadcast = 0;
314 for (auto expr : map.getResults()) {
315 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
316 if (!dimExpr || dimExpr.getValue() != 0)
317 break;
318 numLeadingBroadcast++;
319 }
320 // If there are no leading zeros in the map there is nothing to do.
321 if (numLeadingBroadcast == 0)
322 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
323
324 VectorType originalVecType = op.getVectorType();
325 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
326 // Calculate new map, vector type and masks without the leading zeros.
327 AffineMap newMap = AffineMap::get(
328 map.getNumDims(), 0, map.getResults().take_back(N: reducedShapeRank),
329 op.getContext());
330 // Only remove the leading zeros if the rest of the map is a minor identity
331 // with broadasting. Otherwise we first want to permute the map.
332 if (!newMap.isMinorIdentityWithBroadcasting()) {
333 return rewriter.notifyMatchFailure(
334 op, "map is not a minor identity with broadcasting");
335 }
336
337 // TODO: support zero-dimension vectors natively. See:
338 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
339 // In the meantime, lower these to a scalar load when they pop up.
340 if (reducedShapeRank == 0) {
341 Value newRead;
342 if (isa<TensorType>(op.getShapedType())) {
343 newRead = rewriter.create<tensor::ExtractOp>(
344 op.getLoc(), op.getSource(), op.getIndices());
345 } else {
346 newRead = rewriter.create<memref::LoadOp>(
347 op.getLoc(), originalVecType.getElementType(), op.getSource(),
348 op.getIndices());
349 }
350 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
351 newRead);
352 return success();
353 }
354
355 SmallVector<int64_t> newShape(
356 originalVecType.getShape().take_back(reducedShapeRank));
357 SmallVector<bool> newScalableDims(
358 originalVecType.getScalableDims().take_back(reducedShapeRank));
359 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
360 if (newShape.empty())
361 return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
362
363 VectorType newReadType = VectorType::get(
364 newShape, originalVecType.getElementType(), newScalableDims);
365 ArrayAttr newInBoundsAttr =
366 op.getInBounds()
367 ? rewriter.getArrayAttr(
368 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
369 : ArrayAttr();
370 Value newRead = rewriter.create<vector::TransferReadOp>(
371 op.getLoc(), newReadType, op.getSource(), op.getIndices(),
372 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
373 newInBoundsAttr);
374 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
375 newRead);
376 return success();
377 }
378};
379
380} // namespace
381
382void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
383 RewritePatternSet &patterns, PatternBenefit benefit) {
384 patterns
385 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
387 arg: patterns.getContext(), args&: benefit);
388}
389
390//===----------------------------------------------------------------------===//
391// populateVectorTransferLoweringPatterns
392//===----------------------------------------------------------------------===//
393
394namespace {
395/// Progressive lowering of transfer_read. This pattern supports lowering of
396/// `vector.transfer_read` to a combination of `vector.load` and
397/// `vector.broadcast` if all of the following hold:
398/// - Stride of most minor memref dimension must be 1.
399/// - Out-of-bounds masking is not required.
400/// - If the memref's element type is a vector type then it coincides with the
401/// result type.
402/// - The permutation map doesn't perform permutation (broadcasting is allowed).
403struct TransferReadToVectorLoadLowering
404 : public OpRewritePattern<vector::TransferReadOp> {
405 TransferReadToVectorLoadLowering(MLIRContext *context,
406 std::optional<unsigned> maxRank,
407 PatternBenefit benefit = 1)
408 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
409 maxTransferRank(maxRank) {}
410
411 LogicalResult matchAndRewrite(vector::TransferReadOp read,
412 PatternRewriter &rewriter) const override {
413 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
414 return rewriter.notifyMatchFailure(
415 read, "vector type is greater than max transfer rank");
416 }
417
418 SmallVector<unsigned> broadcastedDims;
419 // Permutations are handled by VectorToSCF or
420 // populateVectorTransferPermutationMapLoweringPatterns.
421 // We let the 0-d corner case pass-through as it is supported.
422 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
423 &broadcastedDims))
424 return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
425
426 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
427 if (!memRefType)
428 return rewriter.notifyMatchFailure(read, "not a memref source");
429
430 // Non-unit strides are handled by VectorToSCF.
431 if (!isLastMemrefDimUnitStride(memRefType))
432 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
433
434 // If there is broadcasting involved then we first load the unbroadcasted
435 // vector, and then broadcast it with `vector.broadcast`.
436 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
437 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(),
438 vectorShape.end());
439 for (unsigned i : broadcastedDims)
440 unbroadcastedVectorShape[i] = 1;
441 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
442 unbroadcastedVectorShape, read.getVectorType().getElementType());
443
444 // `vector.load` supports vector types as memref's elements only when the
445 // resulting vector type is the same as the element type.
446 auto memrefElTy = memRefType.getElementType();
447 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
448 return rewriter.notifyMatchFailure(read, "incompatible element type");
449
450 // Otherwise, element types of the memref and the vector must match.
451 if (!isa<VectorType>(memrefElTy) &&
452 memrefElTy != read.getVectorType().getElementType())
453 return rewriter.notifyMatchFailure(read, "non-matching element type");
454
455 // Out-of-bounds dims are handled by MaterializeTransferMask.
456 if (read.hasOutOfBoundsDim())
457 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
458
459 // Create vector load op.
460 Operation *loadOp;
461 if (read.getMask()) {
462 if (read.getVectorType().getRank() != 1)
463 // vector.maskedload operates on 1-D vectors.
464 return rewriter.notifyMatchFailure(
465 read, "vector type is not rank 1, can't create masked load, needs "
466 "VectorToSCF");
467
468 Value fill = rewriter.create<vector::SplatOp>(
469 read.getLoc(), unbroadcastedVectorType, read.getPadding());
470 loadOp = rewriter.create<vector::MaskedLoadOp>(
471 read.getLoc(), unbroadcastedVectorType, read.getSource(),
472 read.getIndices(), read.getMask(), fill);
473 } else {
474 loadOp = rewriter.create<vector::LoadOp>(
475 read.getLoc(), unbroadcastedVectorType, read.getSource(),
476 read.getIndices());
477 }
478
479 // Insert a broadcasting op if required.
480 if (!broadcastedDims.empty()) {
481 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
482 read, read.getVectorType(), loadOp->getResult(0));
483 } else {
484 rewriter.replaceOp(read, loadOp->getResult(idx: 0));
485 }
486
487 return success();
488 }
489
490 std::optional<unsigned> maxTransferRank;
491};
492
493/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
494// TODO: we shouldn't cross the vector/scalar domains just for this
495// but atm we lack the infra to avoid it. Possible solutions include:
496// - go directly to LLVM + bitcast
497// - introduce a bitcast op and likely a new pointer dialect
498// - let memref.load/store additionally support the 0-d vector case
499// There are still deeper data layout issues lingering even in this
500// trivial case (for architectures for which this matters).
501struct VectorLoadToMemrefLoadLowering
502 : public OpRewritePattern<vector::LoadOp> {
503 using OpRewritePattern::OpRewritePattern;
504
505 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
506 PatternRewriter &rewriter) const override {
507 auto vecType = loadOp.getVectorType();
508 if (vecType.getNumElements() != 1)
509 return rewriter.notifyMatchFailure(loadOp, "not a single element vector");
510
511 auto memrefLoad = rewriter.create<memref::LoadOp>(
512 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
513 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
514 memrefLoad);
515 return success();
516 }
517};
518
519/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
520struct VectorStoreToMemrefStoreLowering
521 : public OpRewritePattern<vector::StoreOp> {
522 using OpRewritePattern::OpRewritePattern;
523
524 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
525 PatternRewriter &rewriter) const override {
526 auto vecType = storeOp.getVectorType();
527 if (vecType.getNumElements() != 1)
528 return rewriter.notifyMatchFailure(storeOp, "not single element vector");
529
530 Value extracted;
531 if (vecType.getRank() == 0) {
532 // TODO: Unifiy once ExtractOp supports 0-d vectors.
533 extracted = rewriter.create<vector::ExtractElementOp>(
534 storeOp.getLoc(), storeOp.getValueToStore());
535 } else {
536 SmallVector<int64_t> indices(vecType.getRank(), 0);
537 extracted = rewriter.create<vector::ExtractOp>(
538 storeOp.getLoc(), storeOp.getValueToStore(), indices);
539 }
540
541 rewriter.replaceOpWithNewOp<memref::StoreOp>(
542 storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
543 return success();
544 }
545};
546
547/// Progressive lowering of transfer_write. This pattern supports lowering of
548/// `vector.transfer_write` to `vector.store` if all of the following hold:
549/// - Stride of most minor memref dimension must be 1.
550/// - Out-of-bounds masking is not required.
551/// - If the memref's element type is a vector type then it coincides with the
552/// type of the written value.
553/// - The permutation map is the minor identity map (neither permutation nor
554/// broadcasting is allowed).
555struct TransferWriteToVectorStoreLowering
556 : public OpRewritePattern<vector::TransferWriteOp> {
557 TransferWriteToVectorStoreLowering(MLIRContext *context,
558 std::optional<unsigned> maxRank,
559 PatternBenefit benefit = 1)
560 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
561 maxTransferRank(maxRank) {}
562
563 LogicalResult matchAndRewrite(vector::TransferWriteOp write,
564 PatternRewriter &rewriter) const override {
565 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
566 return rewriter.notifyMatchFailure(
567 write, "vector type is greater than max transfer rank");
568 }
569
570 // Permutations are handled by VectorToSCF or
571 // populateVectorTransferPermutationMapLoweringPatterns.
572 if ( // pass-through for the 0-d corner case.
573 !write.getPermutationMap().isMinorIdentity())
574 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
575 diag << "permutation map is not minor identity: " << write;
576 });
577
578 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
579 if (!memRefType)
580 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
581 diag << "not a memref type: " << write;
582 });
583
584 // Non-unit strides are handled by VectorToSCF.
585 if (!isLastMemrefDimUnitStride(memRefType))
586 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
587 diag << "most minor stride is not 1: " << write;
588 });
589
590 // `vector.store` supports vector types as memref's elements only when the
591 // type of the vector value being written is the same as the element type.
592 auto memrefElTy = memRefType.getElementType();
593 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
594 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
595 diag << "elemental type mismatch: " << write;
596 });
597
598 // Otherwise, element types of the memref and the vector must match.
599 if (!isa<VectorType>(memrefElTy) &&
600 memrefElTy != write.getVectorType().getElementType())
601 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
602 diag << "elemental type mismatch: " << write;
603 });
604
605 // Out-of-bounds dims are handled by MaterializeTransferMask.
606 if (write.hasOutOfBoundsDim())
607 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
608 diag << "out of bounds dim: " << write;
609 });
610 if (write.getMask()) {
611 if (write.getVectorType().getRank() != 1)
612 // vector.maskedstore operates on 1-D vectors.
613 return rewriter.notifyMatchFailure(
614 write.getLoc(), [=](Diagnostic &diag) {
615 diag << "vector type is not rank 1, can't create masked store, "
616 "needs VectorToSCF: "
617 << write;
618 });
619
620 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
621 write, write.getSource(), write.getIndices(), write.getMask(),
622 write.getVector());
623 } else {
624 rewriter.replaceOpWithNewOp<vector::StoreOp>(
625 write, write.getVector(), write.getSource(), write.getIndices());
626 }
627 return success();
628 }
629
630 std::optional<unsigned> maxTransferRank;
631};
632} // namespace
633
634void mlir::vector::populateVectorTransferLoweringPatterns(
635 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
636 PatternBenefit benefit) {
637 patterns.add<TransferReadToVectorLoadLowering,
638 TransferWriteToVectorStoreLowering>(arg: patterns.getContext(),
639 args&: maxTransferRank, args&: benefit);
640 patterns
641 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
642 arg: patterns.getContext(), args&: benefit);
643}
644

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp