1//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent patterns to rewrite a vector.transfer
10// op into a fully in-bounds part and a partial part.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Linalg/IR/Linalg.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22
23#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Interfaces/VectorInterfaces.h"
26
27#include "llvm/ADT/STLExtras.h"
28
29#define DEBUG_TYPE "vector-transfer-split"
30
31using namespace mlir;
32using namespace mlir::vector;
33
34/// Build the condition to ensure that a particular VectorTransferOpInterface
35/// is in-bounds.
36static Value createInBoundsCond(RewriterBase &b,
37 VectorTransferOpInterface xferOp) {
38 assert(xferOp.getPermutationMap().isMinorIdentity() &&
39 "Expected minor identity map");
40 Value inBoundsCond;
41 xferOp.zipResultAndIndexing(fun: [&](int64_t resultIdx, int64_t indicesIdx) {
42 // Zip over the resulting vector shape and memref indices.
43 // If the dimension is known to be in-bounds, it does not participate in
44 // the construction of `inBoundsCond`.
45 if (xferOp.isDimInBounds(dim: resultIdx))
46 return;
47 // Fold or create the check that `index + vector_size` <= `memref_size`.
48 Location loc = xferOp.getLoc();
49 int64_t vectorSize = xferOp.getVectorType().getDimSize(idx: resultIdx);
50 OpFoldResult sum = affine::makeComposedFoldedAffineApply(
51 b, loc, expr: b.getAffineDimExpr(position: 0) + b.getAffineConstantExpr(constant: vectorSize),
52 operands: {xferOp.getIndices()[indicesIdx]});
53 OpFoldResult dimSz =
54 memref::getMixedSize(builder&: b, loc, value: xferOp.getBase(), dim: indicesIdx);
55 auto maybeCstSum = getConstantIntValue(ofr: sum);
56 auto maybeCstDimSz = getConstantIntValue(ofr: dimSz);
57 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
58 return;
59 Value cond =
60 b.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::sle,
61 args: getValueOrCreateConstantIndexOp(b, loc, ofr: sum),
62 args: getValueOrCreateConstantIndexOp(b, loc, ofr: dimSz));
63 // Conjunction over all dims for which we are in-bounds.
64 if (inBoundsCond)
65 inBoundsCond = b.create<arith::AndIOp>(location: loc, args&: inBoundsCond, args&: cond);
66 else
67 inBoundsCond = cond;
68 });
69 return inBoundsCond;
70}
71
72/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
73/// masking) fast path and a slow path.
74/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
75/// newly created conditional upon function return.
76/// To accommodate for the fact that the original vector.transfer indexing may
77/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
78/// scf.if op returns a view and values of type index.
79/// At this time, only vector.transfer_read case is implemented.
80///
81/// Example (a 2-D vector.transfer_read):
82/// ```
83/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
84/// ```
85/// is transformed into:
86/// ```
87/// %1:3 = scf.if (%inBounds) {
88/// // fast path, direct cast
89/// memref.cast %A: memref<A...> to compatibleMemRefType
90/// scf.yield %view : compatibleMemRefType, index, index
91/// } else {
92/// // slow path, not in-bounds vector.transfer or linalg.copy.
93/// memref.cast %alloc: memref<B...> to compatibleMemRefType
94/// scf.yield %4 : compatibleMemRefType, index, index
95// }
96/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
97/// ```
98/// where `alloc` is a top of the function alloca'ed buffer of one vector.
99///
100/// Preconditions:
101/// 1. `xferOp.getPermutationMap()` must be a minor identity map
102/// 2. the rank of the `xferOp.memref()` and the rank of the
103/// `xferOp.getVector()` must be equal. This will be relaxed in the future
104/// but requires rank-reducing subviews.
105static LogicalResult
106splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
107 // TODO: support 0-d corner case.
108 if (xferOp.getTransferRank() == 0)
109 return failure();
110
111 // TODO: expand support to these 2 cases.
112 if (!xferOp.getPermutationMap().isMinorIdentity())
113 return failure();
114 // Must have some out-of-bounds dimension to be a candidate for splitting.
115 if (!xferOp.hasOutOfBoundsDim())
116 return failure();
117 // Don't split transfer operations directly under IfOp, this avoids applying
118 // the pattern recursively.
119 // TODO: improve the filtering condition to make it more applicable.
120 if (isa<scf::IfOp>(Val: xferOp->getParentOp()))
121 return failure();
122 return success();
123}
124
125/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
126/// be cast. If the MemRefTypes don't have the same rank or are not strided,
127/// return null; otherwise:
128/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
129/// 2. else return a new MemRefType obtained by iterating over the shape and
130/// strides and:
131/// a. keeping the ones that are static and equal across `aT` and `bT`.
132/// b. using a dynamic shape and/or stride for the dimensions that don't
133/// agree.
134static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
135 if (memref::CastOp::areCastCompatible(inputs: aT, outputs: bT))
136 return aT;
137 if (aT.getRank() != bT.getRank())
138 return MemRefType();
139 int64_t aOffset, bOffset;
140 SmallVector<int64_t, 4> aStrides, bStrides;
141 if (failed(Result: aT.getStridesAndOffset(strides&: aStrides, offset&: aOffset)) ||
142 failed(Result: bT.getStridesAndOffset(strides&: bStrides, offset&: bOffset)) ||
143 aStrides.size() != bStrides.size())
144 return MemRefType();
145
146 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
147 int64_t resOffset;
148 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
149 resStrides(bT.getRank(), 0);
150 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
151 resShape[idx] =
152 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
153 resStrides[idx] =
154 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
155 }
156 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
157 return MemRefType::get(
158 shape: resShape, elementType: aT.getElementType(),
159 layout: StridedLayoutAttr::get(context: aT.getContext(), offset: resOffset, strides: resStrides));
160}
161
162/// Casts the given memref to a compatible memref type. If the source memref has
163/// a different address space than the target type, a `memref.memory_space_cast`
164/// is first inserted, followed by a `memref.cast`.
165static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
166 MemRefType compatibleMemRefType) {
167 MemRefType sourceType = cast<MemRefType>(Val: memref.getType());
168 Value res = memref;
169 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
170 sourceType = MemRefType::get(
171 shape: sourceType.getShape(), elementType: sourceType.getElementType(),
172 layout: sourceType.getLayout(), memorySpace: compatibleMemRefType.getMemorySpace());
173 res = b.create<memref::MemorySpaceCastOp>(location: memref.getLoc(), args&: sourceType, args&: res);
174 }
175 if (sourceType == compatibleMemRefType)
176 return res;
177 return b.create<memref::CastOp>(location: memref.getLoc(), args&: compatibleMemRefType, args&: res);
178}
179
180/// Operates under a scoped context to build the intersection between the
181/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
182// TODO: view intersection/union/differences should be a proper std op.
183static std::pair<Value, Value>
184createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
185 Value alloc) {
186 Location loc = xferOp.getLoc();
187 int64_t memrefRank = xferOp.getShapedType().getRank();
188 // TODO: relax this precondition, will require rank-reducing subviews.
189 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
190 "Expected memref rank to match the alloc rank");
191 ValueRange leadingIndices =
192 xferOp.getIndices().take_front(n: xferOp.getLeadingShapedRank());
193 SmallVector<OpFoldResult, 4> sizes;
194 sizes.append(in_start: leadingIndices.begin(), in_end: leadingIndices.end());
195 auto isaWrite = isa<vector::TransferWriteOp>(Val: xferOp);
196 xferOp.zipResultAndIndexing(fun: [&](int64_t resultIdx, int64_t indicesIdx) {
197 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
198 Value dimMemRef =
199 b.create<memref::DimOp>(location: xferOp.getLoc(), args: xferOp.getBase(), args&: indicesIdx);
200 Value dimAlloc = b.create<memref::DimOp>(location: loc, args&: alloc, args&: resultIdx);
201 Value index = xferOp.getIndices()[indicesIdx];
202 AffineExpr i, j, k;
203 bindDims(ctx: xferOp.getContext(), exprs&: i, exprs&: j, exprs&: k);
204 SmallVector<AffineMap, 4> maps =
205 AffineMap::inferFromExprList(exprsList: MapList{{i - j, k}}, context: b.getContext());
206 // affine_min(%dimMemRef - %index, %dimAlloc)
207 Value affineMin = b.create<affine::AffineMinOp>(
208 location: loc, args: index.getType(), args&: maps[0], args: ValueRange{dimMemRef, index, dimAlloc});
209 sizes.push_back(Elt: affineMin);
210 });
211
212 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(Range: llvm::map_range(
213 C: xferOp.getIndices(), F: [](Value idx) -> OpFoldResult { return idx; }));
214 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(value: 0));
215 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(value: 1));
216 auto copySrc = b.create<memref::SubViewOp>(
217 location: loc, args: isaWrite ? alloc : xferOp.getBase(), args&: srcIndices, args&: sizes, args&: strides);
218 auto copyDest = b.create<memref::SubViewOp>(
219 location: loc, args: isaWrite ? xferOp.getBase() : alloc, args&: destIndices, args&: sizes, args&: strides);
220 return std::make_pair(x&: copySrc, y&: copyDest);
221}
222
223/// Given an `xferOp` for which:
224/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
225/// 2. a memref of single vector `alloc` has been allocated.
226/// Produce IR resembling:
227/// ```
228/// %1:3 = scf.if (%inBounds) {
229/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
230/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
231/// scf.yield %view, ... : compatibleMemRefType, index, index
232/// } else {
233/// %2 = linalg.fill(%pad, %alloc)
234/// %3 = subview %view [...][...][...]
235/// %4 = subview %alloc [0, 0] [...] [...]
236/// linalg.copy(%3, %4)
237/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
238/// scf.yield %5, ... : compatibleMemRefType, index, index
239/// }
240/// ```
241/// Return the produced scf::IfOp.
242static scf::IfOp
243createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
244 TypeRange returnTypes, Value inBoundsCond,
245 MemRefType compatibleMemRefType, Value alloc) {
246 Location loc = xferOp.getLoc();
247 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
248 Value memref = xferOp.getBase();
249 return b.create<scf::IfOp>(
250 location: loc, args&: inBoundsCond,
251 args: [&](OpBuilder &b, Location loc) {
252 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
253 scf::ValueVector viewAndIndices{res};
254 llvm::append_range(C&: viewAndIndices, R: xferOp.getIndices());
255 b.create<scf::YieldOp>(location: loc, args&: viewAndIndices);
256 },
257 args: [&](OpBuilder &b, Location loc) {
258 b.create<linalg::FillOp>(location: loc, args: ValueRange{xferOp.getPadding()},
259 args: ValueRange{alloc});
260 // Take partial subview of memref which guarantees no dimension
261 // overflows.
262 IRRewriter rewriter(b);
263 std::pair<Value, Value> copyArgs = createSubViewIntersection(
264 b&: rewriter, xferOp: cast<VectorTransferOpInterface>(Val: xferOp.getOperation()),
265 alloc);
266 b.create<memref::CopyOp>(location: loc, args&: copyArgs.first, args&: copyArgs.second);
267 Value casted =
268 castToCompatibleMemRefType(b, memref: alloc, compatibleMemRefType);
269 scf::ValueVector viewAndIndices{casted};
270 viewAndIndices.insert(I: viewAndIndices.end(), NumToInsert: xferOp.getTransferRank(),
271 Elt: zero);
272 b.create<scf::YieldOp>(location: loc, args&: viewAndIndices);
273 });
274}
275
276/// Given an `xferOp` for which:
277/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
278/// 2. a memref of single vector `alloc` has been allocated.
279/// Produce IR resembling:
280/// ```
281/// %1:3 = scf.if (%inBounds) {
282/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
283/// memref.cast %A: memref<A...> to compatibleMemRefType
284/// scf.yield %view, ... : compatibleMemRefType, index, index
285/// } else {
286/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
287/// %3 = vector.type_cast %extra_alloc :
288/// memref<...> to memref<vector<...>>
289/// store %2, %3[] : memref<vector<...>>
290/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
291/// scf.yield %4, ... : compatibleMemRefType, index, index
292/// }
293/// ```
294/// Return the produced scf::IfOp.
295static scf::IfOp createFullPartialVectorTransferRead(
296 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
297 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
298 Location loc = xferOp.getLoc();
299 scf::IfOp fullPartialIfOp;
300 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
301 Value memref = xferOp.getBase();
302 return b.create<scf::IfOp>(
303 location: loc, args&: inBoundsCond,
304 args: [&](OpBuilder &b, Location loc) {
305 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
306 scf::ValueVector viewAndIndices{res};
307 llvm::append_range(C&: viewAndIndices, R: xferOp.getIndices());
308 b.create<scf::YieldOp>(location: loc, args&: viewAndIndices);
309 },
310 args: [&](OpBuilder &b, Location loc) {
311 Operation *newXfer = b.clone(op&: *xferOp.getOperation());
312 Value vector = cast<VectorTransferOpInterface>(Val: newXfer).getVector();
313 b.create<memref::StoreOp>(
314 location: loc, args&: vector,
315 args: b.create<vector::TypeCastOp>(
316 location: loc, args: MemRefType::get(shape: {}, elementType: vector.getType()), args&: alloc));
317
318 Value casted =
319 castToCompatibleMemRefType(b, memref: alloc, compatibleMemRefType);
320 scf::ValueVector viewAndIndices{casted};
321 viewAndIndices.insert(I: viewAndIndices.end(), NumToInsert: xferOp.getTransferRank(),
322 Elt: zero);
323 b.create<scf::YieldOp>(location: loc, args&: viewAndIndices);
324 });
325}
326
327/// Given an `xferOp` for which:
328/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
329/// 2. a memref of single vector `alloc` has been allocated.
330/// Produce IR resembling:
331/// ```
332/// %1:3 = scf.if (%inBounds) {
333/// memref.cast %A: memref<A...> to compatibleMemRefType
334/// scf.yield %view, ... : compatibleMemRefType, index, index
335/// } else {
336/// %3 = vector.type_cast %extra_alloc :
337/// memref<...> to memref<vector<...>>
338/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
339/// scf.yield %4, ... : compatibleMemRefType, index, index
340/// }
341/// ```
342static ValueRange
343getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
344 TypeRange returnTypes, Value inBoundsCond,
345 MemRefType compatibleMemRefType, Value alloc) {
346 Location loc = xferOp.getLoc();
347 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
348 Value memref = xferOp.getBase();
349 return b
350 .create<scf::IfOp>(
351 location: loc, args&: inBoundsCond,
352 args: [&](OpBuilder &b, Location loc) {
353 Value res =
354 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
355 scf::ValueVector viewAndIndices{res};
356 llvm::append_range(C&: viewAndIndices, R: xferOp.getIndices());
357 b.create<scf::YieldOp>(location: loc, args&: viewAndIndices);
358 },
359 args: [&](OpBuilder &b, Location loc) {
360 Value casted =
361 castToCompatibleMemRefType(b, memref: alloc, compatibleMemRefType);
362 scf::ValueVector viewAndIndices{casted};
363 viewAndIndices.insert(I: viewAndIndices.end(),
364 NumToInsert: xferOp.getTransferRank(), Elt: zero);
365 b.create<scf::YieldOp>(location: loc, args&: viewAndIndices);
366 })
367 ->getResults();
368}
369
370/// Given an `xferOp` for which:
371/// 1. `inBoundsCond` has been computed.
372/// 2. a memref of single vector `alloc` has been allocated.
373/// 3. it originally wrote to %view
374/// Produce IR resembling:
375/// ```
376/// %notInBounds = arith.xori %inBounds, %true
377/// scf.if (%notInBounds) {
378/// %3 = subview %alloc [...][...][...]
379/// %4 = subview %view [0, 0][...][...]
380/// linalg.copy(%3, %4)
381/// }
382/// ```
383static void createFullPartialLinalgCopy(RewriterBase &b,
384 vector::TransferWriteOp xferOp,
385 Value inBoundsCond, Value alloc) {
386 Location loc = xferOp.getLoc();
387 auto notInBounds = b.create<arith::XOrIOp>(
388 location: loc, args&: inBoundsCond, args: b.create<arith::ConstantIntOp>(location: loc, args: true, args: 1));
389 b.create<scf::IfOp>(location: loc, args&: notInBounds, args: [&](OpBuilder &b, Location loc) {
390 IRRewriter rewriter(b);
391 std::pair<Value, Value> copyArgs = createSubViewIntersection(
392 b&: rewriter, xferOp: cast<VectorTransferOpInterface>(Val: xferOp.getOperation()),
393 alloc);
394 b.create<memref::CopyOp>(location: loc, args&: copyArgs.first, args&: copyArgs.second);
395 b.create<scf::YieldOp>(location: loc, args: ValueRange{});
396 });
397}
398
399/// Given an `xferOp` for which:
400/// 1. `inBoundsCond` has been computed.
401/// 2. a memref of single vector `alloc` has been allocated.
402/// 3. it originally wrote to %view
403/// Produce IR resembling:
404/// ```
405/// %notInBounds = arith.xori %inBounds, %true
406/// scf.if (%notInBounds) {
407/// %2 = load %alloc : memref<vector<...>>
408/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
409/// }
410/// ```
411static void createFullPartialVectorTransferWrite(RewriterBase &b,
412 vector::TransferWriteOp xferOp,
413 Value inBoundsCond,
414 Value alloc) {
415 Location loc = xferOp.getLoc();
416 auto notInBounds = b.create<arith::XOrIOp>(
417 location: loc, args&: inBoundsCond, args: b.create<arith::ConstantIntOp>(location: loc, args: true, args: 1));
418 b.create<scf::IfOp>(location: loc, args&: notInBounds, args: [&](OpBuilder &b, Location loc) {
419 IRMapping mapping;
420 Value load = b.create<memref::LoadOp>(
421 location: loc,
422 args: b.create<vector::TypeCastOp>(
423 location: loc, args: MemRefType::get(shape: {}, elementType: xferOp.getVector().getType()), args&: alloc),
424 args: ValueRange());
425 mapping.map(from: xferOp.getVector(), to: load);
426 b.clone(op&: *xferOp.getOperation(), mapper&: mapping);
427 b.create<scf::YieldOp>(location: loc, args: ValueRange{});
428 });
429}
430
431// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
432static Operation *getAutomaticAllocationScope(Operation *op) {
433 // Find the closest surrounding allocation scope that is not a known looping
434 // construct (putting alloca's in loops doesn't always lower to deallocation
435 // until the end of the loop).
436 Operation *scope = nullptr;
437 for (Operation *parent = op->getParentOp(); parent != nullptr;
438 parent = parent->getParentOp()) {
439 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
440 scope = parent;
441 if (!isa<scf::ForOp, affine::AffineForOp>(Val: parent))
442 break;
443 }
444 assert(scope && "Expected op to be inside automatic allocation scope");
445 return scope;
446}
447
448/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
449/// masking) fastpath and a slowpath.
450///
451/// For vector.transfer_read:
452/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
453/// newly created conditional upon function return.
454/// To accomodate for the fact that the original vector.transfer indexing may be
455/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
456/// scf.if op returns a view and values of type index.
457///
458/// Example (a 2-D vector.transfer_read):
459/// ```
460/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
461/// ```
462/// is transformed into:
463/// ```
464/// %1:3 = scf.if (%inBounds) {
465/// // fastpath, direct cast
466/// memref.cast %A: memref<A...> to compatibleMemRefType
467/// scf.yield %view : compatibleMemRefType, index, index
468/// } else {
469/// // slowpath, not in-bounds vector.transfer or linalg.copy.
470/// memref.cast %alloc: memref<B...> to compatibleMemRefType
471/// scf.yield %4 : compatibleMemRefType, index, index
472// }
473/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
474/// ```
475/// where `alloc` is a top of the function alloca'ed buffer of one vector.
476///
477/// For vector.transfer_write:
478/// There are 2 conditional blocks. First a block to decide which memref and
479/// indices to use for an unmasked, inbounds write. Then a conditional block to
480/// further copy a partial buffer into the final result in the slow path case.
481///
482/// Example (a 2-D vector.transfer_write):
483/// ```
484/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
485/// ```
486/// is transformed into:
487/// ```
488/// %1:3 = scf.if (%inBounds) {
489/// memref.cast %A: memref<A...> to compatibleMemRefType
490/// scf.yield %view : compatibleMemRefType, index, index
491/// } else {
492/// memref.cast %alloc: memref<B...> to compatibleMemRefType
493/// scf.yield %4 : compatibleMemRefType, index, index
494/// }
495/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
496/// true]}
497/// scf.if (%notInBounds) {
498/// // slowpath: not in-bounds vector.transfer or linalg.copy.
499/// }
500/// ```
501/// where `alloc` is a top of the function alloca'ed buffer of one vector.
502///
503/// Preconditions:
504/// 1. `xferOp.getPermutationMap()` must be a minor identity map
505/// 2. the rank of the `xferOp.getBase()` and the rank of the
506/// `xferOp.getVector()` must be equal. This will be relaxed in the future
507/// but requires rank-reducing subviews.
508LogicalResult mlir::vector::splitFullAndPartialTransfer(
509 RewriterBase &b, VectorTransferOpInterface xferOp,
510 VectorTransformsOptions options, scf::IfOp *ifOp) {
511 if (options.vectorTransferSplit == VectorTransferSplit::None)
512 return failure();
513
514 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
515 auto inBoundsAttr = b.getBoolArrayAttr(values: bools);
516 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
517 b.modifyOpInPlace(root: xferOp, callable: [&]() {
518 xferOp->setAttr(name: xferOp.getInBoundsAttrName(), value: inBoundsAttr);
519 });
520 return success();
521 }
522
523 // Assert preconditions. Additionally, keep the variables in an inner scope to
524 // ensure they aren't used in the wrong scopes further down.
525 {
526 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
527 "Expected splitFullAndPartialTransferPrecondition to hold");
528
529 auto xferReadOp = dyn_cast<vector::TransferReadOp>(Val: xferOp.getOperation());
530 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(Val: xferOp.getOperation());
531
532 if (!(xferReadOp || xferWriteOp))
533 return failure();
534 if (xferWriteOp && xferWriteOp.getMask())
535 return failure();
536 if (xferReadOp && xferReadOp.getMask())
537 return failure();
538 }
539
540 RewriterBase::InsertionGuard guard(b);
541 b.setInsertionPoint(xferOp);
542 Value inBoundsCond = createInBoundsCond(
543 b, xferOp: cast<VectorTransferOpInterface>(Val: xferOp.getOperation()));
544 if (!inBoundsCond)
545 return failure();
546
547 // Top of the function `alloc` for transient storage.
548 Value alloc;
549 {
550 RewriterBase::InsertionGuard guard(b);
551 Operation *scope = getAutomaticAllocationScope(op: xferOp);
552 assert(scope->getNumRegions() == 1 &&
553 "AutomaticAllocationScope with >1 regions");
554 b.setInsertionPointToStart(&scope->getRegion(index: 0).front());
555 auto shape = xferOp.getVectorType().getShape();
556 Type elementType = xferOp.getVectorType().getElementType();
557 alloc = b.create<memref::AllocaOp>(location: scope->getLoc(),
558 args: MemRefType::get(shape, elementType),
559 args: ValueRange{}, args: b.getI64IntegerAttr(value: 32));
560 }
561
562 MemRefType compatibleMemRefType =
563 getCastCompatibleMemRefType(aT: cast<MemRefType>(Val: xferOp.getShapedType()),
564 bT: cast<MemRefType>(Val: alloc.getType()));
565 if (!compatibleMemRefType)
566 return failure();
567
568 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
569 b.getIndexType());
570 returnTypes[0] = compatibleMemRefType;
571
572 if (auto xferReadOp =
573 dyn_cast<vector::TransferReadOp>(Val: xferOp.getOperation())) {
574 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
575 scf::IfOp fullPartialIfOp =
576 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
577 ? createFullPartialVectorTransferRead(b, xferOp: xferReadOp, returnTypes,
578 inBoundsCond,
579 compatibleMemRefType, alloc)
580 : createFullPartialLinalgCopy(b, xferOp: xferReadOp, returnTypes,
581 inBoundsCond, compatibleMemRefType,
582 alloc);
583 if (ifOp)
584 *ifOp = fullPartialIfOp;
585
586 // Set existing read op to in-bounds, it always reads from a full buffer.
587 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
588 xferReadOp.setOperand(i, value: fullPartialIfOp.getResult(i));
589
590 b.modifyOpInPlace(root: xferOp, callable: [&]() {
591 xferOp->setAttr(name: xferOp.getInBoundsAttrName(), value: inBoundsAttr);
592 });
593
594 return success();
595 }
596
597 auto xferWriteOp = cast<vector::TransferWriteOp>(Val: xferOp.getOperation());
598
599 // Decide which location to write the entire vector to.
600 auto memrefAndIndices = getLocationToWriteFullVec(
601 b, xferOp: xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
602
603 // Do an in bounds write to either the output or the extra allocated buffer.
604 // The operation is cloned to prevent deleting information needed for the
605 // later IR creation.
606 IRMapping mapping;
607 mapping.map(from: xferWriteOp.getBase(), to: memrefAndIndices.front());
608 mapping.map(from: xferWriteOp.getIndices(), to: memrefAndIndices.drop_front());
609 auto *clone = b.clone(op&: *xferWriteOp, mapper&: mapping);
610 clone->setAttr(name: xferWriteOp.getInBoundsAttrName(), value: inBoundsAttr);
611
612 // Create a potential copy from the allocated buffer to the final output in
613 // the slow path case.
614 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
615 createFullPartialVectorTransferWrite(b, xferOp: xferWriteOp, inBoundsCond, alloc);
616 else
617 createFullPartialLinalgCopy(b, xferOp: xferWriteOp, inBoundsCond, alloc);
618
619 b.eraseOp(op: xferOp);
620
621 return success();
622}
623
624namespace {
625/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
626/// may take an extra filter to perform selection at a finer granularity.
627struct VectorTransferFullPartialRewriter : public RewritePattern {
628 using FilterConstraintType =
629 std::function<LogicalResult(VectorTransferOpInterface op)>;
630
631 explicit VectorTransferFullPartialRewriter(
632 MLIRContext *context,
633 VectorTransformsOptions options = VectorTransformsOptions(),
634 FilterConstraintType filter =
635 [](VectorTransferOpInterface op) { return success(); },
636 PatternBenefit benefit = 1)
637 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
638 filter(std::move(filter)) {}
639
640 /// Performs the rewrite.
641 LogicalResult matchAndRewrite(Operation *op,
642 PatternRewriter &rewriter) const override;
643
644private:
645 VectorTransformsOptions options;
646 FilterConstraintType filter;
647};
648
649} // namespace
650
651LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
652 Operation *op, PatternRewriter &rewriter) const {
653 auto xferOp = dyn_cast<VectorTransferOpInterface>(Val: op);
654 if (!xferOp || failed(Result: splitFullAndPartialTransferPrecondition(xferOp)) ||
655 failed(Result: filter(xferOp)))
656 return failure();
657 return splitFullAndPartialTransfer(b&: rewriter, xferOp, options);
658}
659
660void mlir::vector::populateVectorTransferFullPartialPatterns(
661 RewritePatternSet &patterns, const VectorTransformsOptions &options) {
662 patterns.add<VectorTransferFullPartialRewriter>(arg: patterns.getContext(),
663 args: options);
664}
665

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp