1//===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector transfer operations to SCF.
10//
11//===----------------------------------------------------------------------===//
12
13#include <numeric>
14#include <optional>
15#include <type_traits>
16
17#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/Tensor/IR/Tensor.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/ImplicitLocOpBuilder.h"
30#include "mlir/Pass/Pass.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32#include "mlir/Transforms/Passes.h"
33
34namespace mlir {
35#define GEN_PASS_DEF_CONVERTVECTORTOSCF
36#include "mlir/Conversion/Passes.h.inc"
37} // namespace mlir
38
39using namespace mlir;
40using vector::TransferReadOp;
41using vector::TransferWriteOp;
42
43namespace {
44
45/// Attribute name used for labeling transfer ops during progressive lowering.
46static const char kPassLabel[] = "__vector_to_scf_lowering__";
47
48/// Return true if this transfer op operates on a source tensor.
49static bool isTensorOp(VectorTransferOpInterface xferOp) {
50 if (isa<RankedTensorType>(xferOp.getShapedType())) {
51 if (isa<vector::TransferWriteOp>(xferOp)) {
52 // TransferWriteOps on tensors have a result.
53 assert(xferOp->getNumResults() > 0);
54 }
55 return true;
56 }
57 return false;
58}
59
60/// Patterns that inherit from this struct have access to
61/// VectorTransferToSCFOptions.
62template <typename OpTy>
63struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
64 explicit VectorToSCFPattern(MLIRContext *context,
65 VectorTransferToSCFOptions opt)
66 : OpRewritePattern<OpTy>(context), options(opt) {}
67
68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
69 PatternRewriter &rewriter) const {
70 if (isTensorOp(xferOp) && !options.lowerTensors) {
71 return rewriter.notifyMatchFailure(
72 xferOp, "lowering tensor transfers is disabled");
73 }
74 return success();
75 }
76
77 VectorTransferToSCFOptions options;
78};
79
80/// Given a vector transfer op, calculate which dimension of the `source`
81/// memref should be unpacked in the next application of TransferOpConversion.
82/// A return value of std::nullopt indicates a broadcast.
83template <typename OpTy>
84static std::optional<int64_t> unpackedDim(OpTy xferOp) {
85 // TODO: support 0-d corner case.
86 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
87 auto map = xferOp.getPermutationMap();
88 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89 return expr.getPosition();
90 }
91 assert(xferOp.isBroadcastDim(0) &&
92 "Expected AffineDimExpr or AffineConstantExpr");
93 return std::nullopt;
94}
95
96/// Compute the permutation map for the new (N-1)-D vector transfer op. This
97/// map is identical to the current permutation map, but the first result is
98/// omitted.
99template <typename OpTy>
100static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
101 // TODO: support 0-d corner case.
102 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
103 auto map = xferOp.getPermutationMap();
104 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
105 b.getContext());
106}
107
108/// Calculate the indices for the new vector transfer op.
109///
110/// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
111/// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
112/// ^^^^^^
113/// `iv` is the iteration variable of the (new) surrounding loop.
114template <typename OpTy>
115static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
116 SmallVector<Value, 8> &indices) {
117 typename OpTy::Adaptor adaptor(xferOp);
118 // Corresponding memref dim of the vector dim that is unpacked.
119 auto dim = unpackedDim(xferOp);
120 auto prevIndices = adaptor.getIndices();
121 indices.append(prevIndices.begin(), prevIndices.end());
122
123 Location loc = xferOp.getLoc();
124 bool isBroadcast = !dim.has_value();
125 if (!isBroadcast) {
126 AffineExpr d0, d1;
127 bindDims(xferOp.getContext(), d0, d1);
128 Value offset = adaptor.getIndices()[*dim];
129 indices[*dim] =
130 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
131 }
132}
133
134static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
135 Value value) {
136 if (hasRetVal) {
137 assert(value && "Expected non-empty value");
138 b.create<scf::YieldOp>(loc, value);
139 } else {
140 b.create<scf::YieldOp>(loc);
141 }
142}
143
144/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
145/// is set to true. No such check is generated under following circumstances:
146/// * xferOp does not have a mask.
147/// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
148/// computed and attached to the new transfer op in the pattern.)
149/// * The to-be-unpacked dim of xferOp is a broadcast.
150template <typename OpTy>
151static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
152 if (!xferOp.getMask())
153 return Value();
154 if (xferOp.getMaskType().getRank() != 1)
155 return Value();
156 if (xferOp.isBroadcastDim(0))
157 return Value();
158
159 Location loc = xferOp.getLoc();
160 return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
161}
162
163/// Helper function TransferOpConversion and TransferOp1dConversion.
164/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
165/// specified dimension `dim` with the loop iteration variable `iv`.
166/// E.g., when unpacking dimension 0 from:
167/// ```
168/// %vec = vector.transfer_read %A[%a, %b] %cst
169/// : vector<5x4xf32>, memref<?x?xf32>
170/// ```
171/// An if check similar to this will be generated inside the loop:
172/// ```
173/// %d = memref.dim %A, %c0 : memref<?x?xf32>
174/// if (%a + iv < %d) {
175/// (in-bounds case)
176/// } else {
177/// (out-of-bounds case)
178/// }
179/// ```
180///
181/// If the transfer is 1D and has a mask, this function generates a more complex
182/// check also accounts for potentially masked out elements.
183///
184/// This function variant returns the value returned by `inBoundsCase` or
185/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
186/// `resultTypes`.
187template <typename OpTy>
188static Value generateInBoundsCheck(
189 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
190 TypeRange resultTypes,
191 function_ref<Value(OpBuilder &, Location)> inBoundsCase,
192 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
193 bool hasRetVal = !resultTypes.empty();
194 Value cond; // Condition to be built...
195
196 // Condition check 1: Access in-bounds?
197 bool isBroadcast = !dim; // No in-bounds check for broadcasts.
198 Location loc = xferOp.getLoc();
199 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
200 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
201 Value memrefDim = vector::createOrFoldDimOp(b, loc, source: xferOp.getBase(), dim: *dim);
202 AffineExpr d0, d1;
203 bindDims(xferOp.getContext(), d0, d1);
204 Value base = xferOp.getIndices()[*dim];
205 Value memrefIdx =
206 affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
207 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
208 memrefIdx);
209 }
210
211 // Condition check 2: Masked in?
212 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
213 if (cond)
214 cond = lb.create<arith::AndIOp>(cond, maskCond);
215 else
216 cond = maskCond;
217 }
218
219 // If the condition is non-empty, generate an SCF::IfOp.
220 if (cond) {
221 auto check = lb.create<scf::IfOp>(
222 cond,
223 /*thenBuilder=*/
224 [&](OpBuilder &b, Location loc) {
225 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
226 },
227 /*elseBuilder=*/
228 [&](OpBuilder &b, Location loc) {
229 if (outOfBoundsCase) {
230 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
231 } else {
232 b.create<scf::YieldOp>(loc);
233 }
234 });
235
236 return hasRetVal ? check.getResult(0) : Value();
237 }
238
239 // Condition is empty, no need for an SCF::IfOp.
240 return inBoundsCase(b, loc);
241}
242
243/// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
244/// a return value. Consequently, this function does not have a return value.
245template <typename OpTy>
246static void generateInBoundsCheck(
247 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
248 function_ref<void(OpBuilder &, Location)> inBoundsCase,
249 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
250 generateInBoundsCheck(
251 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
252 /*inBoundsCase=*/
253 [&](OpBuilder &b, Location loc) {
254 inBoundsCase(b, loc);
255 return Value();
256 },
257 /*outOfBoundsCase=*/
258 [&](OpBuilder &b, Location loc) {
259 if (outOfBoundsCase)
260 outOfBoundsCase(b, loc);
261 return Value();
262 });
263}
264
265/// Given an ArrayAttr, return a copy where the first element is dropped.
266static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
267 if (!attr)
268 return attr;
269 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
270}
271
272/// Add the pass label to a vector transfer op if its rank is not the target
273/// rank.
274template <typename OpTy>
275static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
276 unsigned targetRank) {
277 if (newXferOp.getVectorType().getRank() > targetRank)
278 newXferOp->setAttr(kPassLabel, b.getUnitAttr());
279}
280
281namespace lowering_n_d {
282
283/// Helper data structure for data and mask buffers.
284struct BufferAllocs {
285 Value dataBuffer;
286 Value maskBuffer;
287};
288
289// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
290static Operation *getAutomaticAllocationScope(Operation *op) {
291 Operation *scope =
292 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
293 assert(scope && "Expected op to be inside automatic allocation scope");
294 return scope;
295}
296
297/// Allocate temporary buffers for data (vector) and mask (if present).
298template <typename OpTy>
299static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
300 Location loc = xferOp.getLoc();
301 OpBuilder::InsertionGuard guard(b);
302 Operation *scope = getAutomaticAllocationScope(xferOp);
303 assert(scope->getNumRegions() == 1 &&
304 "AutomaticAllocationScope with >1 regions");
305 b.setInsertionPointToStart(&scope->getRegion(index: 0).front());
306
307 BufferAllocs result;
308 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
309 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
310
311 if (xferOp.getMask()) {
312 auto maskType = MemRefType::get({}, xferOp.getMask().getType());
313 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
314 b.setInsertionPoint(xferOp);
315 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
316 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
317 }
318
319 return result;
320}
321
322/// Given a MemRefType with VectorType element type, unpack one dimension from
323/// the VectorType into the MemRefType.
324///
325/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
326static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
327 auto vectorType = dyn_cast<VectorType>(type.getElementType());
328 // Vectors with leading scalable dims are not supported.
329 // It may be possible to support these in future by using dynamic memref dims.
330 if (vectorType.getScalableDims().front())
331 return failure();
332 auto memrefShape = type.getShape();
333 SmallVector<int64_t, 8> newMemrefShape;
334 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
335 newMemrefShape.push_back(vectorType.getDimSize(0));
336 return MemRefType::get(newMemrefShape,
337 VectorType::Builder(vectorType).dropDim(0));
338}
339
340/// Given a transfer op, find the memref from which the mask is loaded. This
341/// is similar to Strategy<TransferWriteOp>::getBuffer.
342template <typename OpTy>
343static Value getMaskBuffer(OpTy xferOp) {
344 assert(xferOp.getMask() && "Expected that transfer op has mask");
345 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
346 assert(loadOp && "Expected transfer op mask produced by LoadOp");
347 return loadOp.getMemRef();
348}
349
350/// Codegen strategy, depending on the operation.
351template <typename OpTy>
352struct Strategy;
353
354/// Code strategy for vector TransferReadOp.
355template <>
356struct Strategy<TransferReadOp> {
357 /// Find the StoreOp that is used for writing the current TransferReadOp's
358 /// result to the temporary buffer allocation.
359 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
360 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
361 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
362 assert(storeOp && "Expected TransferReadOp result used by StoreOp");
363 return storeOp;
364 }
365
366 /// Find the temporary buffer allocation. All labeled TransferReadOps are
367 /// used like this, where %buf is either the buffer allocation or a type cast
368 /// of the buffer allocation:
369 /// ```
370 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
371 /// memref.store %vec, %buf[...] ...
372 /// ```
373 static Value getBuffer(TransferReadOp xferOp) {
374 return getStoreOp(xferOp).getMemRef();
375 }
376
377 /// Retrieve the indices of the current StoreOp that stores into the buffer.
378 static void getBufferIndices(TransferReadOp xferOp,
379 SmallVector<Value, 8> &indices) {
380 auto storeOp = getStoreOp(xferOp);
381 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
382 indices.append(prevIndices.begin(), prevIndices.end());
383 }
384
385 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
386 /// accesses on the to-be-unpacked dimension.
387 ///
388 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
389 /// variable `iv`.
390 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
391 ///
392 /// E.g.:
393 /// ```
394 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
395 /// : memref<?x?x?xf32>, vector<4x3xf32>
396 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
397 /// ```
398 /// Is rewritten to:
399 /// ```
400 /// %casted = vector.type_cast %buf
401 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
402 /// for %j = 0 to 4 {
403 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
404 /// : memref<?x?x?xf32>, vector<3xf32>
405 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
406 /// }
407 /// ```
408 ///
409 /// Note: The loop and type cast are generated in TransferOpConversion.
410 /// The original TransferReadOp and store op are deleted in `cleanup`.
411 /// Note: The `mask` operand is set in TransferOpConversion.
412 static TransferReadOp rewriteOp(OpBuilder &b,
413 VectorTransferToSCFOptions options,
414 TransferReadOp xferOp, Value buffer, Value iv,
415 ValueRange /*loopState*/) {
416 SmallVector<Value, 8> storeIndices;
417 getBufferIndices(xferOp: xferOp, indices&: storeIndices);
418 storeIndices.push_back(iv);
419
420 SmallVector<Value, 8> xferIndices;
421 getXferIndices(b, xferOp, iv, xferIndices);
422
423 Location loc = xferOp.getLoc();
424 auto bufferType = dyn_cast<ShapedType>(buffer.getType());
425 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
426 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
427 auto newXferOp = b.create<vector::TransferReadOp>(
428 loc, vecType, xferOp.getBase(), xferIndices,
429 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
430 xferOp.getPadding(), Value(), inBoundsAttr);
431
432 maybeApplyPassLabel(b, newXferOp, options.targetRank);
433
434 b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
435 return newXferOp;
436 }
437
438 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
439 /// padding value to the temporary buffer.
440 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
441 Value buffer, Value iv,
442 ValueRange /*loopState*/) {
443 SmallVector<Value, 8> storeIndices;
444 getBufferIndices(xferOp: xferOp, indices&: storeIndices);
445 storeIndices.push_back(iv);
446
447 Location loc = xferOp.getLoc();
448 auto bufferType = dyn_cast<ShapedType>(buffer.getType());
449 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
450 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
451 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
452
453 return Value();
454 }
455
456 /// Cleanup after rewriting the op.
457 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
458 scf::ForOp /*forOp*/) {
459 rewriter.eraseOp(op: getStoreOp(xferOp));
460 rewriter.eraseOp(op: xferOp);
461 }
462
463 /// Return the initial loop state for the generated scf.for loop.
464 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
465};
466
467/// Codegen strategy for vector TransferWriteOp.
468template <>
469struct Strategy<TransferWriteOp> {
470 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
471 /// used like this, where %buf is either the buffer allocation or a type cast
472 /// of the buffer allocation:
473 /// ```
474 /// %vec = memref.load %buf[...] ...
475 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
476 /// ```
477 static Value getBuffer(TransferWriteOp xferOp) {
478 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479 assert(loadOp && "Expected transfer op vector produced by LoadOp");
480 return loadOp.getMemRef();
481 }
482
483 /// Retrieve the indices of the current LoadOp that loads from the buffer.
484 static void getBufferIndices(TransferWriteOp xferOp,
485 SmallVector<Value, 8> &indices) {
486 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
487 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
488 indices.append(prevIndices.begin(), prevIndices.end());
489 }
490
491 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
492 /// accesses on the to-be-unpacked dimension.
493 ///
494 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
495 /// using the loop iteration variable `iv`.
496 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
497 /// to memory.
498 ///
499 /// Note: For more details, see comments on Strategy<TransferReadOp>.
500 static TransferWriteOp rewriteOp(OpBuilder &b,
501 VectorTransferToSCFOptions options,
502 TransferWriteOp xferOp, Value buffer,
503 Value iv, ValueRange loopState) {
504 SmallVector<Value, 8> loadIndices;
505 getBufferIndices(xferOp: xferOp, indices&: loadIndices);
506 loadIndices.push_back(iv);
507
508 SmallVector<Value, 8> xferIndices;
509 getXferIndices(b, xferOp, iv, xferIndices);
510
511 Location loc = xferOp.getLoc();
512 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
513 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
514 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
515 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
516 auto newXferOp = b.create<vector::TransferWriteOp>(
517 loc, type, vec, source, xferIndices,
518 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
519 inBoundsAttr);
520
521 maybeApplyPassLabel(b, newXferOp, options.targetRank);
522
523 return newXferOp;
524 }
525
526 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
527 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
528 Value buffer, Value iv,
529 ValueRange loopState) {
530 return isTensorOp(xferOp) ? loopState[0] : Value();
531 }
532
533 /// Cleanup after rewriting the op.
534 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
535 scf::ForOp forOp) {
536 if (isTensorOp(xferOp)) {
537 assert(forOp->getNumResults() == 1 && "Expected one for loop result");
538 rewriter.replaceOp(xferOp, forOp->getResult(0));
539 } else {
540 rewriter.eraseOp(op: xferOp);
541 }
542 }
543
544 /// Return the initial loop state for the generated scf.for loop.
545 static Value initialLoopState(TransferWriteOp xferOp) {
546 return isTensorOp(xferOp) ? xferOp.getBase() : Value();
547 }
548};
549
550template <typename OpTy>
551static LogicalResult checkPrepareXferOp(OpTy xferOp, PatternRewriter &rewriter,
552 VectorTransferToSCFOptions options) {
553 if (xferOp->hasAttr(kPassLabel))
554 return rewriter.notifyMatchFailure(
555 xferOp, "kPassLabel is present (vector-to-scf lowering in progress)");
556 if (xferOp.getVectorType().getRank() <= options.targetRank)
557 return rewriter.notifyMatchFailure(
558 xferOp, "xferOp vector rank <= transformation target rank");
559 if (xferOp.getVectorType().getScalableDims().front())
560 return rewriter.notifyMatchFailure(
561 xferOp, "Unpacking of the leading dimension into the memref is not yet "
562 "supported for scalable dims");
563 if (isTensorOp(xferOp) && !options.lowerTensors)
564 return rewriter.notifyMatchFailure(
565 xferOp, "Unpacking for tensors has been disabled.");
566 if (xferOp.getVectorType().getElementType() !=
567 xferOp.getShapedType().getElementType())
568 return rewriter.notifyMatchFailure(
569 xferOp, "Mismatching source and destination element types.");
570
571 return success();
572}
573
574/// Prepare a TransferReadOp for progressive lowering.
575///
576/// 1. Allocate a temporary buffer.
577/// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
578/// 3. Store the result of the TransferReadOp into the temporary buffer.
579/// 4. Load the result from the temporary buffer and replace all uses of the
580/// original TransferReadOp with this load.
581///
582/// E.g.:
583/// ```
584/// %vec = vector.transfer_read %A[%a, %b, %c], %cst
585/// : vector<5x4xf32>, memref<?x?x?xf32>
586/// ```
587/// is rewritten to:
588/// ```
589/// %0 = memref.alloca() : memref<vector<5x4xf32>>
590/// %1 = vector.transfer_read %A[%a, %b, %c], %cst
591/// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
592/// memref.store %1, %0[] : memref<vector<5x4xf32>>
593/// %vec = memref.load %0[] : memref<vector<5x4xf32>>
594/// ```
595///
596/// Note: A second temporary buffer may be allocated for the `mask` operand.
597struct PrepareTransferReadConversion
598 : public VectorToSCFPattern<TransferReadOp> {
599 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
600
601 LogicalResult matchAndRewrite(TransferReadOp xferOp,
602 PatternRewriter &rewriter) const override {
603 if (checkPrepareXferOp(xferOp, rewriter, options).failed())
604 return rewriter.notifyMatchFailure(
605 xferOp, "checkPrepareXferOp conditions not met!");
606
607 auto buffers = allocBuffers(rewriter, xferOp);
608 auto *newXfer = rewriter.clone(*xferOp.getOperation());
609 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
610 if (xferOp.getMask()) {
611 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
612 buffers.maskBuffer);
613 }
614
615 Location loc = xferOp.getLoc();
616 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
617 buffers.dataBuffer);
618 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
619
620 return success();
621 }
622};
623
624/// Prepare a TransferWriteOp for progressive lowering.
625///
626/// 1. Allocate a temporary buffer.
627/// 2. Store the vector into the buffer.
628/// 3. Load the vector from the buffer again.
629/// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
630/// marking it eligible for progressive lowering via TransferOpConversion.
631///
632/// E.g.:
633/// ```
634/// vector.transfer_write %vec, %A[%a, %b, %c]
635/// : vector<5x4xf32>, memref<?x?x?xf32>
636/// ```
637/// is rewritten to:
638/// ```
639/// %0 = memref.alloca() : memref<vector<5x4xf32>>
640/// memref.store %vec, %0[] : memref<vector<5x4xf32>>
641/// %1 = memref.load %0[] : memref<vector<5x4xf32>>
642/// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
643/// : vector<5x4xf32>, memref<?x?x?xf32>
644/// ```
645///
646/// Note: A second temporary buffer may be allocated for the `mask` operand.
647struct PrepareTransferWriteConversion
648 : public VectorToSCFPattern<TransferWriteOp> {
649 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
650
651 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
652 PatternRewriter &rewriter) const override {
653 if (checkPrepareXferOp(xferOp, rewriter, options).failed())
654 return rewriter.notifyMatchFailure(
655 xferOp, "checkPrepareXferOp conditions not met!");
656
657 Location loc = xferOp.getLoc();
658 auto buffers = allocBuffers(rewriter, xferOp);
659 rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
660 buffers.dataBuffer);
661 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
662 rewriter.modifyOpInPlace(xferOp, [&]() {
663 xferOp.getValueToStoreMutable().assign(loadedVec);
664 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
665 });
666
667 if (xferOp.getMask()) {
668 rewriter.modifyOpInPlace(xferOp, [&]() {
669 xferOp.getMaskMutable().assign(buffers.maskBuffer);
670 });
671 }
672
673 return success();
674 }
675};
676
677/// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
678/// printing both 1D scalable vectors and n-D fixed size vectors.
679///
680/// E.g.:
681/// ```
682/// vector.print %v : vector<[4]xi32>
683/// ```
684/// is rewritten to:
685/// ```
686/// %c0 = arith.constant 0 : index
687/// %c4 = arith.constant 4 : index
688/// %c1 = arith.constant 1 : index
689/// %vscale = vector.vscale
690/// %length = arith.muli %vscale, %c4 : index
691/// %lastIndex = arith.subi %length, %c1 : index
692/// vector.print punctuation <open>
693/// scf.for %i = %c0 to %length step %c1 {
694/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
695/// vector.print %el : i32 punctuation <no_punctuation>
696/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
697/// scf.if %notLastIndex {
698/// vector.print punctuation <comma>
699/// }
700/// }
701/// vector.print punctuation <close>
702/// vector.print
703/// ```
704struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
705 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
706 LogicalResult matchAndRewrite(vector::PrintOp printOp,
707 PatternRewriter &rewriter) const override {
708 if (!printOp.getSource())
709 return failure();
710
711 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
712 if (!vectorType)
713 return failure();
714
715 // Currently >= 2D scalable vectors are not supported.
716 // These can't be lowered to LLVM (as LLVM does not support scalable vectors
717 // of scalable vectors), and due to limitations of current ops can't be
718 // indexed with SSA values or flattened. This may change after
719 // https://reviews.llvm.org/D155034, though there still needs to be a path
720 // for lowering to LLVM.
721 if (vectorType.getRank() > 1 && vectorType.isScalable())
722 return failure();
723
724 auto loc = printOp.getLoc();
725 auto value = printOp.getSource();
726
727 if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
728 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
729 // avoid issues extend them to a more standard size.
730 // https://github.com/llvm/llvm-project/issues/30613
731 auto width = intTy.getWidth();
732 auto legalWidth = llvm::NextPowerOf2(A: std::max(8u, width) - 1);
733 auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
734 intTy.getSignedness());
735 // arith can only take signless integers, so we must cast back and forth.
736 auto signlessSourceVectorType =
737 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy: intTy));
738 auto signlessTargetVectorType =
739 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy: legalIntTy));
740 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
741 value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
742 value);
743 if (value.getType() != signlessTargetVectorType) {
744 if (width == 1 || intTy.isUnsigned())
745 value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
746 value);
747 else
748 value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
749 value);
750 }
751 value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
752 vectorType = targetVectorType;
753 }
754
755 auto scalableDimensions = vectorType.getScalableDims();
756 auto shape = vectorType.getShape();
757 constexpr int64_t singletonShape[] = {1};
758 if (vectorType.getRank() == 0)
759 shape = singletonShape;
760
761 if (vectorType.getRank() != 1) {
762 // Flatten n-D vectors to 1D. This is done to allow indexing with a
763 // non-constant value (which can currently only be done via
764 // vector.extractelement for 1D vectors).
765 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
766 std::multiplies<int64_t>());
767 auto flatVectorType =
768 VectorType::get({flatLength}, vectorType.getElementType());
769 value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
770 }
771
772 vector::PrintOp firstClose;
773 SmallVector<Value, 8> loopIndices;
774 for (unsigned d = 0; d < shape.size(); d++) {
775 // Setup loop bounds and step.
776 Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
777 Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
778 Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
779 if (!scalableDimensions.empty() && scalableDimensions[d]) {
780 auto vscale = rewriter.create<vector::VectorScaleOp>(
781 loc, rewriter.getIndexType());
782 upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
783 }
784 auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
785
786 // Create a loop to print the elements surrounded by parentheses.
787 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
788 auto loop =
789 rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
790 auto printClose = rewriter.create<vector::PrintOp>(
791 loc, vector::PrintPunctuation::Close);
792 if (!firstClose)
793 firstClose = printClose;
794
795 auto loopIdx = loop.getInductionVar();
796 loopIndices.push_back(loopIdx);
797
798 // Print a comma after all but the last element.
799 rewriter.setInsertionPointToStart(loop.getBody());
800 auto notLastIndex = rewriter.create<arith::CmpIOp>(
801 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
802 rewriter.create<scf::IfOp>(loc, notLastIndex,
803 [&](OpBuilder &builder, Location loc) {
804 builder.create<vector::PrintOp>(
805 loc, vector::PrintPunctuation::Comma);
806 builder.create<scf::YieldOp>(loc);
807 });
808
809 rewriter.setInsertionPointToStart(loop.getBody());
810 }
811
812 // Compute the flattened index.
813 // Note: For the > rank 1 vectors this assumes non-scalable.
814 Value flatIndex;
815 auto currentStride = 1;
816 for (int d = shape.size() - 1; d >= 0; d--) {
817 auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
818 auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
819 if (flatIndex)
820 flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
821 else
822 flatIndex = index;
823 currentStride *= shape[d];
824 }
825
826 // Print the scalar elements in the inner most loop.
827 auto element =
828 rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
829 rewriter.create<vector::PrintOp>(loc, element,
830 vector::PrintPunctuation::NoPunctuation);
831
832 rewriter.setInsertionPointAfter(firstClose);
833 rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
834 rewriter.eraseOp(op: printOp);
835 return success();
836 }
837
838 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
839 return IntegerType::get(intTy.getContext(), intTy.getWidth(),
840 IntegerType::Signless);
841 };
842};
843
844/// Progressive lowering of vector transfer ops: Unpack one dimension.
845///
846/// 1. Unpack one dimension from the current buffer type and cast the buffer
847/// to that new type. E.g.:
848/// ```
849/// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
850/// vector.transfer_write %vec ...
851/// ```
852/// The following cast is generated:
853/// ```
854/// %casted = vector.type_cast %0
855/// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
856/// ```
857/// 2. Generate a for loop and rewrite the transfer op according to the
858/// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
859/// out-of-bounds, generate an if-check and handle both cases separately.
860/// 3. Clean up according to the corresponding Strategy<OpTy>.
861///
862/// Note: If the transfer op is a TransferWriteOp and operates on a tensor
863/// source (as opposed to a memref source), then each iteration of the generated
864/// scf.for loop yields the new tensor value. E.g.:
865/// ```
866/// %result = scf.for i = 0 to 5 {
867/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
868/// %1 = vector.transfer_write %0, %source[...]
869/// : vector<4x3xf32>, tensor<5x4x3xf32>
870/// scf.yield %1 : tensor<5x4x3xf32>
871/// }
872/// ```
873template <typename OpTy>
874struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
875 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
876
877 void initialize() {
878 // This pattern recursively unpacks one dimension at a time. The recursion
879 // bounded as the rank is strictly decreasing.
880 this->setHasBoundedRewriteRecursion();
881 }
882
883 static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
884 SmallVectorImpl<Value> &loadIndices,
885 Value iv) {
886 assert(xferOp.getMask() && "Expected transfer op to have mask");
887
888 // Add load indices from the previous iteration.
889 // The mask buffer depends on the permutation map, which makes determining
890 // the indices quite complex, so this is why we need to "look back" to the
891 // previous iteration to find the right indices.
892 Value maskBuffer = getMaskBuffer(xferOp);
893 for (Operation *user : maskBuffer.getUsers()) {
894 // If there is no previous load op, then the indices are empty.
895 if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
896 Operation::operand_range prevIndices = loadOp.getIndices();
897 loadIndices.append(prevIndices.begin(), prevIndices.end());
898 break;
899 }
900 }
901
902 // In case of broadcast: Use same indices to load from memref
903 // as before.
904 if (!xferOp.isBroadcastDim(0))
905 loadIndices.push_back(Elt: iv);
906 }
907
908 LogicalResult matchAndRewrite(OpTy xferOp,
909 PatternRewriter &rewriter) const override {
910 if (!xferOp->hasAttr(kPassLabel))
911 return rewriter.notifyMatchFailure(
912 xferOp, "kPassLabel is present (progressing lowering in progress)");
913
914 // Find and cast data buffer. How the buffer can be found depends on OpTy.
915 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
916 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
917 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
918 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
919 if (failed(castedDataType))
920 return rewriter.notifyMatchFailure(xferOp,
921 "Failed to unpack one vector dim.");
922
923 auto castedDataBuffer =
924 locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
925
926 // If the xferOp has a mask: Find and cast mask buffer.
927 Value castedMaskBuffer;
928 if (xferOp.getMask()) {
929 Value maskBuffer = getMaskBuffer(xferOp);
930 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
931 // Do not unpack a dimension of the mask, if:
932 // * To-be-unpacked transfer op dimension is a broadcast.
933 // * Mask is 1D, i.e., the mask cannot be further unpacked.
934 // (That means that all remaining dimensions of the transfer op must
935 // be broadcasted.)
936 castedMaskBuffer = maskBuffer;
937 } else {
938 // It's safe to assume the mask buffer can be unpacked if the data
939 // buffer was unpacked.
940 auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
941 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
942 castedMaskBuffer =
943 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
944 }
945 }
946
947 // Loop bounds and step.
948 auto lb = locB.create<arith::ConstantIndexOp>(0);
949 auto ub = locB.create<arith::ConstantIndexOp>(
950 castedDataType->getDimSize(castedDataType->getRank() - 1));
951 auto step = locB.create<arith::ConstantIndexOp>(1);
952 // TransferWriteOps that operate on tensors return the modified tensor and
953 // require a loop state.
954 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
955
956 // Generate for loop.
957 auto result = locB.create<scf::ForOp>(
958 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
959 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
960 Type stateType = loopState.empty() ? Type() : loopState[0].getType();
961
962 auto result = generateInBoundsCheck(
963 b, xferOp, iv, unpackedDim(xferOp),
964 stateType ? TypeRange(stateType) : TypeRange(),
965 /*inBoundsCase=*/
966 [&](OpBuilder &b, Location loc) {
967 // Create new transfer op.
968 OpTy newXfer = Strategy<OpTy>::rewriteOp(
969 b, this->options, xferOp, castedDataBuffer, iv, loopState);
970
971 // If old transfer op has a mask: Set mask on new transfer op.
972 // Special case: If the mask of the old transfer op is 1D and
973 // the unpacked dim is not a broadcast, no mask is needed on
974 // the new transfer op.
975 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
976 xferOp.getMaskType().getRank() > 1)) {
977 OpBuilder::InsertionGuard guard(b);
978 b.setInsertionPoint(newXfer); // Insert load before newXfer.
979
980 SmallVector<Value, 8> loadIndices;
981 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
982 loadIndices, iv);
983 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
984 loadIndices);
985 rewriter.modifyOpInPlace(newXfer, [&]() {
986 newXfer.getMaskMutable().assign(mask);
987 });
988 }
989
990 return loopState.empty() ? Value() : newXfer->getResult(0);
991 },
992 /*outOfBoundsCase=*/
993 [&](OpBuilder &b, Location /*loc*/) {
994 return Strategy<OpTy>::handleOutOfBoundsDim(
995 b, xferOp, castedDataBuffer, iv, loopState);
996 });
997
998 maybeYieldValue(b, loc, !loopState.empty(), result);
999 });
1000
1001 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1002 return success();
1003 }
1004};
1005
1006/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1007/// and ConstantMaskOp.
1008template <typename VscaleConstantBuilder>
1009static FailureOr<SmallVector<OpFoldResult>>
1010getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1011 if (!mask)
1012 return SmallVector<OpFoldResult>{};
1013 if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1014 return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1015 return OpFoldResult(dimSize);
1016 });
1017 }
1018 if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1019 int dimIdx = 0;
1020 VectorType maskType = constantMask.getVectorType();
1021 auto indexType = IndexType::get(mask.getContext());
1022 return llvm::map_to_vector(
1023 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1024 // A scalable dim in a constant_mask means vscale x dimSize.
1025 if (maskType.getScalableDims()[dimIdx++])
1026 return OpFoldResult(createVscaleMultiple(dimSize));
1027 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1028 });
1029 }
1030 return failure();
1031}
1032
1033/// Scalable vector lowering of transfer_write(transpose). This lowering only
1034/// supports rank 2 (scalable) vectors, but can be used in conjunction with
1035/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1036/// unrolls until the first scalable dimension.
1037///
1038/// Example:
1039///
1040/// BEFORE:
1041/// ```mlir
1042/// %transpose = vector.transpose %vec, [1, 0]
1043/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1044/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1045/// : vector<[4]x4xf32>, memref<?x?xf32>
1046/// ```
1047///
1048/// AFTER:
1049/// ```mlir
1050/// %c1 = arith.constant 1 : index
1051/// %c4 = arith.constant 4 : index
1052/// %c0 = arith.constant 0 : index
1053/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1054/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1055/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1056/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1057/// %vscale = vector.vscale
1058/// %c4_vscale = arith.muli %vscale, %c4 : index
1059/// scf.for %idx = %c0 to %c4_vscale step %c1 {
1060/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1061/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1062/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1063/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1064/// %slice_i = affine.apply #map(%idx)[%i]
1065/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1066/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1067/// : vector<4xf32>, memref<?x?xf32>
1068/// }
1069/// ```
1070struct ScalableTransposeTransferWriteConversion
1071 : VectorToSCFPattern<vector::TransferWriteOp> {
1072 using VectorToSCFPattern::VectorToSCFPattern;
1073
1074 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1075 PatternRewriter &rewriter) const override {
1076 if (failed(checkLowerTensors(writeOp, rewriter)))
1077 return failure();
1078
1079 VectorType vectorType = writeOp.getVectorType();
1080
1081 // Note: By comparing the scalable dims to an ArrayRef of length two this
1082 // implicitly checks the rank (is also two).
1083 ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1084 if (scalableFlags != ArrayRef<bool>{true, false}) {
1085 return rewriter.notifyMatchFailure(
1086 writeOp, "expected vector of the form vector<[N]xMxty>");
1087 }
1088
1089 auto permutationMap = writeOp.getPermutationMap();
1090 if (!permutationMap.isIdentity()) {
1091 return rewriter.notifyMatchFailure(
1092 writeOp, "non-identity permutations are unsupported (lower first)");
1093 }
1094
1095 // Note: This pattern is only lowering the leading dimension (to a loop),
1096 // so we only check if the leading dimension is in bounds. The in-bounds
1097 // attribute for the trailing dimension will be propagated.
1098 if (!writeOp.isDimInBounds(0)) {
1099 return rewriter.notifyMatchFailure(
1100 writeOp, "out-of-bounds dims are unsupported (use masking)");
1101 }
1102
1103 Value vector = writeOp.getVector();
1104 auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1105 if (!transposeOp ||
1106 transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1107 return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1108 }
1109
1110 auto loc = writeOp.getLoc();
1111 auto createVscaleMultiple =
1112 vector::makeVscaleConstantBuilder(rewriter, loc: loc);
1113
1114 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1115 if (failed(maskDims)) {
1116 return rewriter.notifyMatchFailure(writeOp,
1117 "failed to resolve mask dims");
1118 }
1119
1120 int64_t fixedDimSize = vectorType.getDimSize(1);
1121 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1122
1123 // Extract all slices from the source of the transpose.
1124 auto transposeSource = transposeOp.getVector();
1125 SmallVector<Value> transposeSourceSlices =
1126 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1127 return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
1128 });
1129
1130 // Loop bounds and step.
1131 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1132 auto ub =
1133 maskDims->empty()
1134 ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1135 : vector::getAsValues(builder&: rewriter, loc: loc, foldResults: maskDims->front()).front();
1136 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1137
1138 // Generate a new mask for the slice.
1139 VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1140 Value sliceMask = nullptr;
1141 if (!maskDims->empty()) {
1142 sliceMask = rewriter.create<vector::CreateMaskOp>(
1143 loc, sliceType.clone(rewriter.getI1Type()),
1144 ArrayRef<OpFoldResult>(*maskDims).drop_front());
1145 }
1146
1147 Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
1148 ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1149 auto result = rewriter.create<scf::ForOp>(
1150 loc, lb, ub, step, initLoopArgs,
1151 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1152 // Indices for the new transfer op.
1153 SmallVector<Value, 8> xferIndices;
1154 getXferIndices(b, writeOp, iv, xferIndices);
1155
1156 // Extract a transposed slice from the source vector.
1157 SmallVector<Value> transposeElements =
1158 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1159 return b.create<vector::ExtractOp>(
1160 loc, transposeSourceSlices[idx], iv);
1161 });
1162 auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1163 transposeElements);
1164
1165 // Create the transfer_write for the slice.
1166 Value dest =
1167 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1168 auto newWriteOp = b.create<vector::TransferWriteOp>(
1169 loc, sliceVec, dest, xferIndices,
1170 ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1171 if (sliceMask)
1172 newWriteOp.getMaskMutable().assign(sliceMask);
1173
1174 // Yield from the loop.
1175 b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1176 ? ValueRange{}
1177 : newWriteOp.getResult());
1178 });
1179
1180 if (isTensorOp(writeOp))
1181 rewriter.replaceOp(writeOp, result);
1182 else
1183 rewriter.eraseOp(op: writeOp);
1184
1185 return success();
1186 }
1187};
1188
1189} // namespace lowering_n_d
1190
1191namespace lowering_n_d_unrolled {
1192
1193/// If the original transfer op has a mask, compute the mask of the new transfer
1194/// op (for the current iteration `i`) and assign it.
1195template <typename OpTy>
1196static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1197 int64_t i) {
1198 if (!xferOp.getMask())
1199 return;
1200
1201 if (xferOp.isBroadcastDim(0)) {
1202 // To-be-unpacked dimension is a broadcast, which does not have a
1203 // corresponding mask dimension. Mask attribute remains unchanged.
1204 newXferOp.getMaskMutable().assign(xferOp.getMask());
1205 return;
1206 }
1207
1208 if (xferOp.getMaskType().getRank() > 1) {
1209 // Unpack one dimension of the mask.
1210 OpBuilder::InsertionGuard guard(b);
1211 b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1212
1213 llvm::SmallVector<int64_t, 1> indices({i});
1214 Location loc = xferOp.getLoc();
1215 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1216 newXferOp.getMaskMutable().assign(newMask);
1217 }
1218
1219 // If we end up here: The mask of the old transfer op is 1D and the unpacked
1220 // dim is not a broadcast, so no mask is needed on the new transfer op.
1221 // `generateInBoundsCheck` will have evaluated the mask already.
1222}
1223
1224/// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1225/// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1226/// memref buffer is allocated and the SCF loop is fully unrolled.
1227///
1228/// ```
1229/// E.g.:
1230/// ```
1231/// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1232/// : memref<?x?x?xf32>, vector<5x4xf32>
1233/// ```
1234/// is rewritten to IR such as (simplified):
1235/// ```
1236/// %v_init = splat %padding : vector<5x4xf32>
1237/// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1238/// : memref<?x?x?xf32>, vector<4xf32>
1239/// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1240/// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1241/// : memref<?x?x?xf32>, vector<4xf32>
1242/// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1243/// ...
1244/// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1245/// : memref<?x?x?xf32>, vector<4xf32>
1246/// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1247/// ```
1248///
1249/// Note: As an optimization, if the result of the original TransferReadOp
1250/// was directly inserted into another vector, no new %v_init vector is created.
1251/// Instead, the new TransferReadOp results are inserted into that vector.
1252struct UnrollTransferReadConversion
1253 : public VectorToSCFPattern<TransferReadOp> {
1254 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1255
1256 void initialize() {
1257 // This pattern recursively unpacks one dimension at a time. The recursion
1258 // bounded as the rank is strictly decreasing.
1259 setHasBoundedRewriteRecursion();
1260 }
1261
1262 /// Get or build the vector into which the newly created TransferReadOp
1263 /// results are inserted.
1264 Value buildResultVector(PatternRewriter &rewriter,
1265 TransferReadOp xferOp) const {
1266 if (auto insertOp = getInsertOp(xferOp))
1267 return insertOp.getDest();
1268 Location loc = xferOp.getLoc();
1269 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1270 xferOp.getPadding());
1271 }
1272
1273 /// If the result of the TransferReadOp has exactly one user, which is a
1274 /// vector::InsertOp, return that operation.
1275 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1276 if (xferOp->hasOneUse()) {
1277 Operation *xferOpUser = *xferOp->getUsers().begin();
1278 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1279 return insertOp;
1280 }
1281
1282 return vector::InsertOp();
1283 }
1284
1285 /// If the result of the TransferReadOp has exactly one user, which is a
1286 /// vector::InsertOp, return that operation's indices.
1287 void getInsertionIndices(TransferReadOp xferOp,
1288 SmallVectorImpl<OpFoldResult> &indices) const {
1289 if (auto insertOp = getInsertOp(xferOp)) {
1290 auto pos = insertOp.getMixedPosition();
1291 indices.append(pos.begin(), pos.end());
1292 }
1293 }
1294
1295 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1296 /// accesses, and broadcasts and transposes in permutation maps.
1297 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1298 PatternRewriter &rewriter) const override {
1299 if (xferOp.getVectorType().getRank() <= options.targetRank)
1300 return rewriter.notifyMatchFailure(
1301 xferOp, "vector rank is less or equal to target rank");
1302 if (failed(checkLowerTensors(xferOp, rewriter)))
1303 return failure();
1304 if (xferOp.getVectorType().getElementType() !=
1305 xferOp.getShapedType().getElementType())
1306 return rewriter.notifyMatchFailure(
1307 xferOp, "not yet supported: element type mismatch");
1308 auto xferVecType = xferOp.getVectorType();
1309 if (xferVecType.getScalableDims()[0]) {
1310 return rewriter.notifyMatchFailure(
1311 xferOp, "scalable dimensions cannot be unrolled at compile time");
1312 }
1313
1314 auto insertOp = getInsertOp(xferOp);
1315 auto vec = buildResultVector(rewriter, xferOp: xferOp);
1316 auto vecType = dyn_cast<VectorType>(vec.getType());
1317
1318 VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1319
1320 int64_t dimSize = xferVecType.getShape()[0];
1321
1322 // Generate fully unrolled loop of transfer ops.
1323 Location loc = xferOp.getLoc();
1324 for (int64_t i = 0; i < dimSize; ++i) {
1325 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1326
1327 vec = generateInBoundsCheck(
1328 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1329 /*inBoundsCase=*/
1330 [&](OpBuilder &b, Location loc) {
1331 // Indices for the new transfer op.
1332 SmallVector<Value, 8> xferIndices;
1333 getXferIndices(b, xferOp, iv, xferIndices);
1334
1335 // Indices for the new vector.insert op.
1336 SmallVector<OpFoldResult, 8> insertionIndices;
1337 getInsertionIndices(xferOp: xferOp, indices&: insertionIndices);
1338 insertionIndices.push_back(rewriter.getIndexAttr(i));
1339
1340 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1341 auto newXferOp = b.create<vector::TransferReadOp>(
1342 loc, newXferVecType, xferOp.getBase(), xferIndices,
1343 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1344 xferOp.getPadding(), Value(), inBoundsAttr);
1345 maybeAssignMask(b, xferOp, newXferOp, i);
1346 return b.create<vector::InsertOp>(loc, newXferOp, vec,
1347 insertionIndices);
1348 },
1349 /*outOfBoundsCase=*/
1350 [&](OpBuilder &b, Location loc) {
1351 // Loop through original (unmodified) vector.
1352 return vec;
1353 });
1354 }
1355
1356 if (insertOp) {
1357 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1358 rewriter.replaceOp(insertOp, vec);
1359 rewriter.eraseOp(op: xferOp);
1360 } else {
1361 rewriter.replaceOp(xferOp, vec);
1362 }
1363
1364 return success();
1365 }
1366};
1367
1368/// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1369/// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1370/// memref buffer is allocated and the SCF loop is fully unrolled.
1371///
1372/// ```
1373/// E.g.:
1374/// ```
1375/// vector.transfer_write %vec, %A[%a, %b, %c]
1376/// : vector<5x4xf32>, memref<?x?x?xf32>
1377/// ```
1378/// is rewritten to IR such as (simplified):
1379/// ```
1380/// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1381/// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1382/// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1383/// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1384/// ...
1385/// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1386/// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1387/// ```
1388///
1389/// Note: As an optimization, if the vector of the original TransferWriteOp
1390/// was directly extracted from another vector via an ExtractOp `a`, extract
1391/// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1392/// doing so, `a` may become dead, and the number of ExtractOps generated during
1393/// recursive application of this pattern will be minimal.
1394struct UnrollTransferWriteConversion
1395 : public VectorToSCFPattern<TransferWriteOp> {
1396 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1397
1398 void initialize() {
1399 // This pattern recursively unpacks one dimension at a time. The recursion
1400 // bounded as the rank is strictly decreasing.
1401 setHasBoundedRewriteRecursion();
1402 }
1403
1404 /// Return the vector from which newly generated ExtracOps will extract.
1405 Value getDataVector(TransferWriteOp xferOp) const {
1406 if (auto extractOp = getExtractOp(xferOp))
1407 return extractOp.getVector();
1408 return xferOp.getVector();
1409 }
1410
1411 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1412 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1413 if (auto *op = xferOp.getVector().getDefiningOp())
1414 return dyn_cast<vector::ExtractOp>(op);
1415 return vector::ExtractOp();
1416 }
1417
1418 /// If the input of the given TransferWriteOp is an ExtractOp, return its
1419 /// indices.
1420 void getExtractionIndices(TransferWriteOp xferOp,
1421 SmallVectorImpl<OpFoldResult> &indices) const {
1422 if (auto extractOp = getExtractOp(xferOp)) {
1423 auto pos = extractOp.getMixedPosition();
1424 indices.append(pos.begin(), pos.end());
1425 }
1426 }
1427
1428 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1429 /// accesses, and broadcasts and transposes in permutation maps.
1430 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1431 PatternRewriter &rewriter) const override {
1432 VectorType inputVectorTy = xferOp.getVectorType();
1433
1434 if (inputVectorTy.getRank() <= options.targetRank)
1435 return failure();
1436
1437 if (failed(checkLowerTensors(xferOp, rewriter)))
1438 return failure();
1439 // Transfer ops that modify the element type are not supported atm.
1440 if (inputVectorTy.getElementType() !=
1441 xferOp.getShapedType().getElementType())
1442 return failure();
1443
1444 auto vec = getDataVector(xferOp: xferOp);
1445 if (inputVectorTy.getScalableDims()[0]) {
1446 // Cannot unroll a scalable dimension at compile time.
1447 return failure();
1448 }
1449
1450 int64_t dimSize = inputVectorTy.getShape()[0];
1451 Value source = xferOp.getBase(); // memref or tensor to be written to.
1452 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1453
1454 // Generate fully unrolled loop of transfer ops.
1455 Location loc = xferOp.getLoc();
1456 for (int64_t i = 0; i < dimSize; ++i) {
1457 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1458
1459 auto updatedSource = generateInBoundsCheck(
1460 rewriter, xferOp, iv, unpackedDim(xferOp),
1461 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1462 /*inBoundsCase=*/
1463 [&](OpBuilder &b, Location loc) {
1464 // Indices for the new transfer op.
1465 SmallVector<Value, 8> xferIndices;
1466 getXferIndices(b, xferOp, iv, xferIndices);
1467
1468 // Indices for the new vector.extract op.
1469 SmallVector<OpFoldResult, 8> extractionIndices;
1470 getExtractionIndices(xferOp: xferOp, indices&: extractionIndices);
1471 extractionIndices.push_back(b.getI64IntegerAttr(i));
1472
1473 auto extracted =
1474 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1475 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1476 Value xferVec;
1477 if (inputVectorTy.getRank() == 1) {
1478 // When target-rank=0, unrolling would causes the vector input
1479 // argument into `transfer_write` to become a scalar. We solve
1480 // this by broadcasting the scalar to a 0D vector.
1481 xferVec = b.create<vector::BroadcastOp>(
1482 loc, VectorType::get({}, extracted.getType()), extracted);
1483 } else {
1484 xferVec = extracted;
1485 }
1486 auto newXferOp = b.create<vector::TransferWriteOp>(
1487 loc, sourceType, xferVec, source, xferIndices,
1488 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1489 inBoundsAttr);
1490
1491 maybeAssignMask(b, xferOp, newXferOp, i);
1492
1493 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1494 },
1495 /*outOfBoundsCase=*/
1496 [&](OpBuilder &b, Location loc) {
1497 return isTensorOp(xferOp) ? source : Value();
1498 });
1499
1500 if (isTensorOp(xferOp))
1501 source = updatedSource;
1502 }
1503
1504 if (isTensorOp(xferOp))
1505 rewriter.replaceOp(xferOp, source);
1506 else
1507 rewriter.eraseOp(op: xferOp);
1508
1509 return success();
1510 }
1511};
1512
1513} // namespace lowering_n_d_unrolled
1514
1515namespace lowering_1_d {
1516
1517/// Compute the indices into the memref for the LoadOp/StoreOp generated as
1518/// part of TransferOp1dConversion. Return the memref dimension on which
1519/// the transfer is operating. A return value of std::nullopt indicates a
1520/// broadcast.
1521template <typename OpTy>
1522static std::optional<int64_t>
1523get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1524 SmallVector<Value, 8> &memrefIndices) {
1525 auto indices = xferOp.getIndices();
1526 auto map = xferOp.getPermutationMap();
1527 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1528
1529 memrefIndices.append(indices.begin(), indices.end());
1530 assert(map.getNumResults() == 1 &&
1531 "Expected 1 permutation map result for 1D transfer");
1532 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1533 Location loc = xferOp.getLoc();
1534 auto dim = expr.getPosition();
1535 AffineExpr d0, d1;
1536 bindDims(xferOp.getContext(), d0, d1);
1537 Value offset = memrefIndices[dim];
1538 memrefIndices[dim] =
1539 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1540 return dim;
1541 }
1542
1543 assert(xferOp.isBroadcastDim(0) &&
1544 "Expected AffineDimExpr or AffineConstantExpr");
1545 return std::nullopt;
1546}
1547
1548/// Codegen strategy for TransferOp1dConversion, depending on the
1549/// operation.
1550template <typename OpTy>
1551struct Strategy1d;
1552
1553/// Codegen strategy for TransferReadOp.
1554template <>
1555struct Strategy1d<TransferReadOp> {
1556 static void generateForLoopBody(OpBuilder &b, Location loc,
1557 TransferReadOp xferOp, Value iv,
1558 ValueRange loopState) {
1559 SmallVector<Value, 8> indices;
1560 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1561 auto vec = loopState[0];
1562
1563 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1564 // padding value).
1565 auto nextVec = generateInBoundsCheck(
1566 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1567 /*inBoundsCase=*/
1568 [&](OpBuilder &b, Location loc) {
1569 Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
1570 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1571 },
1572 /*outOfBoundsCase=*/
1573 [&](OpBuilder & /*b*/, Location loc) { return vec; });
1574 b.create<scf::YieldOp>(loc, nextVec);
1575 }
1576
1577 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1578 // Inititalize vector with padding value.
1579 Location loc = xferOp.getLoc();
1580 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1581 xferOp.getPadding());
1582 }
1583};
1584
1585/// Codegen strategy for TransferWriteOp.
1586template <>
1587struct Strategy1d<TransferWriteOp> {
1588 static void generateForLoopBody(OpBuilder &b, Location loc,
1589 TransferWriteOp xferOp, Value iv,
1590 ValueRange /*loopState*/) {
1591 SmallVector<Value, 8> indices;
1592 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1593
1594 // Nothing to do in case of out-of-bounds access.
1595 generateInBoundsCheck(
1596 b, xferOp, iv, dim,
1597 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1598 auto val =
1599 b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1600 b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
1601 });
1602 b.create<scf::YieldOp>(loc);
1603 }
1604
1605 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1606 return Value();
1607 }
1608};
1609
1610/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1611/// necessary in cases where a 1D vector transfer op cannot be lowered into
1612/// vector load/stores due to non-unit strides or broadcasts:
1613///
1614/// * Transfer dimension is not the last memref dimension
1615/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1616/// * Memref has a layout map with non-unit stride on the last dimension
1617///
1618/// This pattern generates IR as follows:
1619///
1620/// 1. Generate a for loop iterating over each vector element.
1621/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1622/// depending on OpTy.
1623///
1624/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1625/// can be generated instead of TransferOp1dConversion. Add such a pattern
1626/// to ConvertVectorToLLVM.
1627///
1628/// E.g.:
1629/// ```
1630/// vector.transfer_write %vec, %A[%a, %b]
1631/// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1632/// : vector<9xf32>, memref<?x?xf32>
1633/// ```
1634/// Is rewritten to approximately the following pseudo-IR:
1635/// ```
1636/// for i = 0 to 9 {
1637/// %t = vector.extractelement %vec[i] : vector<9xf32>
1638/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1639/// }
1640/// ```
1641template <typename OpTy>
1642struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1643 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1644
1645 LogicalResult matchAndRewrite(OpTy xferOp,
1646 PatternRewriter &rewriter) const override {
1647 // TODO: support 0-d corner case.
1648 if (xferOp.getTransferRank() == 0)
1649 return failure();
1650 auto map = xferOp.getPermutationMap();
1651 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1652
1653 if (!memRefType)
1654 return failure();
1655 if (xferOp.getVectorType().getRank() != 1)
1656 return failure();
1657 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1658 return failure(); // Handled by ConvertVectorToLLVM
1659
1660 // Loop bounds, step, state...
1661 Location loc = xferOp.getLoc();
1662 auto vecType = xferOp.getVectorType();
1663 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1664 Value ub =
1665 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1666 if (vecType.isScalable()) {
1667 Value vscale =
1668 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
1669 ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
1670 }
1671 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1672 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1673
1674 // Generate for loop.
1675 rewriter.replaceOpWithNewOp<scf::ForOp>(
1676 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1677 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1678 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1679 });
1680
1681 return success();
1682 }
1683};
1684
1685} // namespace lowering_1_d
1686} // namespace
1687
1688void mlir::populateVectorToSCFConversionPatterns(
1689 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1690 if (options.unroll) {
1691 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1692 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1693 arg: patterns.getContext(), args: options);
1694 } else {
1695 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1696 lowering_n_d::PrepareTransferWriteConversion,
1697 lowering_n_d::TransferOpConversion<TransferReadOp>,
1698 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1699 arg: patterns.getContext(), args: options);
1700 }
1701 if (options.lowerScalable) {
1702 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1703 arg: patterns.getContext(), args: options);
1704 }
1705 if (options.targetRank == 1) {
1706 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1707 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1708 arg: patterns.getContext(), args: options);
1709 }
1710 patterns.add<lowering_n_d::DecomposePrintOpConversion>(arg: patterns.getContext(),
1711 args: options);
1712}
1713
1714namespace {
1715
1716struct ConvertVectorToSCFPass
1717 : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1718 ConvertVectorToSCFPass() = default;
1719 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1720 this->fullUnroll = options.unroll;
1721 this->targetRank = options.targetRank;
1722 this->lowerTensors = options.lowerTensors;
1723 this->lowerScalable = options.lowerScalable;
1724 }
1725
1726 void runOnOperation() override {
1727 VectorTransferToSCFOptions options;
1728 options.unroll = fullUnroll;
1729 options.targetRank = targetRank;
1730 options.lowerTensors = lowerTensors;
1731 options.lowerScalable = lowerScalable;
1732
1733 // Lower permutation maps first.
1734 RewritePatternSet lowerTransferPatterns(&getContext());
1735 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1736 patterns&: lowerTransferPatterns);
1737 (void)applyPatternsGreedily(getOperation(),
1738 std::move(lowerTransferPatterns));
1739
1740 RewritePatternSet patterns(&getContext());
1741 populateVectorToSCFConversionPatterns(patterns, options);
1742 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1743 }
1744};
1745
1746} // namespace
1747
1748std::unique_ptr<Pass>
1749mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1750 return std::make_unique<ConvertVectorToSCFPass>(args: options);
1751}
1752

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp