1//===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector transfer operations to SCF.
10//
11//===----------------------------------------------------------------------===//
12
13#include <numeric>
14#include <optional>
15
16#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
17
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26#include "mlir/IR/Builders.h"
27#include "mlir/IR/ImplicitLocOpBuilder.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_CONVERTVECTORTOSCF
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using vector::TransferReadOp;
38using vector::TransferWriteOp;
39
40namespace {
41
42/// Attribute name used for labeling transfer ops during progressive lowering.
43static const char kPassLabel[] = "__vector_to_scf_lowering__";
44
45/// Return true if this transfer op operates on a source tensor.
46static bool isTensorOp(VectorTransferOpInterface xferOp) {
47 if (isa<RankedTensorType>(Val: xferOp.getShapedType())) {
48 if (isa<vector::TransferWriteOp>(Val: xferOp)) {
49 // TransferWriteOps on tensors have a result.
50 assert(xferOp->getNumResults() > 0);
51 }
52 return true;
53 }
54 return false;
55}
56
57/// Patterns that inherit from this struct have access to
58/// VectorTransferToSCFOptions.
59template <typename OpTy>
60struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
61 explicit VectorToSCFPattern(MLIRContext *context,
62 VectorTransferToSCFOptions opt)
63 : OpRewritePattern<OpTy>(context), options(opt) {}
64
65 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
66 PatternRewriter &rewriter) const {
67 if (isTensorOp(xferOp) && !options.lowerTensors) {
68 return rewriter.notifyMatchFailure(
69 arg&: xferOp, msg: "lowering tensor transfers is disabled");
70 }
71 return success();
72 }
73
74 VectorTransferToSCFOptions options;
75};
76
77/// Given a vector transfer op, calculate which dimension of the `source`
78/// memref should be unpacked in the next application of TransferOpConversion.
79/// A return value of std::nullopt indicates a broadcast.
80template <typename OpTy>
81static std::optional<int64_t> unpackedDim(OpTy xferOp) {
82 // TODO: support 0-d corner case.
83 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
84 auto map = xferOp.getPermutationMap();
85 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
86 return expr.getPosition();
87 }
88 assert(xferOp.isBroadcastDim(0) &&
89 "Expected AffineDimExpr or AffineConstantExpr");
90 return std::nullopt;
91}
92
93/// Compute the permutation map for the new (N-1)-D vector transfer op. This
94/// map is identical to the current permutation map, but the first result is
95/// omitted.
96template <typename OpTy>
97static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
98 // TODO: support 0-d corner case.
99 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
100 auto map = xferOp.getPermutationMap();
101 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
102 b.getContext());
103}
104
105/// Calculate the indices for the new vector transfer op.
106///
107/// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
108/// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
109/// ^^^^^^
110/// `iv` is the iteration variable of the (new) surrounding loop.
111template <typename OpTy>
112static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
113 SmallVector<Value, 8> &indices) {
114 typename OpTy::Adaptor adaptor(xferOp);
115 // Corresponding memref dim of the vector dim that is unpacked.
116 auto dim = unpackedDim(xferOp);
117 auto prevIndices = adaptor.getIndices();
118 indices.append(prevIndices.begin(), prevIndices.end());
119
120 Location loc = xferOp.getLoc();
121 bool isBroadcast = !dim.has_value();
122 if (!isBroadcast) {
123 AffineExpr d0, d1;
124 bindDims(xferOp.getContext(), d0, d1);
125 Value offset = adaptor.getIndices()[*dim];
126 indices[*dim] =
127 affine::makeComposedAffineApply(b, loc, e: d0 + d1, operands: {offset, iv});
128 }
129}
130
131static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
132 Value value) {
133 if (hasRetVal) {
134 assert(value && "Expected non-empty value");
135 b.create<scf::YieldOp>(location: loc, args&: value);
136 } else {
137 b.create<scf::YieldOp>(location: loc);
138 }
139}
140
141/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
142/// is set to true. No such check is generated under following circumstances:
143/// * xferOp does not have a mask.
144/// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
145/// computed and attached to the new transfer op in the pattern.)
146/// * The to-be-unpacked dim of xferOp is a broadcast.
147template <typename OpTy>
148static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
149 if (!xferOp.getMask())
150 return Value();
151 if (xferOp.getMaskType().getRank() != 1)
152 return Value();
153 if (xferOp.isBroadcastDim(0))
154 return Value();
155
156 Location loc = xferOp.getLoc();
157 return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
158}
159
160/// Helper function TransferOpConversion and TransferOp1dConversion.
161/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
162/// specified dimension `dim` with the loop iteration variable `iv`.
163/// E.g., when unpacking dimension 0 from:
164/// ```
165/// %vec = vector.transfer_read %A[%a, %b] %cst
166/// : vector<5x4xf32>, memref<?x?xf32>
167/// ```
168/// An if check similar to this will be generated inside the loop:
169/// ```
170/// %d = memref.dim %A, %c0 : memref<?x?xf32>
171/// if (%a + iv < %d) {
172/// (in-bounds case)
173/// } else {
174/// (out-of-bounds case)
175/// }
176/// ```
177///
178/// If the transfer is 1D and has a mask, this function generates a more complex
179/// check also accounts for potentially masked out elements.
180///
181/// This function variant returns the value returned by `inBoundsCase` or
182/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
183/// `resultTypes`.
184template <typename OpTy>
185static Value generateInBoundsCheck(
186 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
187 TypeRange resultTypes,
188 function_ref<Value(OpBuilder &, Location)> inBoundsCase,
189 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
190 bool hasRetVal = !resultTypes.empty();
191 Value cond; // Condition to be built...
192
193 // Condition check 1: Access in-bounds?
194 bool isBroadcast = !dim; // No in-bounds check for broadcasts.
195 Location loc = xferOp.getLoc();
196 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
197 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
198 Value memrefDim = vector::createOrFoldDimOp(b, loc, source: xferOp.getBase(), dim: *dim);
199 AffineExpr d0, d1;
200 bindDims(xferOp.getContext(), d0, d1);
201 Value base = xferOp.getIndices()[*dim];
202 Value memrefIdx =
203 affine::makeComposedAffineApply(b, loc, e: d0 + d1, operands: {base, iv});
204 cond = lb.create<arith::CmpIOp>(args: arith::CmpIPredicate::sgt, args&: memrefDim,
205 args&: memrefIdx);
206 }
207
208 // Condition check 2: Masked in?
209 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
210 if (cond)
211 cond = lb.create<arith::AndIOp>(cond, maskCond);
212 else
213 cond = maskCond;
214 }
215
216 // If the condition is non-empty, generate an SCF::IfOp.
217 if (cond) {
218 auto check = lb.create<scf::IfOp>(
219 cond,
220 /*thenBuilder=*/
221 [&](OpBuilder &b, Location loc) {
222 maybeYieldValue(b, loc, hasRetVal, value: inBoundsCase(b, loc));
223 },
224 /*elseBuilder=*/
225 [&](OpBuilder &b, Location loc) {
226 if (outOfBoundsCase) {
227 maybeYieldValue(b, loc, hasRetVal, value: outOfBoundsCase(b, loc));
228 } else {
229 b.create<scf::YieldOp>(location: loc);
230 }
231 });
232
233 return hasRetVal ? check.getResult(0) : Value();
234 }
235
236 // Condition is empty, no need for an SCF::IfOp.
237 return inBoundsCase(b, loc);
238}
239
240/// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
241/// a return value. Consequently, this function does not have a return value.
242template <typename OpTy>
243static void generateInBoundsCheck(
244 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
245 function_ref<void(OpBuilder &, Location)> inBoundsCase,
246 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
247 generateInBoundsCheck(
248 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
249 /*inBoundsCase=*/
250 [&](OpBuilder &b, Location loc) {
251 inBoundsCase(b, loc);
252 return Value();
253 },
254 /*outOfBoundsCase=*/
255 [&](OpBuilder &b, Location loc) {
256 if (outOfBoundsCase)
257 outOfBoundsCase(b, loc);
258 return Value();
259 });
260}
261
262/// Given an ArrayAttr, return a copy where the first element is dropped.
263static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
264 if (!attr)
265 return attr;
266 return ArrayAttr::get(context: b.getContext(), value: attr.getValue().drop_front());
267}
268
269/// Add the pass label to a vector transfer op if its rank is not the target
270/// rank.
271template <typename OpTy>
272static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
273 unsigned targetRank) {
274 if (newXferOp.getVectorType().getRank() > targetRank)
275 newXferOp->setAttr(kPassLabel, b.getUnitAttr());
276}
277
278namespace lowering_n_d {
279
280/// Helper data structure for data and mask buffers.
281struct BufferAllocs {
282 Value dataBuffer;
283 Value maskBuffer;
284};
285
286// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
287static Operation *getAutomaticAllocationScope(Operation *op) {
288 Operation *scope =
289 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
290 assert(scope && "Expected op to be inside automatic allocation scope");
291 return scope;
292}
293
294/// Allocate temporary buffers for data (vector) and mask (if present).
295template <typename OpTy>
296static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
297 Location loc = xferOp.getLoc();
298 OpBuilder::InsertionGuard guard(b);
299 Operation *scope = getAutomaticAllocationScope(xferOp);
300 assert(scope->getNumRegions() == 1 &&
301 "AutomaticAllocationScope with >1 regions");
302 b.setInsertionPointToStart(&scope->getRegion(index: 0).front());
303
304 BufferAllocs result;
305 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
306 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
307
308 if (xferOp.getMask()) {
309 auto maskType = MemRefType::get({}, xferOp.getMask().getType());
310 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
311 b.setInsertionPoint(xferOp);
312 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
313 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
314 }
315
316 return result;
317}
318
319/// Given a MemRefType with VectorType element type, unpack one dimension from
320/// the VectorType into the MemRefType.
321///
322/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
323static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
324 auto vectorType = dyn_cast<VectorType>(Val: type.getElementType());
325 // Vectors with leading scalable dims are not supported.
326 // It may be possible to support these in future by using dynamic memref dims.
327 if (vectorType.getScalableDims().front())
328 return failure();
329 auto memrefShape = type.getShape();
330 SmallVector<int64_t, 8> newMemrefShape;
331 newMemrefShape.append(in_start: memrefShape.begin(), in_end: memrefShape.end());
332 newMemrefShape.push_back(Elt: vectorType.getDimSize(idx: 0));
333 return MemRefType::get(shape: newMemrefShape,
334 elementType: VectorType::Builder(vectorType).dropDim(pos: 0));
335}
336
337/// Given a transfer op, find the memref from which the mask is loaded. This
338/// is similar to Strategy<TransferWriteOp>::getBuffer.
339template <typename OpTy>
340static Value getMaskBuffer(OpTy xferOp) {
341 assert(xferOp.getMask() && "Expected that transfer op has mask");
342 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
343 assert(loadOp && "Expected transfer op mask produced by LoadOp");
344 return loadOp.getMemRef();
345}
346
347/// Codegen strategy, depending on the operation.
348template <typename OpTy>
349struct Strategy;
350
351/// Code strategy for vector TransferReadOp.
352template <>
353struct Strategy<TransferReadOp> {
354 /// Find the StoreOp that is used for writing the current TransferReadOp's
355 /// result to the temporary buffer allocation.
356 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
357 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
358 auto storeOp = dyn_cast<memref::StoreOp>(Val: (*xferOp->use_begin()).getOwner());
359 assert(storeOp && "Expected TransferReadOp result used by StoreOp");
360 return storeOp;
361 }
362
363 /// Find the temporary buffer allocation. All labeled TransferReadOps are
364 /// used like this, where %buf is either the buffer allocation or a type cast
365 /// of the buffer allocation:
366 /// ```
367 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
368 /// memref.store %vec, %buf[...] ...
369 /// ```
370 static Value getBuffer(TransferReadOp xferOp) {
371 return getStoreOp(xferOp).getMemRef();
372 }
373
374 /// Retrieve the indices of the current StoreOp that stores into the buffer.
375 static void getBufferIndices(TransferReadOp xferOp,
376 SmallVector<Value, 8> &indices) {
377 auto storeOp = getStoreOp(xferOp);
378 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
379 indices.append(in_start: prevIndices.begin(), in_end: prevIndices.end());
380 }
381
382 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
383 /// accesses on the to-be-unpacked dimension.
384 ///
385 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
386 /// variable `iv`.
387 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
388 ///
389 /// E.g.:
390 /// ```
391 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
392 /// : memref<?x?x?xf32>, vector<4x3xf32>
393 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
394 /// ```
395 /// Is rewritten to:
396 /// ```
397 /// %casted = vector.type_cast %buf
398 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
399 /// for %j = 0 to 4 {
400 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
401 /// : memref<?x?x?xf32>, vector<3xf32>
402 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
403 /// }
404 /// ```
405 ///
406 /// Note: The loop and type cast are generated in TransferOpConversion.
407 /// The original TransferReadOp and store op are deleted in `cleanup`.
408 /// Note: The `mask` operand is set in TransferOpConversion.
409 static TransferReadOp rewriteOp(OpBuilder &b,
410 VectorTransferToSCFOptions options,
411 TransferReadOp xferOp, Value buffer, Value iv,
412 ValueRange /*loopState*/) {
413 SmallVector<Value, 8> storeIndices;
414 getBufferIndices(xferOp, indices&: storeIndices);
415 storeIndices.push_back(Elt: iv);
416
417 SmallVector<Value, 8> xferIndices;
418 getXferIndices(b, xferOp, iv, indices&: xferIndices);
419
420 Location loc = xferOp.getLoc();
421 auto bufferType = dyn_cast<ShapedType>(Val: buffer.getType());
422 auto vecType = dyn_cast<VectorType>(Val: bufferType.getElementType());
423 auto inBoundsAttr = dropFirstElem(b, attr: xferOp.getInBoundsAttr());
424 auto newXferOp = b.create<vector::TransferReadOp>(
425 location: loc, args&: vecType, args: xferOp.getBase(), args&: xferIndices,
426 args: AffineMapAttr::get(value: unpackedPermutationMap(b, xferOp)),
427 args: xferOp.getPadding(), args: Value(), args&: inBoundsAttr);
428
429 maybeApplyPassLabel(b, newXferOp, targetRank: options.targetRank);
430
431 b.create<memref::StoreOp>(location: loc, args: newXferOp.getVector(), args&: buffer, args&: storeIndices);
432 return newXferOp;
433 }
434
435 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
436 /// padding value to the temporary buffer.
437 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
438 Value buffer, Value iv,
439 ValueRange /*loopState*/) {
440 SmallVector<Value, 8> storeIndices;
441 getBufferIndices(xferOp, indices&: storeIndices);
442 storeIndices.push_back(Elt: iv);
443
444 Location loc = xferOp.getLoc();
445 auto bufferType = dyn_cast<ShapedType>(Val: buffer.getType());
446 auto vecType = dyn_cast<VectorType>(Val: bufferType.getElementType());
447 auto vec = b.create<vector::SplatOp>(location: loc, args&: vecType, args: xferOp.getPadding());
448 b.create<memref::StoreOp>(location: loc, args&: vec, args&: buffer, args&: storeIndices);
449
450 return Value();
451 }
452
453 /// Cleanup after rewriting the op.
454 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
455 scf::ForOp /*forOp*/) {
456 rewriter.eraseOp(op: getStoreOp(xferOp));
457 rewriter.eraseOp(op: xferOp);
458 }
459
460 /// Return the initial loop state for the generated scf.for loop.
461 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
462};
463
464/// Codegen strategy for vector TransferWriteOp.
465template <>
466struct Strategy<TransferWriteOp> {
467 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
468 /// used like this, where %buf is either the buffer allocation or a type cast
469 /// of the buffer allocation:
470 /// ```
471 /// %vec = memref.load %buf[...] ...
472 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
473 /// ```
474 static Value getBuffer(TransferWriteOp xferOp) {
475 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
476 assert(loadOp && "Expected transfer op vector produced by LoadOp");
477 return loadOp.getMemRef();
478 }
479
480 /// Retrieve the indices of the current LoadOp that loads from the buffer.
481 static void getBufferIndices(TransferWriteOp xferOp,
482 SmallVector<Value, 8> &indices) {
483 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
484 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
485 indices.append(in_start: prevIndices.begin(), in_end: prevIndices.end());
486 }
487
488 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
489 /// accesses on the to-be-unpacked dimension.
490 ///
491 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
492 /// using the loop iteration variable `iv`.
493 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
494 /// to memory.
495 ///
496 /// Note: For more details, see comments on Strategy<TransferReadOp>.
497 static TransferWriteOp rewriteOp(OpBuilder &b,
498 VectorTransferToSCFOptions options,
499 TransferWriteOp xferOp, Value buffer,
500 Value iv, ValueRange loopState) {
501 SmallVector<Value, 8> loadIndices;
502 getBufferIndices(xferOp, indices&: loadIndices);
503 loadIndices.push_back(Elt: iv);
504
505 SmallVector<Value, 8> xferIndices;
506 getXferIndices(b, xferOp, iv, indices&: xferIndices);
507
508 Location loc = xferOp.getLoc();
509 auto vec = b.create<memref::LoadOp>(location: loc, args&: buffer, args&: loadIndices);
510 auto inBoundsAttr = dropFirstElem(b, attr: xferOp.getInBoundsAttr());
511 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
512 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
513 auto newXferOp = b.create<vector::TransferWriteOp>(
514 location: loc, args&: type, args&: vec, args&: source, args&: xferIndices,
515 args: AffineMapAttr::get(value: unpackedPermutationMap(b, xferOp)), args: Value(),
516 args&: inBoundsAttr);
517
518 maybeApplyPassLabel(b, newXferOp, targetRank: options.targetRank);
519
520 return newXferOp;
521 }
522
523 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
524 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
525 Value buffer, Value iv,
526 ValueRange loopState) {
527 return isTensorOp(xferOp) ? loopState[0] : Value();
528 }
529
530 /// Cleanup after rewriting the op.
531 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
532 scf::ForOp forOp) {
533 if (isTensorOp(xferOp)) {
534 assert(forOp->getNumResults() == 1 && "Expected one for loop result");
535 rewriter.replaceOp(op: xferOp, newValues: forOp->getResult(idx: 0));
536 } else {
537 rewriter.eraseOp(op: xferOp);
538 }
539 }
540
541 /// Return the initial loop state for the generated scf.for loop.
542 static Value initialLoopState(TransferWriteOp xferOp) {
543 return isTensorOp(xferOp) ? xferOp.getBase() : Value();
544 }
545};
546
547template <typename OpTy>
548static LogicalResult checkPrepareXferOp(OpTy xferOp, PatternRewriter &rewriter,
549 VectorTransferToSCFOptions options) {
550 if (xferOp->hasAttr(kPassLabel))
551 return rewriter.notifyMatchFailure(
552 xferOp, "kPassLabel is present (vector-to-scf lowering in progress)");
553 if (xferOp.getVectorType().getRank() <= options.targetRank)
554 return rewriter.notifyMatchFailure(
555 xferOp, "xferOp vector rank <= transformation target rank");
556 if (xferOp.getVectorType().getScalableDims().front())
557 return rewriter.notifyMatchFailure(
558 xferOp, "Unpacking of the leading dimension into the memref is not yet "
559 "supported for scalable dims");
560 if (isTensorOp(xferOp) && !options.lowerTensors)
561 return rewriter.notifyMatchFailure(
562 xferOp, "Unpacking for tensors has been disabled.");
563 if (xferOp.getVectorType().getElementType() !=
564 xferOp.getShapedType().getElementType())
565 return rewriter.notifyMatchFailure(
566 xferOp, "Mismatching source and destination element types.");
567
568 return success();
569}
570
571/// Prepare a TransferReadOp for progressive lowering.
572///
573/// 1. Allocate a temporary buffer.
574/// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
575/// 3. Store the result of the TransferReadOp into the temporary buffer.
576/// 4. Load the result from the temporary buffer and replace all uses of the
577/// original TransferReadOp with this load.
578///
579/// E.g.:
580/// ```
581/// %vec = vector.transfer_read %A[%a, %b, %c], %cst
582/// : vector<5x4xf32>, memref<?x?x?xf32>
583/// ```
584/// is rewritten to:
585/// ```
586/// %0 = memref.alloca() : memref<vector<5x4xf32>>
587/// %1 = vector.transfer_read %A[%a, %b, %c], %cst
588/// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
589/// memref.store %1, %0[] : memref<vector<5x4xf32>>
590/// %vec = memref.load %0[] : memref<vector<5x4xf32>>
591/// ```
592///
593/// Note: A second temporary buffer may be allocated for the `mask` operand.
594struct PrepareTransferReadConversion
595 : public VectorToSCFPattern<TransferReadOp> {
596 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
597
598 LogicalResult matchAndRewrite(TransferReadOp xferOp,
599 PatternRewriter &rewriter) const override {
600 if (checkPrepareXferOp(xferOp, rewriter, options).failed())
601 return rewriter.notifyMatchFailure(
602 arg&: xferOp, msg: "checkPrepareXferOp conditions not met!");
603
604 auto buffers = allocBuffers(b&: rewriter, xferOp);
605 auto *newXfer = rewriter.clone(op&: *xferOp.getOperation());
606 newXfer->setAttr(name: kPassLabel, value: rewriter.getUnitAttr());
607 if (xferOp.getMask()) {
608 dyn_cast<TransferReadOp>(Val: newXfer).getMaskMutable().assign(
609 value: buffers.maskBuffer);
610 }
611
612 Location loc = xferOp.getLoc();
613 rewriter.create<memref::StoreOp>(location: loc, args: newXfer->getResult(idx: 0),
614 args&: buffers.dataBuffer);
615 rewriter.replaceOpWithNewOp<memref::LoadOp>(op: xferOp, args&: buffers.dataBuffer);
616
617 return success();
618 }
619};
620
621/// Prepare a TransferWriteOp for progressive lowering.
622///
623/// 1. Allocate a temporary buffer.
624/// 2. Store the vector into the buffer.
625/// 3. Load the vector from the buffer again.
626/// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
627/// marking it eligible for progressive lowering via TransferOpConversion.
628///
629/// E.g.:
630/// ```
631/// vector.transfer_write %vec, %A[%a, %b, %c]
632/// : vector<5x4xf32>, memref<?x?x?xf32>
633/// ```
634/// is rewritten to:
635/// ```
636/// %0 = memref.alloca() : memref<vector<5x4xf32>>
637/// memref.store %vec, %0[] : memref<vector<5x4xf32>>
638/// %1 = memref.load %0[] : memref<vector<5x4xf32>>
639/// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
640/// : vector<5x4xf32>, memref<?x?x?xf32>
641/// ```
642///
643/// Note: A second temporary buffer may be allocated for the `mask` operand.
644struct PrepareTransferWriteConversion
645 : public VectorToSCFPattern<TransferWriteOp> {
646 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
647
648 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
649 PatternRewriter &rewriter) const override {
650 if (checkPrepareXferOp(xferOp, rewriter, options).failed())
651 return rewriter.notifyMatchFailure(
652 arg&: xferOp, msg: "checkPrepareXferOp conditions not met!");
653
654 Location loc = xferOp.getLoc();
655 auto buffers = allocBuffers(b&: rewriter, xferOp);
656 rewriter.create<memref::StoreOp>(location: loc, args: xferOp.getVector(),
657 args&: buffers.dataBuffer);
658 auto loadedVec = rewriter.create<memref::LoadOp>(location: loc, args&: buffers.dataBuffer);
659 rewriter.modifyOpInPlace(root: xferOp, callable: [&]() {
660 xferOp.getValueToStoreMutable().assign(value: loadedVec);
661 xferOp->setAttr(name: kPassLabel, value: rewriter.getUnitAttr());
662 });
663
664 if (xferOp.getMask()) {
665 rewriter.modifyOpInPlace(root: xferOp, callable: [&]() {
666 xferOp.getMaskMutable().assign(value: buffers.maskBuffer);
667 });
668 }
669
670 return success();
671 }
672};
673
674/// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
675/// printing both 1D scalable vectors and n-D fixed size vectors.
676///
677/// E.g.:
678/// ```
679/// vector.print %v : vector<[4]xi32>
680/// ```
681/// is rewritten to:
682/// ```
683/// %c0 = arith.constant 0 : index
684/// %c4 = arith.constant 4 : index
685/// %c1 = arith.constant 1 : index
686/// %vscale = vector.vscale
687/// %length = arith.muli %vscale, %c4 : index
688/// %lastIndex = arith.subi %length, %c1 : index
689/// vector.print punctuation <open>
690/// scf.for %i = %c0 to %length step %c1 {
691/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
692/// vector.print %el : i32 punctuation <no_punctuation>
693/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
694/// scf.if %notLastIndex {
695/// vector.print punctuation <comma>
696/// }
697/// }
698/// vector.print punctuation <close>
699/// vector.print
700/// ```
701struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
702 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
703 LogicalResult matchAndRewrite(vector::PrintOp printOp,
704 PatternRewriter &rewriter) const override {
705 if (!printOp.getSource())
706 return failure();
707
708 VectorType vectorType = dyn_cast<VectorType>(Val: printOp.getPrintType());
709 if (!vectorType)
710 return failure();
711
712 // Currently >= 2D scalable vectors are not supported.
713 // These can't be lowered to LLVM (as LLVM does not support scalable vectors
714 // of scalable vectors), and due to limitations of current ops can't be
715 // indexed with SSA values or flattened. This may change after
716 // https://reviews.llvm.org/D155034, though there still needs to be a path
717 // for lowering to LLVM.
718 if (vectorType.getRank() > 1 && vectorType.isScalable())
719 return failure();
720
721 auto loc = printOp.getLoc();
722 auto value = printOp.getSource();
723
724 if (auto intTy = dyn_cast<IntegerType>(Val: vectorType.getElementType())) {
725 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
726 // avoid issues extend them to a more standard size.
727 // https://github.com/llvm/llvm-project/issues/30613
728 auto width = intTy.getWidth();
729 auto legalWidth = llvm::NextPowerOf2(A: std::max(a: 8u, b: width) - 1);
730 auto legalIntTy = IntegerType::get(context: rewriter.getContext(), width: legalWidth,
731 signedness: intTy.getSignedness());
732 // arith can only take signless integers, so we must cast back and forth.
733 auto signlessSourceVectorType =
734 vectorType.cloneWith(shape: {}, elementType: getIntTypeWithSignlessSemantics(intTy));
735 auto signlessTargetVectorType =
736 vectorType.cloneWith(shape: {}, elementType: getIntTypeWithSignlessSemantics(intTy: legalIntTy));
737 auto targetVectorType = vectorType.cloneWith(shape: {}, elementType: legalIntTy);
738 value = rewriter.create<vector::BitCastOp>(location: loc, args&: signlessSourceVectorType,
739 args&: value);
740 if (value.getType() != signlessTargetVectorType) {
741 if (width == 1 || intTy.isUnsigned())
742 value = rewriter.create<arith::ExtUIOp>(location: loc, args&: signlessTargetVectorType,
743 args&: value);
744 else
745 value = rewriter.create<arith::ExtSIOp>(location: loc, args&: signlessTargetVectorType,
746 args&: value);
747 }
748 value = rewriter.create<vector::BitCastOp>(location: loc, args&: targetVectorType, args&: value);
749 vectorType = targetVectorType;
750 }
751
752 auto scalableDimensions = vectorType.getScalableDims();
753 auto shape = vectorType.getShape();
754 constexpr int64_t singletonShape[] = {1};
755 if (vectorType.getRank() == 0)
756 shape = singletonShape;
757
758 if (vectorType.getRank() != 1) {
759 // Flatten n-D vectors to 1D. This is done to allow indexing with a
760 // non-constant value.
761 auto flatLength = std::accumulate(first: shape.begin(), last: shape.end(), init: 1,
762 binary_op: std::multiplies<int64_t>());
763 auto flatVectorType =
764 VectorType::get(shape: {flatLength}, elementType: vectorType.getElementType());
765 value = rewriter.create<vector::ShapeCastOp>(location: loc, args&: flatVectorType, args&: value);
766 }
767
768 vector::PrintOp firstClose;
769 SmallVector<Value, 8> loopIndices;
770 for (unsigned d = 0; d < shape.size(); d++) {
771 // Setup loop bounds and step.
772 Value lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
773 Value upperBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: shape[d]);
774 Value step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
775 if (!scalableDimensions.empty() && scalableDimensions[d]) {
776 auto vscale = rewriter.create<vector::VectorScaleOp>(
777 location: loc, args: rewriter.getIndexType());
778 upperBound = rewriter.create<arith::MulIOp>(location: loc, args&: upperBound, args&: vscale);
779 }
780 auto lastIndex = rewriter.create<arith::SubIOp>(location: loc, args&: upperBound, args&: step);
781
782 // Create a loop to print the elements surrounded by parentheses.
783 rewriter.create<vector::PrintOp>(location: loc, args: vector::PrintPunctuation::Open);
784 auto loop =
785 rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step);
786 auto printClose = rewriter.create<vector::PrintOp>(
787 location: loc, args: vector::PrintPunctuation::Close);
788 if (!firstClose)
789 firstClose = printClose;
790
791 auto loopIdx = loop.getInductionVar();
792 loopIndices.push_back(Elt: loopIdx);
793
794 // Print a comma after all but the last element.
795 rewriter.setInsertionPointToStart(loop.getBody());
796 auto notLastIndex = rewriter.create<arith::CmpIOp>(
797 location: loc, args: arith::CmpIPredicate::ult, args&: loopIdx, args&: lastIndex);
798 rewriter.create<scf::IfOp>(location: loc, args&: notLastIndex,
799 args: [&](OpBuilder &builder, Location loc) {
800 builder.create<vector::PrintOp>(
801 location: loc, args: vector::PrintPunctuation::Comma);
802 builder.create<scf::YieldOp>(location: loc);
803 });
804
805 rewriter.setInsertionPointToStart(loop.getBody());
806 }
807
808 // Compute the flattened index.
809 // Note: For the > rank 1 vectors this assumes non-scalable.
810 Value flatIndex;
811 auto currentStride = 1;
812 for (int d = shape.size() - 1; d >= 0; d--) {
813 auto stride = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: currentStride);
814 auto index = rewriter.create<arith::MulIOp>(location: loc, args&: stride, args&: loopIndices[d]);
815 if (flatIndex)
816 flatIndex = rewriter.create<arith::AddIOp>(location: loc, args&: flatIndex, args&: index);
817 else
818 flatIndex = index;
819 currentStride *= shape[d];
820 }
821
822 // Print the scalar elements in the inner most loop.
823 auto element = rewriter.create<vector::ExtractOp>(location: loc, args&: value, args&: flatIndex);
824 rewriter.create<vector::PrintOp>(location: loc, args&: element,
825 args: vector::PrintPunctuation::NoPunctuation);
826
827 rewriter.setInsertionPointAfter(firstClose);
828 rewriter.create<vector::PrintOp>(location: loc, args: printOp.getPunctuation());
829 rewriter.eraseOp(op: printOp);
830 return success();
831 }
832
833 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
834 return IntegerType::get(context: intTy.getContext(), width: intTy.getWidth(),
835 signedness: IntegerType::Signless);
836 };
837};
838
839/// Progressive lowering of vector transfer ops: Unpack one dimension.
840///
841/// 1. Unpack one dimension from the current buffer type and cast the buffer
842/// to that new type. E.g.:
843/// ```
844/// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
845/// vector.transfer_write %vec ...
846/// ```
847/// The following cast is generated:
848/// ```
849/// %casted = vector.type_cast %0
850/// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
851/// ```
852/// 2. Generate a for loop and rewrite the transfer op according to the
853/// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
854/// out-of-bounds, generate an if-check and handle both cases separately.
855/// 3. Clean up according to the corresponding Strategy<OpTy>.
856///
857/// Note: If the transfer op is a TransferWriteOp and operates on a tensor
858/// source (as opposed to a memref source), then each iteration of the generated
859/// scf.for loop yields the new tensor value. E.g.:
860/// ```
861/// %result = scf.for i = 0 to 5 {
862/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
863/// %1 = vector.transfer_write %0, %source[...]
864/// : vector<4x3xf32>, tensor<5x4x3xf32>
865/// scf.yield %1 : tensor<5x4x3xf32>
866/// }
867/// ```
868template <typename OpTy>
869struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
870 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
871
872 void initialize() {
873 // This pattern recursively unpacks one dimension at a time. The recursion
874 // bounded as the rank is strictly decreasing.
875 this->setHasBoundedRewriteRecursion();
876 }
877
878 static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
879 SmallVectorImpl<Value> &loadIndices,
880 Value iv) {
881 assert(xferOp.getMask() && "Expected transfer op to have mask");
882
883 // Add load indices from the previous iteration.
884 // The mask buffer depends on the permutation map, which makes determining
885 // the indices quite complex, so this is why we need to "look back" to the
886 // previous iteration to find the right indices.
887 Value maskBuffer = getMaskBuffer(xferOp);
888 for (Operation *user : maskBuffer.getUsers()) {
889 // If there is no previous load op, then the indices are empty.
890 if (auto loadOp = dyn_cast<memref::LoadOp>(Val: user)) {
891 Operation::operand_range prevIndices = loadOp.getIndices();
892 loadIndices.append(in_start: prevIndices.begin(), in_end: prevIndices.end());
893 break;
894 }
895 }
896
897 // In case of broadcast: Use same indices to load from memref
898 // as before.
899 if (!xferOp.isBroadcastDim(0))
900 loadIndices.push_back(Elt: iv);
901 }
902
903 LogicalResult matchAndRewrite(OpTy xferOp,
904 PatternRewriter &rewriter) const override {
905 if (!xferOp->hasAttr(kPassLabel))
906 return rewriter.notifyMatchFailure(
907 xferOp, "kPassLabel is present (progressing lowering in progress)");
908
909 // Find and cast data buffer. How the buffer can be found depends on OpTy.
910 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
911 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
912 auto dataBufferType = dyn_cast<MemRefType>(Val: dataBuffer.getType());
913 FailureOr<MemRefType> castedDataType = unpackOneDim(type: dataBufferType);
914 if (failed(Result: castedDataType))
915 return rewriter.notifyMatchFailure(xferOp,
916 "Failed to unpack one vector dim.");
917
918 auto castedDataBuffer =
919 locB.create<vector::TypeCastOp>(args&: *castedDataType, args&: dataBuffer);
920
921 // If the xferOp has a mask: Find and cast mask buffer.
922 Value castedMaskBuffer;
923 if (xferOp.getMask()) {
924 Value maskBuffer = getMaskBuffer(xferOp);
925 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
926 // Do not unpack a dimension of the mask, if:
927 // * To-be-unpacked transfer op dimension is a broadcast.
928 // * Mask is 1D, i.e., the mask cannot be further unpacked.
929 // (That means that all remaining dimensions of the transfer op must
930 // be broadcasted.)
931 castedMaskBuffer = maskBuffer;
932 } else {
933 // It's safe to assume the mask buffer can be unpacked if the data
934 // buffer was unpacked.
935 auto maskBufferType = cast<MemRefType>(Val: maskBuffer.getType());
936 MemRefType castedMaskType = *unpackOneDim(type: maskBufferType);
937 castedMaskBuffer =
938 locB.create<vector::TypeCastOp>(args&: castedMaskType, args&: maskBuffer);
939 }
940 }
941
942 // Loop bounds and step.
943 auto lb = locB.create<arith::ConstantIndexOp>(args: 0);
944 auto ub = locB.create<arith::ConstantIndexOp>(
945 args: castedDataType->getDimSize(idx: castedDataType->getRank() - 1));
946 auto step = locB.create<arith::ConstantIndexOp>(args: 1);
947 // TransferWriteOps that operate on tensors return the modified tensor and
948 // require a loop state.
949 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
950
951 // Generate for loop.
952 auto result = locB.create<scf::ForOp>(
953 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
954 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
955 Type stateType = loopState.empty() ? Type() : loopState[0].getType();
956
957 auto result = generateInBoundsCheck(
958 b, xferOp, iv, unpackedDim(xferOp),
959 stateType ? TypeRange(stateType) : TypeRange(),
960 /*inBoundsCase=*/
961 [&](OpBuilder &b, Location loc) {
962 // Create new transfer op.
963 OpTy newXfer = Strategy<OpTy>::rewriteOp(
964 b, this->options, xferOp, castedDataBuffer, iv, loopState);
965
966 // If old transfer op has a mask: Set mask on new transfer op.
967 // Special case: If the mask of the old transfer op is 1D and
968 // the unpacked dim is not a broadcast, no mask is needed on
969 // the new transfer op.
970 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
971 xferOp.getMaskType().getRank() > 1)) {
972 OpBuilder::InsertionGuard guard(b);
973 b.setInsertionPoint(newXfer); // Insert load before newXfer.
974
975 SmallVector<Value, 8> loadIndices;
976 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
977 loadIndices, iv);
978 auto mask = b.create<memref::LoadOp>(location: loc, args&: castedMaskBuffer,
979 args&: loadIndices);
980 rewriter.modifyOpInPlace(newXfer, [&]() {
981 newXfer.getMaskMutable().assign(mask);
982 });
983 }
984
985 return loopState.empty() ? Value() : newXfer->getResult(0);
986 },
987 /*outOfBoundsCase=*/
988 [&](OpBuilder &b, Location /*loc*/) {
989 return Strategy<OpTy>::handleOutOfBoundsDim(
990 b, xferOp, castedDataBuffer, iv, loopState);
991 });
992
993 maybeYieldValue(b, loc, !loopState.empty(), result);
994 });
995
996 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
997 return success();
998 }
999};
1000
1001/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1002/// and ConstantMaskOp.
1003template <typename VscaleConstantBuilder>
1004static FailureOr<SmallVector<OpFoldResult>>
1005getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1006 if (!mask)
1007 return SmallVector<OpFoldResult>{};
1008 if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1009 return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1010 return OpFoldResult(dimSize);
1011 });
1012 }
1013 if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1014 int dimIdx = 0;
1015 VectorType maskType = constantMask.getVectorType();
1016 auto indexType = IndexType::get(context: mask.getContext());
1017 return llvm::map_to_vector(
1018 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1019 // A scalable dim in a constant_mask means vscale x dimSize.
1020 if (maskType.getScalableDims()[dimIdx++])
1021 return OpFoldResult(createVscaleMultiple(dimSize));
1022 return OpFoldResult(IntegerAttr::get(type: indexType, value: dimSize));
1023 });
1024 }
1025 return failure();
1026}
1027
1028/// Scalable vector lowering of transfer_write(transpose). This lowering only
1029/// supports rank 2 (scalable) vectors, but can be used in conjunction with
1030/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1031/// unrolls until the first scalable dimension.
1032///
1033/// Example:
1034///
1035/// BEFORE:
1036/// ```mlir
1037/// %transpose = vector.transpose %vec, [1, 0]
1038/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1039/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1040/// : vector<[4]x4xf32>, memref<?x?xf32>
1041/// ```
1042///
1043/// AFTER:
1044/// ```mlir
1045/// %c1 = arith.constant 1 : index
1046/// %c4 = arith.constant 4 : index
1047/// %c0 = arith.constant 0 : index
1048/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1049/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1050/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1051/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1052/// %vscale = vector.vscale
1053/// %c4_vscale = arith.muli %vscale, %c4 : index
1054/// scf.for %idx = %c0 to %c4_vscale step %c1 {
1055/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1056/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1057/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1058/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1059/// %slice_i = affine.apply #map(%idx)[%i]
1060/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1061/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1062/// : vector<4xf32>, memref<?x?xf32>
1063/// }
1064/// ```
1065struct ScalableTransposeTransferWriteConversion
1066 : VectorToSCFPattern<vector::TransferWriteOp> {
1067 using VectorToSCFPattern::VectorToSCFPattern;
1068
1069 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1070 PatternRewriter &rewriter) const override {
1071 if (failed(Result: checkLowerTensors(xferOp: writeOp, rewriter)))
1072 return failure();
1073
1074 VectorType vectorType = writeOp.getVectorType();
1075
1076 // Note: By comparing the scalable dims to an ArrayRef of length two this
1077 // implicitly checks the rank (is also two).
1078 ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1079 if (scalableFlags != ArrayRef<bool>{true, false}) {
1080 return rewriter.notifyMatchFailure(
1081 arg&: writeOp, msg: "expected vector of the form vector<[N]xMxty>");
1082 }
1083
1084 auto permutationMap = writeOp.getPermutationMap();
1085 if (!permutationMap.isIdentity()) {
1086 return rewriter.notifyMatchFailure(
1087 arg&: writeOp, msg: "non-identity permutations are unsupported (lower first)");
1088 }
1089
1090 // Note: This pattern is only lowering the leading dimension (to a loop),
1091 // so we only check if the leading dimension is in bounds. The in-bounds
1092 // attribute for the trailing dimension will be propagated.
1093 if (!writeOp.isDimInBounds(dim: 0)) {
1094 return rewriter.notifyMatchFailure(
1095 arg&: writeOp, msg: "out-of-bounds dims are unsupported (use masking)");
1096 }
1097
1098 Value vector = writeOp.getVector();
1099 auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1100 if (!transposeOp ||
1101 transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1102 return rewriter.notifyMatchFailure(arg&: writeOp, msg: "source not transpose");
1103 }
1104
1105 auto loc = writeOp.getLoc();
1106 auto createVscaleMultiple =
1107 vector::makeVscaleConstantBuilder(rewriter, loc);
1108
1109 auto maskDims = getMaskDimSizes(mask: writeOp.getMask(), createVscaleMultiple);
1110 if (failed(Result: maskDims)) {
1111 return rewriter.notifyMatchFailure(arg&: writeOp,
1112 msg: "failed to resolve mask dims");
1113 }
1114
1115 int64_t fixedDimSize = vectorType.getDimSize(idx: 1);
1116 auto fixedDimOffsets = llvm::seq(Size: fixedDimSize);
1117
1118 // Extract all slices from the source of the transpose.
1119 auto transposeSource = transposeOp.getVector();
1120 SmallVector<Value> transposeSourceSlices =
1121 llvm::map_to_vector(C&: fixedDimOffsets, F: [&](int64_t idx) -> Value {
1122 return rewriter.create<vector::ExtractOp>(location: loc, args&: transposeSource, args&: idx);
1123 });
1124
1125 // Loop bounds and step.
1126 auto lb = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1127 auto ub =
1128 maskDims->empty()
1129 ? Value(createVscaleMultiple(vectorType.getDimSize(idx: 0)))
1130 : vector::getAsValues(builder&: rewriter, loc, foldResults: maskDims->front()).front();
1131 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
1132
1133 // Generate a new mask for the slice.
1134 VectorType sliceType = VectorType::Builder(vectorType).dropDim(pos: 0);
1135 Value sliceMask = nullptr;
1136 if (!maskDims->empty()) {
1137 sliceMask = rewriter.create<vector::CreateMaskOp>(
1138 location: loc, args: sliceType.clone(elementType: rewriter.getI1Type()),
1139 args: ArrayRef<OpFoldResult>(*maskDims).drop_front());
1140 }
1141
1142 Value initDest = isTensorOp(xferOp: writeOp) ? writeOp.getBase() : Value{};
1143 ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1144 auto result = rewriter.create<scf::ForOp>(
1145 location: loc, args&: lb, args&: ub, args&: step, args&: initLoopArgs,
1146 args: [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1147 // Indices for the new transfer op.
1148 SmallVector<Value, 8> xferIndices;
1149 getXferIndices(b, xferOp: writeOp, iv, indices&: xferIndices);
1150
1151 // Extract a transposed slice from the source vector.
1152 SmallVector<Value> transposeElements =
1153 llvm::map_to_vector(C&: fixedDimOffsets, F: [&](int64_t idx) -> Value {
1154 return b.create<vector::ExtractOp>(
1155 location: loc, args&: transposeSourceSlices[idx], args&: iv);
1156 });
1157 auto sliceVec = b.create<vector::FromElementsOp>(location: loc, args&: sliceType,
1158 args&: transposeElements);
1159
1160 // Create the transfer_write for the slice.
1161 Value dest =
1162 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1163 auto newWriteOp = b.create<vector::TransferWriteOp>(
1164 location: loc, args&: sliceVec, args&: dest, args&: xferIndices,
1165 args: ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1166 if (sliceMask)
1167 newWriteOp.getMaskMutable().assign(value: sliceMask);
1168
1169 // Yield from the loop.
1170 b.create<scf::YieldOp>(location: loc, args: loopIterArgs.empty()
1171 ? ValueRange{}
1172 : newWriteOp.getResult());
1173 });
1174
1175 if (isTensorOp(xferOp: writeOp))
1176 rewriter.replaceOp(op: writeOp, newOp: result);
1177 else
1178 rewriter.eraseOp(op: writeOp);
1179
1180 return success();
1181 }
1182};
1183
1184} // namespace lowering_n_d
1185
1186namespace lowering_n_d_unrolled {
1187
1188/// If the original transfer op has a mask, compute the mask of the new transfer
1189/// op (for the current iteration `i`) and assign it.
1190template <typename OpTy>
1191static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1192 int64_t i) {
1193 if (!xferOp.getMask())
1194 return;
1195
1196 if (xferOp.isBroadcastDim(0)) {
1197 // To-be-unpacked dimension is a broadcast, which does not have a
1198 // corresponding mask dimension. Mask attribute remains unchanged.
1199 newXferOp.getMaskMutable().assign(xferOp.getMask());
1200 return;
1201 }
1202
1203 if (xferOp.getMaskType().getRank() > 1) {
1204 // Unpack one dimension of the mask.
1205 OpBuilder::InsertionGuard guard(b);
1206 b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1207
1208 llvm::SmallVector<int64_t, 1> indices({i});
1209 Location loc = xferOp.getLoc();
1210 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1211 newXferOp.getMaskMutable().assign(newMask);
1212 }
1213
1214 // If we end up here: The mask of the old transfer op is 1D and the unpacked
1215 // dim is not a broadcast, so no mask is needed on the new transfer op.
1216 // `generateInBoundsCheck` will have evaluated the mask already.
1217}
1218
1219/// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1220/// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1221/// memref buffer is allocated and the SCF loop is fully unrolled.
1222///
1223/// ```
1224/// E.g.:
1225/// ```
1226/// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1227/// : memref<?x?x?xf32>, vector<5x4xf32>
1228/// ```
1229/// is rewritten to IR such as (simplified):
1230/// ```
1231/// %v_init = splat %padding : vector<5x4xf32>
1232/// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1233/// : memref<?x?x?xf32>, vector<4xf32>
1234/// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1235/// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1236/// : memref<?x?x?xf32>, vector<4xf32>
1237/// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1238/// ...
1239/// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1240/// : memref<?x?x?xf32>, vector<4xf32>
1241/// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1242/// ```
1243///
1244/// Note: As an optimization, if the result of the original TransferReadOp
1245/// was directly inserted into another vector, no new %v_init vector is created.
1246/// Instead, the new TransferReadOp results are inserted into that vector.
1247struct UnrollTransferReadConversion
1248 : public VectorToSCFPattern<TransferReadOp> {
1249 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1250
1251 void initialize() {
1252 // This pattern recursively unpacks one dimension at a time. The recursion
1253 // bounded as the rank is strictly decreasing.
1254 setHasBoundedRewriteRecursion();
1255 }
1256
1257 /// Get or build the vector into which the newly created TransferReadOp
1258 /// results are inserted.
1259 Value buildResultVector(PatternRewriter &rewriter,
1260 TransferReadOp xferOp) const {
1261 if (auto insertOp = getInsertOp(xferOp))
1262 return insertOp.getDest();
1263 Location loc = xferOp.getLoc();
1264 return rewriter.create<vector::SplatOp>(location: loc, args: xferOp.getVectorType(),
1265 args: xferOp.getPadding());
1266 }
1267
1268 /// If the result of the TransferReadOp has exactly one user, which is a
1269 /// vector::InsertOp, return that operation.
1270 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1271 if (xferOp->hasOneUse()) {
1272 Operation *xferOpUser = *xferOp->getUsers().begin();
1273 if (auto insertOp = dyn_cast<vector::InsertOp>(Val: xferOpUser))
1274 return insertOp;
1275 }
1276
1277 return vector::InsertOp();
1278 }
1279
1280 /// If the result of the TransferReadOp has exactly one user, which is a
1281 /// vector::InsertOp, return that operation's indices.
1282 void getInsertionIndices(TransferReadOp xferOp,
1283 SmallVectorImpl<OpFoldResult> &indices) const {
1284 if (auto insertOp = getInsertOp(xferOp)) {
1285 auto pos = insertOp.getMixedPosition();
1286 indices.append(in_start: pos.begin(), in_end: pos.end());
1287 }
1288 }
1289
1290 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1291 /// accesses, and broadcasts and transposes in permutation maps.
1292 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1293 PatternRewriter &rewriter) const override {
1294 if (xferOp.getVectorType().getRank() <= options.targetRank)
1295 return rewriter.notifyMatchFailure(
1296 arg&: xferOp, msg: "vector rank is less or equal to target rank");
1297 if (failed(Result: checkLowerTensors(xferOp, rewriter)))
1298 return failure();
1299 if (xferOp.getVectorType().getElementType() !=
1300 xferOp.getShapedType().getElementType())
1301 return rewriter.notifyMatchFailure(
1302 arg&: xferOp, msg: "not yet supported: element type mismatch");
1303 auto xferVecType = xferOp.getVectorType();
1304 if (xferVecType.getScalableDims()[0]) {
1305 return rewriter.notifyMatchFailure(
1306 arg&: xferOp, msg: "scalable dimensions cannot be unrolled at compile time");
1307 }
1308
1309 auto insertOp = getInsertOp(xferOp);
1310 auto vec = buildResultVector(rewriter, xferOp);
1311 auto vecType = dyn_cast<VectorType>(Val: vec.getType());
1312
1313 VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(pos: 0);
1314
1315 int64_t dimSize = xferVecType.getShape()[0];
1316
1317 // Generate fully unrolled loop of transfer ops.
1318 Location loc = xferOp.getLoc();
1319 for (int64_t i = 0; i < dimSize; ++i) {
1320 Value iv = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i);
1321
1322 // FIXME: Rename this lambda - it does much more than just
1323 // in-bounds-check generation.
1324 vec = generateInBoundsCheck(
1325 b&: rewriter, xferOp, iv, dim: unpackedDim(xferOp), resultTypes: TypeRange(vecType),
1326 /*inBoundsCase=*/
1327 [&](OpBuilder &b, Location loc) {
1328 // Indices for the new transfer op.
1329 SmallVector<Value, 8> xferIndices;
1330 getXferIndices(b, xferOp, iv, indices&: xferIndices);
1331
1332 // Indices for the new vector.insert op.
1333 SmallVector<OpFoldResult, 8> insertionIndices;
1334 getInsertionIndices(xferOp, indices&: insertionIndices);
1335 insertionIndices.push_back(Elt: rewriter.getIndexAttr(value: i));
1336
1337 auto inBoundsAttr = dropFirstElem(b, attr: xferOp.getInBoundsAttr());
1338
1339 auto newXferOp = b.create<vector::TransferReadOp>(
1340 location: loc, args&: newXferVecType, args: xferOp.getBase(), args&: xferIndices,
1341 args: AffineMapAttr::get(value: unpackedPermutationMap(b, xferOp)),
1342 args: xferOp.getPadding(), args: Value(), args&: inBoundsAttr);
1343 maybeAssignMask(b, xferOp, newXferOp, i);
1344
1345 Value valToInser = newXferOp.getResult();
1346 if (newXferVecType.getRank() == 0) {
1347 // vector.insert does not accept rank-0 as the non-indexed
1348 // argument. Extract the scalar before inserting.
1349 valToInser = b.create<vector::ExtractOp>(location: loc, args&: valToInser,
1350 args: SmallVector<int64_t>());
1351 }
1352 return b.create<vector::InsertOp>(location: loc, args&: valToInser, args&: vec,
1353 args&: insertionIndices);
1354 },
1355 /*outOfBoundsCase=*/
1356 [&](OpBuilder &b, Location loc) {
1357 // Loop through original (unmodified) vector.
1358 return vec;
1359 });
1360 }
1361
1362 if (insertOp) {
1363 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1364 rewriter.replaceOp(op: insertOp, newValues: vec);
1365 rewriter.eraseOp(op: xferOp);
1366 } else {
1367 rewriter.replaceOp(op: xferOp, newValues: vec);
1368 }
1369
1370 return success();
1371 }
1372};
1373
1374/// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1375/// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1376/// memref buffer is allocated and the SCF loop is fully unrolled.
1377///
1378/// ```
1379/// E.g.:
1380/// ```
1381/// vector.transfer_write %vec, %A[%a, %b, %c]
1382/// : vector<5x4xf32>, memref<?x?x?xf32>
1383/// ```
1384/// is rewritten to IR such as (simplified):
1385/// ```
1386/// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1387/// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1388/// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1389/// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1390/// ...
1391/// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1392/// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1393/// ```
1394///
1395/// Note: As an optimization, if the vector of the original TransferWriteOp
1396/// was directly extracted from another vector via an ExtractOp `a`, extract
1397/// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1398/// doing so, `a` may become dead, and the number of ExtractOps generated during
1399/// recursive application of this pattern will be minimal.
1400struct UnrollTransferWriteConversion
1401 : public VectorToSCFPattern<TransferWriteOp> {
1402 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1403
1404 void initialize() {
1405 // This pattern recursively unpacks one dimension at a time. The recursion
1406 // bounded as the rank is strictly decreasing.
1407 setHasBoundedRewriteRecursion();
1408 }
1409
1410 /// Return the vector from which newly generated ExtracOps will extract.
1411 Value getDataVector(TransferWriteOp xferOp) const {
1412 if (auto extractOp = getExtractOp(xferOp))
1413 return extractOp.getVector();
1414 return xferOp.getVector();
1415 }
1416
1417 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1418 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1419 if (auto *op = xferOp.getVector().getDefiningOp())
1420 return dyn_cast<vector::ExtractOp>(Val: op);
1421 return vector::ExtractOp();
1422 }
1423
1424 /// If the input of the given TransferWriteOp is an ExtractOp, return its
1425 /// indices.
1426 void getExtractionIndices(TransferWriteOp xferOp,
1427 SmallVectorImpl<OpFoldResult> &indices) const {
1428 if (auto extractOp = getExtractOp(xferOp)) {
1429 auto pos = extractOp.getMixedPosition();
1430 indices.append(in_start: pos.begin(), in_end: pos.end());
1431 }
1432 }
1433
1434 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1435 /// accesses, and broadcasts and transposes in permutation maps.
1436 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1437 PatternRewriter &rewriter) const override {
1438 VectorType inputVectorTy = xferOp.getVectorType();
1439
1440 if (inputVectorTy.getRank() <= options.targetRank)
1441 return failure();
1442
1443 if (failed(Result: checkLowerTensors(xferOp, rewriter)))
1444 return failure();
1445 // Transfer ops that modify the element type are not supported atm.
1446 if (inputVectorTy.getElementType() !=
1447 xferOp.getShapedType().getElementType())
1448 return failure();
1449
1450 auto vec = getDataVector(xferOp);
1451 if (inputVectorTy.getScalableDims()[0]) {
1452 // Cannot unroll a scalable dimension at compile time.
1453 return failure();
1454 }
1455
1456 int64_t dimSize = inputVectorTy.getShape()[0];
1457 Value source = xferOp.getBase(); // memref or tensor to be written to.
1458 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1459
1460 // Generate fully unrolled loop of transfer ops.
1461 Location loc = xferOp.getLoc();
1462 for (int64_t i = 0; i < dimSize; ++i) {
1463 Value iv = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i);
1464
1465 auto updatedSource = generateInBoundsCheck(
1466 b&: rewriter, xferOp, iv, dim: unpackedDim(xferOp),
1467 resultTypes: isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1468 /*inBoundsCase=*/
1469 [&](OpBuilder &b, Location loc) {
1470 // Indices for the new transfer op.
1471 SmallVector<Value, 8> xferIndices;
1472 getXferIndices(b, xferOp, iv, indices&: xferIndices);
1473
1474 // Indices for the new vector.extract op.
1475 SmallVector<OpFoldResult, 8> extractionIndices;
1476 getExtractionIndices(xferOp, indices&: extractionIndices);
1477 extractionIndices.push_back(Elt: b.getI64IntegerAttr(value: i));
1478
1479 auto extracted =
1480 b.create<vector::ExtractOp>(location: loc, args&: vec, args&: extractionIndices);
1481 auto inBoundsAttr = dropFirstElem(b, attr: xferOp.getInBoundsAttr());
1482 Value xferVec;
1483 if (inputVectorTy.getRank() == 1) {
1484 // When target-rank=0, unrolling would causes the vector input
1485 // argument into `transfer_write` to become a scalar. We solve
1486 // this by broadcasting the scalar to a 0D vector.
1487 xferVec = b.create<vector::BroadcastOp>(
1488 location: loc, args: VectorType::get(shape: {}, elementType: extracted.getType()), args&: extracted);
1489 } else {
1490 xferVec = extracted;
1491 }
1492 auto newXferOp = b.create<vector::TransferWriteOp>(
1493 location: loc, args&: sourceType, args&: xferVec, args&: source, args&: xferIndices,
1494 args: AffineMapAttr::get(value: unpackedPermutationMap(b, xferOp)), args: Value(),
1495 args&: inBoundsAttr);
1496
1497 maybeAssignMask(b, xferOp, newXferOp, i);
1498
1499 return isTensorOp(xferOp) ? newXferOp->getResult(idx: 0) : Value();
1500 },
1501 /*outOfBoundsCase=*/
1502 [&](OpBuilder &b, Location loc) {
1503 return isTensorOp(xferOp) ? source : Value();
1504 });
1505
1506 if (isTensorOp(xferOp))
1507 source = updatedSource;
1508 }
1509
1510 if (isTensorOp(xferOp))
1511 rewriter.replaceOp(op: xferOp, newValues: source);
1512 else
1513 rewriter.eraseOp(op: xferOp);
1514
1515 return success();
1516 }
1517};
1518
1519} // namespace lowering_n_d_unrolled
1520
1521namespace lowering_1_d {
1522
1523/// Compute the indices into the memref for the LoadOp/StoreOp generated as
1524/// part of TransferOp1dConversion. Return the memref dimension on which
1525/// the transfer is operating. A return value of std::nullopt indicates a
1526/// broadcast.
1527template <typename OpTy>
1528static std::optional<int64_t>
1529get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1530 SmallVector<Value, 8> &memrefIndices) {
1531 auto indices = xferOp.getIndices();
1532 auto map = xferOp.getPermutationMap();
1533 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1534
1535 memrefIndices.append(indices.begin(), indices.end());
1536 assert(map.getNumResults() == 1 &&
1537 "Expected 1 permutation map result for 1D transfer");
1538 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1539 Location loc = xferOp.getLoc();
1540 auto dim = expr.getPosition();
1541 AffineExpr d0, d1;
1542 bindDims(xferOp.getContext(), d0, d1);
1543 Value offset = memrefIndices[dim];
1544 memrefIndices[dim] =
1545 affine::makeComposedAffineApply(b, loc, e: d0 + d1, operands: {offset, iv});
1546 return dim;
1547 }
1548
1549 assert(xferOp.isBroadcastDim(0) &&
1550 "Expected AffineDimExpr or AffineConstantExpr");
1551 return std::nullopt;
1552}
1553
1554/// Codegen strategy for TransferOp1dConversion, depending on the
1555/// operation.
1556template <typename OpTy>
1557struct Strategy1d;
1558
1559/// Codegen strategy for TransferReadOp.
1560template <>
1561struct Strategy1d<TransferReadOp> {
1562 static void generateForLoopBody(OpBuilder &b, Location loc,
1563 TransferReadOp xferOp, Value iv,
1564 ValueRange loopState) {
1565 SmallVector<Value, 8> indices;
1566 auto dim = get1dMemrefIndices(b, xferOp, iv, memrefIndices&: indices);
1567 auto vec = loopState[0];
1568
1569 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1570 // padding value).
1571 auto nextVec = generateInBoundsCheck(
1572 b, xferOp, iv, dim, resultTypes: TypeRange(xferOp.getVectorType()),
1573 /*inBoundsCase=*/
1574 [&](OpBuilder &b, Location loc) {
1575 Value val = b.create<memref::LoadOp>(location: loc, args: xferOp.getBase(), args&: indices);
1576 return b.create<vector::InsertOp>(location: loc, args&: val, args&: vec, args&: iv);
1577 },
1578 /*outOfBoundsCase=*/
1579 [&](OpBuilder & /*b*/, Location loc) { return vec; });
1580 b.create<scf::YieldOp>(location: loc, args&: nextVec);
1581 }
1582
1583 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1584 // Inititalize vector with padding value.
1585 Location loc = xferOp.getLoc();
1586 return b.create<vector::SplatOp>(location: loc, args: xferOp.getVectorType(),
1587 args: xferOp.getPadding());
1588 }
1589};
1590
1591/// Codegen strategy for TransferWriteOp.
1592template <>
1593struct Strategy1d<TransferWriteOp> {
1594 static void generateForLoopBody(OpBuilder &b, Location loc,
1595 TransferWriteOp xferOp, Value iv,
1596 ValueRange /*loopState*/) {
1597 SmallVector<Value, 8> indices;
1598 auto dim = get1dMemrefIndices(b, xferOp, iv, memrefIndices&: indices);
1599
1600 // Nothing to do in case of out-of-bounds access.
1601 generateInBoundsCheck(
1602 b, xferOp, iv, dim,
1603 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1604 auto val = b.create<vector::ExtractOp>(location: loc, args: xferOp.getVector(), args&: iv);
1605 b.create<memref::StoreOp>(location: loc, args&: val, args: xferOp.getBase(), args&: indices);
1606 });
1607 b.create<scf::YieldOp>(location: loc);
1608 }
1609
1610 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1611 return Value();
1612 }
1613};
1614
1615/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1616/// necessary in cases where a 1D vector transfer op cannot be lowered into
1617/// vector load/stores due to non-unit strides or broadcasts:
1618///
1619/// * Transfer dimension is not the last memref dimension
1620/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1621/// * Memref has a layout map with non-unit stride on the last dimension
1622///
1623/// This pattern generates IR as follows:
1624///
1625/// 1. Generate a for loop iterating over each vector element.
1626/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1627/// depending on OpTy.
1628///
1629/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1630/// can be generated instead of TransferOp1dConversion. Add such a pattern
1631/// to ConvertVectorToLLVM.
1632///
1633/// E.g.:
1634/// ```
1635/// vector.transfer_write %vec, %A[%a, %b]
1636/// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1637/// : vector<9xf32>, memref<?x?xf32>
1638/// ```
1639/// Is rewritten to approximately the following pseudo-IR:
1640/// ```
1641/// for i = 0 to 9 {
1642/// %t = vector.extractelement %vec[i] : vector<9xf32>
1643/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1644/// }
1645/// ```
1646template <typename OpTy>
1647struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1648 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1649
1650 LogicalResult matchAndRewrite(OpTy xferOp,
1651 PatternRewriter &rewriter) const override {
1652 // TODO: support 0-d corner case.
1653 if (xferOp.getTransferRank() == 0)
1654 return failure();
1655 auto map = xferOp.getPermutationMap();
1656 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1657
1658 if (!memRefType)
1659 return failure();
1660 if (xferOp.getVectorType().getRank() != 1)
1661 return failure();
1662 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1663 return failure(); // Handled by ConvertVectorToLLVM
1664
1665 // Loop bounds, step, state...
1666 Location loc = xferOp.getLoc();
1667 auto vecType = xferOp.getVectorType();
1668 auto lb = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1669 Value ub =
1670 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1671 if (vecType.isScalable()) {
1672 Value vscale =
1673 rewriter.create<vector::VectorScaleOp>(location: loc, args: rewriter.getIndexType());
1674 ub = rewriter.create<arith::MulIOp>(location: loc, args&: ub, args&: vscale);
1675 }
1676 auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
1677 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1678
1679 // Generate for loop.
1680 rewriter.replaceOpWithNewOp<scf::ForOp>(
1681 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1682 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1683 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1684 });
1685
1686 return success();
1687 }
1688};
1689
1690} // namespace lowering_1_d
1691} // namespace
1692
1693void mlir::populateVectorToSCFConversionPatterns(
1694 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1695 if (options.unroll) {
1696 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1697 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1698 arg: patterns.getContext(), args: options);
1699 } else {
1700 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1701 lowering_n_d::PrepareTransferWriteConversion,
1702 lowering_n_d::TransferOpConversion<TransferReadOp>,
1703 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1704 arg: patterns.getContext(), args: options);
1705 }
1706 if (options.lowerScalable) {
1707 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1708 arg: patterns.getContext(), args: options);
1709 }
1710 if (options.targetRank == 1) {
1711 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1712 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1713 arg: patterns.getContext(), args: options);
1714 }
1715 patterns.add<lowering_n_d::DecomposePrintOpConversion>(arg: patterns.getContext(),
1716 args: options);
1717}
1718
1719namespace {
1720
1721struct ConvertVectorToSCFPass
1722 : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1723 ConvertVectorToSCFPass() = default;
1724 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1725 this->fullUnroll = options.unroll;
1726 this->targetRank = options.targetRank;
1727 this->lowerTensors = options.lowerTensors;
1728 this->lowerScalable = options.lowerScalable;
1729 }
1730
1731 void runOnOperation() override {
1732 VectorTransferToSCFOptions options;
1733 options.unroll = fullUnroll;
1734 options.targetRank = targetRank;
1735 options.lowerTensors = lowerTensors;
1736 options.lowerScalable = lowerScalable;
1737
1738 // Lower permutation maps first.
1739 RewritePatternSet lowerTransferPatterns(&getContext());
1740 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1741 patterns&: lowerTransferPatterns);
1742 (void)applyPatternsGreedily(op: getOperation(),
1743 patterns: std::move(lowerTransferPatterns));
1744
1745 RewritePatternSet patterns(&getContext());
1746 populateVectorToSCFConversionPatterns(patterns, options);
1747 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
1748 }
1749};
1750
1751} // namespace
1752
1753std::unique_ptr<Pass>
1754mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1755 return std::make_unique<ConvertVectorToSCFPass>(args: options);
1756}
1757

source code of mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp