1//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrite patterns for the permutation_map attribute of
10// vector.transfer operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/MemRef/IR/MemRef.h"
15#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16
17using namespace mlir;
18using namespace mlir::vector;
19
20/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
21/// permutation based on the given indices.
22static ArrayAttr
23inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
24 const SmallVector<unsigned> &permutation) {
25 SmallVector<bool> newInBoundsValues(permutation.size());
26 size_t index = 0;
27 for (unsigned pos : permutation)
28 newInBoundsValues[pos] =
29 cast<BoolAttr>(Val: attr.getValue()[index++]).getValue();
30 return builder.getBoolArrayAttr(values: newInBoundsValues);
31}
32
33/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
34/// dimensions.
35static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
36 int64_t addedRank) {
37 auto originalVecType = cast<VectorType>(Val: vec.getType());
38 SmallVector<int64_t> newShape(addedRank, 1);
39 newShape.append(in_start: originalVecType.getShape().begin(),
40 in_end: originalVecType.getShape().end());
41
42 SmallVector<bool> newScalableDims(addedRank, false);
43 newScalableDims.append(in_start: originalVecType.getScalableDims().begin(),
44 in_end: originalVecType.getScalableDims().end());
45 VectorType newVecType = VectorType::get(
46 shape: newShape, elementType: originalVecType.getElementType(), scalableDims: newScalableDims);
47 return builder.create<vector::BroadcastOp>(location: loc, args&: newVecType, args&: vec);
48}
49
50/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
51/// dimensions.
52static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
53 int64_t addedRank) {
54 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
55 SmallVector<int64_t> permutation;
56 for (int64_t i = addedRank,
57 e = cast<VectorType>(Val: broadcasted.getType()).getRank();
58 i < e; ++i)
59 permutation.push_back(Elt: i);
60 for (int64_t i = 0; i < addedRank; ++i)
61 permutation.push_back(Elt: i);
62 return builder.create<vector::TransposeOp>(location: loc, args&: broadcasted, args&: permutation);
63}
64
65//===----------------------------------------------------------------------===//
66// populateVectorTransferPermutationMapLoweringPatterns
67//===----------------------------------------------------------------------===//
68
69namespace {
70/// Lower transfer_read op with permutation into a transfer_read with a
71/// permutation map composed of leading zeros followed by a minor identiy +
72/// vector.transpose op.
73/// Ex:
74/// vector.transfer_read ...
75/// permutation_map: (d0, d1, d2) -> (0, d1)
76/// into:
77/// %v = vector.transfer_read ...
78/// permutation_map: (d0, d1, d2) -> (d1, 0)
79/// vector.transpose %v, [1, 0]
80///
81/// vector.transfer_read ...
82/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
83/// into:
84/// %v = vector.transfer_read ...
85/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
86/// vector.transpose %v, [0, 1, 3, 2, 4]
87/// Note that an alternative is to transform it to linalg.transpose +
88/// vector.transfer_read to do the transpose in memory instead.
89struct TransferReadPermutationLowering
90 : public MaskableOpRewritePattern<vector::TransferReadOp> {
91 using MaskableOpRewritePattern::MaskableOpRewritePattern;
92
93 FailureOr<mlir::Value>
94 matchAndRewriteMaskableOp(vector::TransferReadOp op,
95 MaskingOpInterface maskOp,
96 PatternRewriter &rewriter) const override {
97 // TODO: support 0-d corner case.
98 if (op.getTransferRank() == 0)
99 return rewriter.notifyMatchFailure(arg&: op, msg: "0-d corner case not supported");
100 // TODO: Support transfer_read inside MaskOp case.
101 if (maskOp)
102 return rewriter.notifyMatchFailure(arg&: op, msg: "Masked case not supported");
103
104 SmallVector<unsigned> permutation;
105 AffineMap map = op.getPermutationMap();
106 if (map.getNumResults() == 0)
107 return rewriter.notifyMatchFailure(arg&: op, msg: "0 result permutation map");
108 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
109 return rewriter.notifyMatchFailure(
110 arg&: op, msg: "map is not permutable to minor identity, apply another pattern");
111 }
112 AffineMap permutationMap =
113 map.getPermutationMap(permutation, context: op.getContext());
114 if (permutationMap.isIdentity())
115 return rewriter.notifyMatchFailure(arg&: op, msg: "map is not identity");
116
117 permutationMap = map.getPermutationMap(permutation, context: op.getContext());
118 // Caluclate the map of the new read by applying the inverse permutation.
119 permutationMap = inversePermutation(map: permutationMap);
120 AffineMap newMap = permutationMap.compose(map);
121 // Apply the reverse transpose to deduce the type of the transfer_read.
122 ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
123 SmallVector<int64_t> newVectorShape(originalShape.size());
124 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
125 SmallVector<bool> newScalableDims(originalShape.size());
126 for (const auto &pos : llvm::enumerate(First&: permutation)) {
127 newVectorShape[pos.value()] = originalShape[pos.index()];
128 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
129 }
130
131 // Transpose in_bounds attribute.
132 ArrayAttr newInBoundsAttr =
133 inverseTransposeInBoundsAttr(builder&: rewriter, attr: op.getInBounds(), permutation);
134
135 // Generate new transfer_read operation.
136 VectorType newReadType = VectorType::get(
137 shape: newVectorShape, elementType: op.getVectorType().getElementType(), scalableDims: newScalableDims);
138 Value newRead = rewriter.create<vector::TransferReadOp>(
139 location: op.getLoc(), args&: newReadType, args: op.getBase(), args: op.getIndices(),
140 args: AffineMapAttr::get(value: newMap), args: op.getPadding(), args: op.getMask(),
141 args&: newInBoundsAttr);
142
143 // Transpose result of transfer_read.
144 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
145 return rewriter
146 .create<vector::TransposeOp>(location: op.getLoc(), args&: newRead, args&: transposePerm)
147 .getResult();
148 }
149};
150
151/// Lower transfer_write op with permutation into a transfer_write with a
152/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
153/// Ex:
154/// vector.transfer_write %v ...
155/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
156/// into:
157/// %tmp = vector.transpose %v, [2, 0, 1]
158/// vector.transfer_write %tmp ...
159/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
160///
161/// vector.transfer_write %v ...
162/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
163/// into:
164/// %tmp = vector.transpose %v, [1, 0]
165/// %v = vector.transfer_write %tmp ...
166/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
167struct TransferWritePermutationLowering
168 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
169 using MaskableOpRewritePattern::MaskableOpRewritePattern;
170
171 FailureOr<mlir::Value>
172 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
173 MaskingOpInterface maskOp,
174 PatternRewriter &rewriter) const override {
175 // TODO: support 0-d corner case.
176 if (op.getTransferRank() == 0)
177 return rewriter.notifyMatchFailure(arg&: op, msg: "0-d corner case not supported");
178 // TODO: Support transfer_write inside MaskOp case.
179 if (maskOp)
180 return rewriter.notifyMatchFailure(arg&: op, msg: "Masked case not supported");
181
182 SmallVector<unsigned> permutation;
183 AffineMap map = op.getPermutationMap();
184 if (map.isMinorIdentity())
185 return rewriter.notifyMatchFailure(arg&: op, msg: "map is already minor identity");
186
187 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
188 return rewriter.notifyMatchFailure(
189 arg&: op, msg: "map is not permutable to minor identity, apply another pattern");
190 }
191
192 // Remove unused dims from the permutation map. E.g.:
193 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
194 // comp = (d0, d1, d2) -> (d2, d0, d1)
195 auto comp = compressUnusedDims(map);
196 AffineMap permutationMap = inversePermutation(map: comp);
197 // Get positions of remaining result dims.
198 SmallVector<int64_t> indices;
199 llvm::transform(Range: permutationMap.getResults(), d_first: std::back_inserter(x&: indices),
200 F: [](AffineExpr expr) {
201 return dyn_cast<AffineDimExpr>(Val&: expr).getPosition();
202 });
203
204 // Transpose in_bounds attribute.
205 ArrayAttr newInBoundsAttr =
206 inverseTransposeInBoundsAttr(builder&: rewriter, attr: op.getInBounds(), permutation);
207
208 // Generate new transfer_write operation.
209 Value newVec = rewriter.create<vector::TransposeOp>(
210 location: op.getLoc(), args: op.getVector(), args&: indices);
211 auto newMap = AffineMap::getMinorIdentityMap(
212 dims: map.getNumDims(), results: map.getNumResults(), context: rewriter.getContext());
213 auto newWrite = rewriter.create<vector::TransferWriteOp>(
214 location: op.getLoc(), args&: newVec, args: op.getBase(), args: op.getIndices(),
215 args: AffineMapAttr::get(value: newMap), args: op.getMask(), args&: newInBoundsAttr);
216 if (newWrite.hasPureTensorSemantics())
217 return newWrite.getResult();
218 // In the memref case there's no return value. Use empty value to signal
219 // success.
220 return Value();
221 }
222};
223
224/// Convert a transfer.write op with a map which isn't the permutation of a
225/// minor identity into a vector.broadcast + transfer_write with permutation of
226/// minor identity map by adding unit dim on inner dimension. Ex:
227/// ```
228/// vector.transfer_write %v
229/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
230/// vector<8x16xf32>
231/// ```
232/// into:
233/// ```
234/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
235/// vector.transfer_write %v1
236/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
237/// vector<1x8x16xf32>
238/// ```
239struct TransferWriteNonPermutationLowering
240 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
241 using MaskableOpRewritePattern::MaskableOpRewritePattern;
242
243 FailureOr<mlir::Value>
244 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
245 MaskingOpInterface maskOp,
246 PatternRewriter &rewriter) const override {
247 // TODO: support 0-d corner case.
248 if (op.getTransferRank() == 0)
249 return rewriter.notifyMatchFailure(arg&: op, msg: "0-d corner case not supported");
250 // TODO: Support transfer_write inside MaskOp case.
251 if (maskOp)
252 return rewriter.notifyMatchFailure(arg&: op, msg: "Masked case not supported");
253
254 SmallVector<unsigned> permutation;
255 AffineMap map = op.getPermutationMap();
256 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
257 return rewriter.notifyMatchFailure(
258 arg&: op,
259 msg: "map is already permutable to minor identity, apply another pattern");
260 }
261
262 // Missing outer dimensions are allowed, find the most outer existing
263 // dimension then deduce the missing inner dimensions.
264 SmallVector<bool> foundDim(map.getNumDims(), false);
265 for (AffineExpr exp : map.getResults())
266 foundDim[cast<AffineDimExpr>(Val&: exp).getPosition()] = true;
267 SmallVector<AffineExpr> exprs;
268 bool foundFirstDim = false;
269 SmallVector<int64_t> missingInnerDim;
270 for (size_t i = 0; i < foundDim.size(); i++) {
271 if (foundDim[i]) {
272 foundFirstDim = true;
273 continue;
274 }
275 if (!foundFirstDim)
276 continue;
277 // Once we found one outer dimension existing in the map keep track of all
278 // the missing dimensions after that.
279 missingInnerDim.push_back(Elt: i);
280 exprs.push_back(Elt: rewriter.getAffineDimExpr(position: i));
281 }
282 // Vector: add unit dims at the beginning of the shape.
283 Value newVec = extendVectorRank(builder&: rewriter, loc: op.getLoc(), vec: op.getVector(),
284 addedRank: missingInnerDim.size());
285 // Mask: add unit dims at the end of the shape.
286 Value newMask;
287 if (op.getMask())
288 newMask = extendMaskRank(builder&: rewriter, loc: op.getLoc(), vec: op.getMask(),
289 addedRank: missingInnerDim.size());
290 exprs.append(in_start: map.getResults().begin(), in_end: map.getResults().end());
291 AffineMap newMap =
292 AffineMap::get(dimCount: map.getNumDims(), symbolCount: 0, results: exprs, context: op.getContext());
293 // All the new dimensions added are inbound.
294 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
295 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
296 newInBoundsValues.push_back(Elt: op.isDimInBounds(dim: i));
297 }
298 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(values: newInBoundsValues);
299 auto newWrite = rewriter.create<vector::TransferWriteOp>(
300 location: op.getLoc(), args&: newVec, args: op.getBase(), args: op.getIndices(),
301 args: AffineMapAttr::get(value: newMap), args&: newMask, args&: newInBoundsAttr);
302 if (newWrite.hasPureTensorSemantics())
303 return newWrite.getResult();
304 // In the memref case there's no return value. Use empty value to signal
305 // success.
306 return Value();
307 }
308};
309
310/// Lower transfer_read op with broadcast in the leading dimensions into
311/// transfer_read of lower rank + vector.broadcast.
312/// Ex: vector.transfer_read ...
313/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
314/// into:
315/// %v = vector.transfer_read ...
316/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
317/// vector.broadcast %v
318struct TransferOpReduceRank
319 : public MaskableOpRewritePattern<vector::TransferReadOp> {
320 using MaskableOpRewritePattern::MaskableOpRewritePattern;
321
322 FailureOr<mlir::Value>
323 matchAndRewriteMaskableOp(vector::TransferReadOp op,
324 MaskingOpInterface maskOp,
325 PatternRewriter &rewriter) const override {
326 // TODO: support 0-d corner case.
327 if (op.getTransferRank() == 0)
328 return rewriter.notifyMatchFailure(arg&: op, msg: "0-d corner case not supported");
329 // TODO: support masked case.
330 if (maskOp)
331 return rewriter.notifyMatchFailure(arg&: op, msg: "Masked case not supported");
332
333 AffineMap map = op.getPermutationMap();
334 unsigned numLeadingBroadcast = 0;
335 for (auto expr : map.getResults()) {
336 auto dimExpr = dyn_cast<AffineConstantExpr>(Val&: expr);
337 if (!dimExpr || dimExpr.getValue() != 0)
338 break;
339 numLeadingBroadcast++;
340 }
341 // If there are no leading zeros in the map there is nothing to do.
342 if (numLeadingBroadcast == 0)
343 return rewriter.notifyMatchFailure(arg&: op, msg: "no leading broadcasts in map");
344
345 VectorType originalVecType = op.getVectorType();
346 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
347 // Calculate new map, vector type and masks without the leading zeros.
348 AffineMap newMap = AffineMap::get(
349 dimCount: map.getNumDims(), symbolCount: 0, results: map.getResults().take_back(N: reducedShapeRank),
350 context: op.getContext());
351 // Only remove the leading zeros if the rest of the map is a minor identity
352 // with broadasting. Otherwise we first want to permute the map.
353 if (!newMap.isMinorIdentityWithBroadcasting()) {
354 return rewriter.notifyMatchFailure(
355 arg&: op, msg: "map is not a minor identity with broadcasting");
356 }
357
358 SmallVector<int64_t> newShape(
359 originalVecType.getShape().take_back(N: reducedShapeRank));
360 SmallVector<bool> newScalableDims(
361 originalVecType.getScalableDims().take_back(N: reducedShapeRank));
362
363 VectorType newReadType = VectorType::get(
364 shape: newShape, elementType: originalVecType.getElementType(), scalableDims: newScalableDims);
365 ArrayAttr newInBoundsAttr =
366 op.getInBounds()
367 ? rewriter.getArrayAttr(
368 value: op.getInBoundsAttr().getValue().take_back(N: reducedShapeRank))
369 : ArrayAttr();
370 Value newRead = rewriter.create<vector::TransferReadOp>(
371 location: op.getLoc(), args&: newReadType, args: op.getBase(), args: op.getIndices(),
372 args: AffineMapAttr::get(value: newMap), args: op.getPadding(), args: op.getMask(),
373 args&: newInBoundsAttr);
374 return rewriter
375 .create<vector::BroadcastOp>(location: op.getLoc(), args&: originalVecType, args&: newRead)
376 .getVector();
377 }
378};
379
380} // namespace
381
382void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
383 RewritePatternSet &patterns, PatternBenefit benefit) {
384 patterns
385 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
386 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
387 arg: patterns.getContext(), args&: benefit);
388}
389
390//===----------------------------------------------------------------------===//
391// populateVectorTransferLoweringPatterns
392//===----------------------------------------------------------------------===//
393
394namespace {
395/// Progressive lowering of transfer_read. This pattern supports lowering of
396/// `vector.transfer_read` to a combination of `vector.load` and
397/// `vector.broadcast` if all of the following hold:
398/// - Stride of most minor memref dimension must be 1.
399/// - Out-of-bounds masking is not required.
400/// - If the memref's element type is a vector type then it coincides with the
401/// result type.
402/// - The permutation map doesn't perform permutation (broadcasting is allowed).
403struct TransferReadToVectorLoadLowering
404 : public MaskableOpRewritePattern<vector::TransferReadOp> {
405 TransferReadToVectorLoadLowering(MLIRContext *context,
406 std::optional<unsigned> maxRank,
407 PatternBenefit benefit = 1)
408 : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
409 maxTransferRank(maxRank) {}
410
411 FailureOr<mlir::Value>
412 matchAndRewriteMaskableOp(vector::TransferReadOp read,
413 MaskingOpInterface maskOp,
414 PatternRewriter &rewriter) const override {
415 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
416 return rewriter.notifyMatchFailure(
417 arg&: read, msg: "vector type is greater than max transfer rank");
418 }
419
420 if (maskOp)
421 return rewriter.notifyMatchFailure(arg&: read, msg: "Masked case not supported");
422 SmallVector<unsigned> broadcastedDims;
423 // Permutations are handled by VectorToSCF or
424 // populateVectorTransferPermutationMapLoweringPatterns.
425 // We let the 0-d corner case pass-through as it is supported.
426 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
427 broadcastedDims: &broadcastedDims))
428 return rewriter.notifyMatchFailure(arg&: read, msg: "not minor identity + bcast");
429
430 auto memRefType = dyn_cast<MemRefType>(Val: read.getShapedType());
431 if (!memRefType)
432 return rewriter.notifyMatchFailure(arg&: read, msg: "not a memref source");
433
434 // Non-unit strides are handled by VectorToSCF.
435 if (!memRefType.isLastDimUnitStride())
436 return rewriter.notifyMatchFailure(arg&: read, msg: "!= 1 stride needs VectorToSCF");
437
438 // If there is broadcasting involved then we first load the unbroadcasted
439 // vector, and then broadcast it with `vector.broadcast`.
440 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
441 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
442 for (unsigned i : broadcastedDims)
443 unbroadcastedVectorShape[i] = 1;
444 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
445 shape: unbroadcastedVectorShape, elementType: read.getVectorType().getElementType());
446
447 // `vector.load` supports vector types as memref's elements only when the
448 // resulting vector type is the same as the element type.
449 auto memrefElTy = memRefType.getElementType();
450 if (isa<VectorType>(Val: memrefElTy) && memrefElTy != unbroadcastedVectorType)
451 return rewriter.notifyMatchFailure(arg&: read, msg: "incompatible element type");
452
453 // Otherwise, element types of the memref and the vector must match.
454 if (!isa<VectorType>(Val: memrefElTy) &&
455 memrefElTy != read.getVectorType().getElementType())
456 return rewriter.notifyMatchFailure(arg&: read, msg: "non-matching element type");
457
458 // Out-of-bounds dims are handled by MaterializeTransferMask.
459 if (read.hasOutOfBoundsDim())
460 return rewriter.notifyMatchFailure(arg&: read, msg: "out-of-bounds needs mask");
461
462 // Create vector load op.
463 Operation *res;
464 if (read.getMask()) {
465 if (read.getVectorType().getRank() != 1)
466 // vector.maskedload operates on 1-D vectors.
467 return rewriter.notifyMatchFailure(
468 arg&: read, msg: "vector type is not rank 1, can't create masked load, needs "
469 "VectorToSCF");
470
471 Value fill = rewriter.create<vector::SplatOp>(
472 location: read.getLoc(), args&: unbroadcastedVectorType, args: read.getPadding());
473 res = rewriter.create<vector::MaskedLoadOp>(
474 location: read.getLoc(), args&: unbroadcastedVectorType, args: read.getBase(),
475 args: read.getIndices(), args: read.getMask(), args&: fill);
476 } else {
477 res = rewriter.create<vector::LoadOp>(location: read.getLoc(),
478 args&: unbroadcastedVectorType,
479 args: read.getBase(), args: read.getIndices());
480 }
481
482 // Insert a broadcasting op if required.
483 if (!broadcastedDims.empty())
484 res = rewriter.create<vector::BroadcastOp>(
485 location: read.getLoc(), args: read.getVectorType(), args: res->getResult(idx: 0));
486 return res->getResult(idx: 0);
487 }
488
489 std::optional<unsigned> maxTransferRank;
490};
491
492/// Progressive lowering of transfer_write. This pattern supports lowering of
493/// `vector.transfer_write` to `vector.store` if all of the following hold:
494/// - Stride of most minor memref dimension must be 1.
495/// - Out-of-bounds masking is not required.
496/// - If the memref's element type is a vector type then it coincides with the
497/// type of the written value.
498/// - The permutation map is the minor identity map (neither permutation nor
499/// broadcasting is allowed).
500struct TransferWriteToVectorStoreLowering
501 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
502 TransferWriteToVectorStoreLowering(MLIRContext *context,
503 std::optional<unsigned> maxRank,
504 PatternBenefit benefit = 1)
505 : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
506 maxTransferRank(maxRank) {}
507
508 FailureOr<mlir::Value>
509 matchAndRewriteMaskableOp(vector::TransferWriteOp write,
510 MaskingOpInterface maskOp,
511 PatternRewriter &rewriter) const override {
512 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
513 return rewriter.notifyMatchFailure(
514 arg&: write, msg: "vector type is greater than max transfer rank");
515 }
516 if (maskOp)
517 return rewriter.notifyMatchFailure(arg&: write, msg: "Masked case not supported");
518
519 // Permutations are handled by VectorToSCF or
520 // populateVectorTransferPermutationMapLoweringPatterns.
521 if ( // pass-through for the 0-d corner case.
522 !write.getPermutationMap().isMinorIdentity())
523 return rewriter.notifyMatchFailure(loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
524 diag << "permutation map is not minor identity: " << write;
525 });
526
527 auto memRefType = dyn_cast<MemRefType>(Val: write.getShapedType());
528 if (!memRefType)
529 return rewriter.notifyMatchFailure(loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
530 diag << "not a memref type: " << write;
531 });
532
533 // Non-unit strides are handled by VectorToSCF.
534 if (!memRefType.isLastDimUnitStride())
535 return rewriter.notifyMatchFailure(loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
536 diag << "most minor stride is not 1: " << write;
537 });
538
539 // `vector.store` supports vector types as memref's elements only when the
540 // type of the vector value being written is the same as the element type.
541 auto memrefElTy = memRefType.getElementType();
542 if (isa<VectorType>(Val: memrefElTy) && memrefElTy != write.getVectorType())
543 return rewriter.notifyMatchFailure(loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
544 diag << "elemental type mismatch: " << write;
545 });
546
547 // Otherwise, element types of the memref and the vector must match.
548 if (!isa<VectorType>(Val: memrefElTy) &&
549 memrefElTy != write.getVectorType().getElementType())
550 return rewriter.notifyMatchFailure(loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
551 diag << "elemental type mismatch: " << write;
552 });
553
554 // Out-of-bounds dims are handled by MaterializeTransferMask.
555 if (write.hasOutOfBoundsDim())
556 return rewriter.notifyMatchFailure(loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
557 diag << "out of bounds dim: " << write;
558 });
559 if (write.getMask()) {
560 if (write.getVectorType().getRank() != 1)
561 // vector.maskedstore operates on 1-D vectors.
562 return rewriter.notifyMatchFailure(
563 loc: write.getLoc(), reasonCallback: [=](Diagnostic &diag) {
564 diag << "vector type is not rank 1, can't create masked store, "
565 "needs VectorToSCF: "
566 << write;
567 });
568
569 rewriter.create<vector::MaskedStoreOp>(
570 location: write.getLoc(), args: write.getBase(), args: write.getIndices(), args: write.getMask(),
571 args: write.getVector());
572 } else {
573 rewriter.create<vector::StoreOp>(location: write.getLoc(), args: write.getVector(),
574 args: write.getBase(), args: write.getIndices());
575 }
576 // There's no return value for StoreOps. Use Value() to signal success to
577 // matchAndRewrite.
578 return Value();
579 }
580
581 std::optional<unsigned> maxTransferRank;
582};
583} // namespace
584
585void mlir::vector::populateVectorTransferLoweringPatterns(
586 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
587 PatternBenefit benefit) {
588 patterns.add<TransferReadToVectorLoadLowering,
589 TransferWriteToVectorStoreLowering>(arg: patterns.getContext(),
590 args&: maxTransferRank, args&: benefit);
591}
592

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp