1//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrite patterns for the permutation_map attribute of
10// vector.transfer operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
18#include "mlir/Interfaces/VectorInterfaces.h"
19
20using namespace mlir;
21using namespace mlir::vector;
22
23/// Transpose a vector transfer op's `in_bounds` attribute by applying reverse
24/// permutation based on the given indices.
25static ArrayAttr
26inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
27 const SmallVector<unsigned> &permutation) {
28 SmallVector<bool> newInBoundsValues(permutation.size());
29 size_t index = 0;
30 for (unsigned pos : permutation)
31 newInBoundsValues[pos] =
32 cast<BoolAttr>(attr.getValue()[index++]).getValue();
33 return builder.getBoolArrayAttr(newInBoundsValues);
34}
35
36/// Extend the rank of a vector Value by `addedRanks` by adding outer unit
37/// dimensions.
38static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
39 int64_t addedRank) {
40 auto originalVecType = cast<VectorType>(vec.getType());
41 SmallVector<int64_t> newShape(addedRank, 1);
42 newShape.append(originalVecType.getShape().begin(),
43 originalVecType.getShape().end());
44
45 SmallVector<bool> newScalableDims(addedRank, false);
46 newScalableDims.append(originalVecType.getScalableDims().begin(),
47 originalVecType.getScalableDims().end());
48 VectorType newVecType = VectorType::get(
49 newShape, originalVecType.getElementType(), newScalableDims);
50 return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
51}
52
53/// Extend the rank of a vector Value by `addedRanks` by adding inner unit
54/// dimensions.
55static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
56 int64_t addedRank) {
57 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
58 SmallVector<int64_t> permutation;
59 for (int64_t i = addedRank,
60 e = cast<VectorType>(broadcasted.getType()).getRank();
61 i < e; ++i)
62 permutation.push_back(Elt: i);
63 for (int64_t i = 0; i < addedRank; ++i)
64 permutation.push_back(Elt: i);
65 return builder.create<vector::TransposeOp>(loc, broadcasted, permutation);
66}
67
68//===----------------------------------------------------------------------===//
69// populateVectorTransferPermutationMapLoweringPatterns
70//===----------------------------------------------------------------------===//
71
72namespace {
73/// Lower transfer_read op with permutation into a transfer_read with a
74/// permutation map composed of leading zeros followed by a minor identiy +
75/// vector.transpose op.
76/// Ex:
77/// vector.transfer_read ...
78/// permutation_map: (d0, d1, d2) -> (0, d1)
79/// into:
80/// %v = vector.transfer_read ...
81/// permutation_map: (d0, d1, d2) -> (d1, 0)
82/// vector.transpose %v, [1, 0]
83///
84/// vector.transfer_read ...
85/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
86/// into:
87/// %v = vector.transfer_read ...
88/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
89/// vector.transpose %v, [0, 1, 3, 2, 4]
90/// Note that an alternative is to transform it to linalg.transpose +
91/// vector.transfer_read to do the transpose in memory instead.
92struct TransferReadPermutationLowering
93 : public MaskableOpRewritePattern<vector::TransferReadOp> {
94 using MaskableOpRewritePattern::MaskableOpRewritePattern;
95
96 FailureOr<mlir::Value>
97 matchAndRewriteMaskableOp(vector::TransferReadOp op,
98 MaskingOpInterface maskOp,
99 PatternRewriter &rewriter) const override {
100 // TODO: support 0-d corner case.
101 if (op.getTransferRank() == 0)
102 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
103 // TODO: Support transfer_read inside MaskOp case.
104 if (maskOp)
105 return rewriter.notifyMatchFailure(op, "Masked case not supported");
106
107 SmallVector<unsigned> permutation;
108 AffineMap map = op.getPermutationMap();
109 if (map.getNumResults() == 0)
110 return rewriter.notifyMatchFailure(op, "0 result permutation map");
111 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
112 return rewriter.notifyMatchFailure(
113 op, "map is not permutable to minor identity, apply another pattern");
114 }
115 AffineMap permutationMap =
116 map.getPermutationMap(permutation, op.getContext());
117 if (permutationMap.isIdentity())
118 return rewriter.notifyMatchFailure(op, "map is not identity");
119
120 permutationMap = map.getPermutationMap(permutation, op.getContext());
121 // Caluclate the map of the new read by applying the inverse permutation.
122 permutationMap = inversePermutation(map: permutationMap);
123 AffineMap newMap = permutationMap.compose(map);
124 // Apply the reverse transpose to deduce the type of the transfer_read.
125 ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
126 SmallVector<int64_t> newVectorShape(originalShape.size());
127 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims();
128 SmallVector<bool> newScalableDims(originalShape.size());
129 for (const auto &pos : llvm::enumerate(First&: permutation)) {
130 newVectorShape[pos.value()] = originalShape[pos.index()];
131 newScalableDims[pos.value()] = originalScalableDims[pos.index()];
132 }
133
134 // Transpose in_bounds attribute.
135 ArrayAttr newInBoundsAttr =
136 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
137
138 // Generate new transfer_read operation.
139 VectorType newReadType = VectorType::get(
140 newVectorShape, op.getVectorType().getElementType(), newScalableDims);
141 Value newRead = rewriter.create<vector::TransferReadOp>(
142 op.getLoc(), newReadType, op.getBase(), op.getIndices(),
143 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
144 newInBoundsAttr);
145
146 // Transpose result of transfer_read.
147 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
148 return rewriter
149 .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
150 .getResult();
151 }
152};
153
154/// Lower transfer_write op with permutation into a transfer_write with a
155/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
156/// Ex:
157/// vector.transfer_write %v ...
158/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
159/// into:
160/// %tmp = vector.transpose %v, [2, 0, 1]
161/// vector.transfer_write %tmp ...
162/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
163///
164/// vector.transfer_write %v ...
165/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
166/// into:
167/// %tmp = vector.transpose %v, [1, 0]
168/// %v = vector.transfer_write %tmp ...
169/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
170struct TransferWritePermutationLowering
171 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
172 using MaskableOpRewritePattern::MaskableOpRewritePattern;
173
174 FailureOr<mlir::Value>
175 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
176 MaskingOpInterface maskOp,
177 PatternRewriter &rewriter) const override {
178 // TODO: support 0-d corner case.
179 if (op.getTransferRank() == 0)
180 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
181 // TODO: Support transfer_write inside MaskOp case.
182 if (maskOp)
183 return rewriter.notifyMatchFailure(op, "Masked case not supported");
184
185 SmallVector<unsigned> permutation;
186 AffineMap map = op.getPermutationMap();
187 if (map.isMinorIdentity())
188 return rewriter.notifyMatchFailure(op, "map is already minor identity");
189
190 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
191 return rewriter.notifyMatchFailure(
192 op, "map is not permutable to minor identity, apply another pattern");
193 }
194
195 // Remove unused dims from the permutation map. E.g.:
196 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
197 // comp = (d0, d1, d2) -> (d2, d0, d1)
198 auto comp = compressUnusedDims(map);
199 AffineMap permutationMap = inversePermutation(comp);
200 // Get positions of remaining result dims.
201 SmallVector<int64_t> indices;
202 llvm::transform(Range: permutationMap.getResults(), d_first: std::back_inserter(x&: indices),
203 F: [](AffineExpr expr) {
204 return dyn_cast<AffineDimExpr>(Val&: expr).getPosition();
205 });
206
207 // Transpose in_bounds attribute.
208 ArrayAttr newInBoundsAttr =
209 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation);
210
211 // Generate new transfer_write operation.
212 Value newVec = rewriter.create<vector::TransposeOp>(
213 op.getLoc(), op.getVector(), indices);
214 auto newMap = AffineMap::getMinorIdentityMap(
215 dims: map.getNumDims(), results: map.getNumResults(), context: rewriter.getContext());
216 auto newWrite = rewriter.create<vector::TransferWriteOp>(
217 op.getLoc(), newVec, op.getBase(), op.getIndices(),
218 AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
219 if (newWrite.hasPureTensorSemantics())
220 return newWrite.getResult();
221 // In the memref case there's no return value. Use empty value to signal
222 // success.
223 return Value();
224 }
225};
226
227/// Convert a transfer.write op with a map which isn't the permutation of a
228/// minor identity into a vector.broadcast + transfer_write with permutation of
229/// minor identity map by adding unit dim on inner dimension. Ex:
230/// ```
231/// vector.transfer_write %v
232/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} :
233/// vector<8x16xf32>
234/// ```
235/// into:
236/// ```
237/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32>
238/// vector.transfer_write %v1
239/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} :
240/// vector<1x8x16xf32>
241/// ```
242struct TransferWriteNonPermutationLowering
243 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
244 using MaskableOpRewritePattern::MaskableOpRewritePattern;
245
246 FailureOr<mlir::Value>
247 matchAndRewriteMaskableOp(vector::TransferWriteOp op,
248 MaskingOpInterface maskOp,
249 PatternRewriter &rewriter) const override {
250 // TODO: support 0-d corner case.
251 if (op.getTransferRank() == 0)
252 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
253 // TODO: Support transfer_write inside MaskOp case.
254 if (maskOp)
255 return rewriter.notifyMatchFailure(op, "Masked case not supported");
256
257 SmallVector<unsigned> permutation;
258 AffineMap map = op.getPermutationMap();
259 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutedDims&: permutation)) {
260 return rewriter.notifyMatchFailure(
261 op,
262 "map is already permutable to minor identity, apply another pattern");
263 }
264
265 // Missing outer dimensions are allowed, find the most outer existing
266 // dimension then deduce the missing inner dimensions.
267 SmallVector<bool> foundDim(map.getNumDims(), false);
268 for (AffineExpr exp : map.getResults())
269 foundDim[cast<AffineDimExpr>(exp).getPosition()] = true;
270 SmallVector<AffineExpr> exprs;
271 bool foundFirstDim = false;
272 SmallVector<int64_t> missingInnerDim;
273 for (size_t i = 0; i < foundDim.size(); i++) {
274 if (foundDim[i]) {
275 foundFirstDim = true;
276 continue;
277 }
278 if (!foundFirstDim)
279 continue;
280 // Once we found one outer dimension existing in the map keep track of all
281 // the missing dimensions after that.
282 missingInnerDim.push_back(Elt: i);
283 exprs.push_back(Elt: rewriter.getAffineDimExpr(position: i));
284 }
285 // Vector: add unit dims at the beginning of the shape.
286 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(),
287 missingInnerDim.size());
288 // Mask: add unit dims at the end of the shape.
289 Value newMask;
290 if (op.getMask())
291 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(),
292 missingInnerDim.size());
293 exprs.append(in_start: map.getResults().begin(), in_end: map.getResults().end());
294 AffineMap newMap =
295 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
296 // All the new dimensions added are inbound.
297 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
298 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
299 newInBoundsValues.push_back(Elt: op.isDimInBounds(i));
300 }
301 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
302 auto newWrite = rewriter.create<vector::TransferWriteOp>(
303 op.getLoc(), newVec, op.getBase(), op.getIndices(),
304 AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
305 if (newWrite.hasPureTensorSemantics())
306 return newWrite.getResult();
307 // In the memref case there's no return value. Use empty value to signal
308 // success.
309 return Value();
310 }
311};
312
313/// Lower transfer_read op with broadcast in the leading dimensions into
314/// transfer_read of lower rank + vector.broadcast.
315/// Ex: vector.transfer_read ...
316/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
317/// into:
318/// %v = vector.transfer_read ...
319/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
320/// vector.broadcast %v
321struct TransferOpReduceRank
322 : public MaskableOpRewritePattern<vector::TransferReadOp> {
323 using MaskableOpRewritePattern::MaskableOpRewritePattern;
324
325 FailureOr<mlir::Value>
326 matchAndRewriteMaskableOp(vector::TransferReadOp op,
327 MaskingOpInterface maskOp,
328 PatternRewriter &rewriter) const override {
329 // TODO: support 0-d corner case.
330 if (op.getTransferRank() == 0)
331 return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
332 // TODO: support masked case.
333 if (maskOp)
334 return rewriter.notifyMatchFailure(op, "Masked case not supported");
335
336 AffineMap map = op.getPermutationMap();
337 unsigned numLeadingBroadcast = 0;
338 for (auto expr : map.getResults()) {
339 auto dimExpr = dyn_cast<AffineConstantExpr>(expr);
340 if (!dimExpr || dimExpr.getValue() != 0)
341 break;
342 numLeadingBroadcast++;
343 }
344 // If there are no leading zeros in the map there is nothing to do.
345 if (numLeadingBroadcast == 0)
346 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map");
347
348 VectorType originalVecType = op.getVectorType();
349 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
350 // Calculate new map, vector type and masks without the leading zeros.
351 AffineMap newMap = AffineMap::get(
352 map.getNumDims(), 0, map.getResults().take_back(N: reducedShapeRank),
353 op.getContext());
354 // Only remove the leading zeros if the rest of the map is a minor identity
355 // with broadasting. Otherwise we first want to permute the map.
356 if (!newMap.isMinorIdentityWithBroadcasting()) {
357 return rewriter.notifyMatchFailure(
358 op, "map is not a minor identity with broadcasting");
359 }
360
361 SmallVector<int64_t> newShape(
362 originalVecType.getShape().take_back(reducedShapeRank));
363 SmallVector<bool> newScalableDims(
364 originalVecType.getScalableDims().take_back(reducedShapeRank));
365
366 VectorType newReadType = VectorType::get(
367 newShape, originalVecType.getElementType(), newScalableDims);
368 ArrayAttr newInBoundsAttr =
369 op.getInBounds()
370 ? rewriter.getArrayAttr(
371 op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
372 : ArrayAttr();
373 Value newRead = rewriter.create<vector::TransferReadOp>(
374 op.getLoc(), newReadType, op.getBase(), op.getIndices(),
375 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
376 newInBoundsAttr);
377 return rewriter
378 .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
379 .getVector();
380 }
381};
382
383} // namespace
384
385void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
386 RewritePatternSet &patterns, PatternBenefit benefit) {
387 patterns
388 .add<TransferReadPermutationLowering, TransferWritePermutationLowering,
389 TransferOpReduceRank, TransferWriteNonPermutationLowering>(
390 arg: patterns.getContext(), args&: benefit);
391}
392
393//===----------------------------------------------------------------------===//
394// populateVectorTransferLoweringPatterns
395//===----------------------------------------------------------------------===//
396
397namespace {
398/// Progressive lowering of transfer_read. This pattern supports lowering of
399/// `vector.transfer_read` to a combination of `vector.load` and
400/// `vector.broadcast` if all of the following hold:
401/// - Stride of most minor memref dimension must be 1.
402/// - Out-of-bounds masking is not required.
403/// - If the memref's element type is a vector type then it coincides with the
404/// result type.
405/// - The permutation map doesn't perform permutation (broadcasting is allowed).
406struct TransferReadToVectorLoadLowering
407 : public MaskableOpRewritePattern<vector::TransferReadOp> {
408 TransferReadToVectorLoadLowering(MLIRContext *context,
409 std::optional<unsigned> maxRank,
410 PatternBenefit benefit = 1)
411 : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit),
412 maxTransferRank(maxRank) {}
413
414 FailureOr<mlir::Value>
415 matchAndRewriteMaskableOp(vector::TransferReadOp read,
416 MaskingOpInterface maskOp,
417 PatternRewriter &rewriter) const override {
418 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) {
419 return rewriter.notifyMatchFailure(
420 read, "vector type is greater than max transfer rank");
421 }
422
423 if (maskOp)
424 return rewriter.notifyMatchFailure(read, "Masked case not supported");
425 SmallVector<unsigned> broadcastedDims;
426 // Permutations are handled by VectorToSCF or
427 // populateVectorTransferPermutationMapLoweringPatterns.
428 // We let the 0-d corner case pass-through as it is supported.
429 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting(
430 &broadcastedDims))
431 return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
432
433 auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
434 if (!memRefType)
435 return rewriter.notifyMatchFailure(read, "not a memref source");
436
437 // Non-unit strides are handled by VectorToSCF.
438 if (!memRefType.isLastDimUnitStride())
439 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
440
441 // If there is broadcasting involved then we first load the unbroadcasted
442 // vector, and then broadcast it with `vector.broadcast`.
443 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape();
444 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape);
445 for (unsigned i : broadcastedDims)
446 unbroadcastedVectorShape[i] = 1;
447 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
448 unbroadcastedVectorShape, read.getVectorType().getElementType());
449
450 // `vector.load` supports vector types as memref's elements only when the
451 // resulting vector type is the same as the element type.
452 auto memrefElTy = memRefType.getElementType();
453 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
454 return rewriter.notifyMatchFailure(read, "incompatible element type");
455
456 // Otherwise, element types of the memref and the vector must match.
457 if (!isa<VectorType>(memrefElTy) &&
458 memrefElTy != read.getVectorType().getElementType())
459 return rewriter.notifyMatchFailure(read, "non-matching element type");
460
461 // Out-of-bounds dims are handled by MaterializeTransferMask.
462 if (read.hasOutOfBoundsDim())
463 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask");
464
465 // Create vector load op.
466 Operation *res;
467 if (read.getMask()) {
468 if (read.getVectorType().getRank() != 1)
469 // vector.maskedload operates on 1-D vectors.
470 return rewriter.notifyMatchFailure(
471 read, "vector type is not rank 1, can't create masked load, needs "
472 "VectorToSCF");
473
474 Value fill = rewriter.create<vector::SplatOp>(
475 read.getLoc(), unbroadcastedVectorType, read.getPadding());
476 res = rewriter.create<vector::MaskedLoadOp>(
477 read.getLoc(), unbroadcastedVectorType, read.getBase(),
478 read.getIndices(), read.getMask(), fill);
479 } else {
480 res = rewriter.create<vector::LoadOp>(read.getLoc(),
481 unbroadcastedVectorType,
482 read.getBase(), read.getIndices());
483 }
484
485 // Insert a broadcasting op if required.
486 if (!broadcastedDims.empty())
487 res = rewriter.create<vector::BroadcastOp>(
488 read.getLoc(), read.getVectorType(), res->getResult(0));
489 return res->getResult(idx: 0);
490 }
491
492 std::optional<unsigned> maxTransferRank;
493};
494
495/// Progressive lowering of transfer_write. This pattern supports lowering of
496/// `vector.transfer_write` to `vector.store` if all of the following hold:
497/// - Stride of most minor memref dimension must be 1.
498/// - Out-of-bounds masking is not required.
499/// - If the memref's element type is a vector type then it coincides with the
500/// type of the written value.
501/// - The permutation map is the minor identity map (neither permutation nor
502/// broadcasting is allowed).
503struct TransferWriteToVectorStoreLowering
504 : public MaskableOpRewritePattern<vector::TransferWriteOp> {
505 TransferWriteToVectorStoreLowering(MLIRContext *context,
506 std::optional<unsigned> maxRank,
507 PatternBenefit benefit = 1)
508 : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit),
509 maxTransferRank(maxRank) {}
510
511 FailureOr<mlir::Value>
512 matchAndRewriteMaskableOp(vector::TransferWriteOp write,
513 MaskingOpInterface maskOp,
514 PatternRewriter &rewriter) const override {
515 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) {
516 return rewriter.notifyMatchFailure(
517 write, "vector type is greater than max transfer rank");
518 }
519 if (maskOp)
520 return rewriter.notifyMatchFailure(write, "Masked case not supported");
521
522 // Permutations are handled by VectorToSCF or
523 // populateVectorTransferPermutationMapLoweringPatterns.
524 if ( // pass-through for the 0-d corner case.
525 !write.getPermutationMap().isMinorIdentity())
526 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
527 diag << "permutation map is not minor identity: " << write;
528 });
529
530 auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
531 if (!memRefType)
532 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
533 diag << "not a memref type: " << write;
534 });
535
536 // Non-unit strides are handled by VectorToSCF.
537 if (!memRefType.isLastDimUnitStride())
538 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
539 diag << "most minor stride is not 1: " << write;
540 });
541
542 // `vector.store` supports vector types as memref's elements only when the
543 // type of the vector value being written is the same as the element type.
544 auto memrefElTy = memRefType.getElementType();
545 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
546 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
547 diag << "elemental type mismatch: " << write;
548 });
549
550 // Otherwise, element types of the memref and the vector must match.
551 if (!isa<VectorType>(memrefElTy) &&
552 memrefElTy != write.getVectorType().getElementType())
553 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
554 diag << "elemental type mismatch: " << write;
555 });
556
557 // Out-of-bounds dims are handled by MaterializeTransferMask.
558 if (write.hasOutOfBoundsDim())
559 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
560 diag << "out of bounds dim: " << write;
561 });
562 if (write.getMask()) {
563 if (write.getVectorType().getRank() != 1)
564 // vector.maskedstore operates on 1-D vectors.
565 return rewriter.notifyMatchFailure(
566 write.getLoc(), [=](Diagnostic &diag) {
567 diag << "vector type is not rank 1, can't create masked store, "
568 "needs VectorToSCF: "
569 << write;
570 });
571
572 rewriter.create<vector::MaskedStoreOp>(
573 write.getLoc(), write.getBase(), write.getIndices(), write.getMask(),
574 write.getVector());
575 } else {
576 rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
577 write.getBase(), write.getIndices());
578 }
579 // There's no return value for StoreOps. Use Value() to signal success to
580 // matchAndRewrite.
581 return Value();
582 }
583
584 std::optional<unsigned> maxTransferRank;
585};
586} // namespace
587
588void mlir::vector::populateVectorTransferLoweringPatterns(
589 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank,
590 PatternBenefit benefit) {
591 patterns.add<TransferReadToVectorLoadLowering,
592 TransferWriteToVectorStoreLowering>(arg: patterns.getContext(),
593 args&: maxTransferRank, args&: benefit);
594}
595

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp