1//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent patterns to rewrite a vector.transfer
10// op into a fully in-bounds part and a partial part.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15#include <type_traits>
16
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23
24#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Interfaces/VectorInterfaces.h"
28
29#include "llvm/ADT/DenseSet.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/Support/CommandLine.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/raw_ostream.h"
35
36#define DEBUG_TYPE "vector-transfer-split"
37
38using namespace mlir;
39using namespace mlir::vector;
40
41/// Build the condition to ensure that a particular VectorTransferOpInterface
42/// is in-bounds.
43static Value createInBoundsCond(RewriterBase &b,
44 VectorTransferOpInterface xferOp) {
45 assert(xferOp.getPermutationMap().isMinorIdentity() &&
46 "Expected minor identity map");
47 Value inBoundsCond;
48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49 // Zip over the resulting vector shape and memref indices.
50 // If the dimension is known to be in-bounds, it does not participate in
51 // the construction of `inBoundsCond`.
52 if (xferOp.isDimInBounds(resultIdx))
53 return;
54 // Fold or create the check that `index + vector_size` <= `memref_size`.
55 Location loc = xferOp.getLoc();
56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
57 OpFoldResult sum = affine::makeComposedFoldedAffineApply(
58 b, loc, b.getAffineDimExpr(position: 0) + b.getAffineConstantExpr(constant: vectorSize),
59 {xferOp.getIndices()[indicesIdx]});
60 OpFoldResult dimSz =
61 memref::getMixedSize(builder&: b, loc, value: xferOp.getSource(), dim: indicesIdx);
62 auto maybeCstSum = getConstantIntValue(ofr: sum);
63 auto maybeCstDimSz = getConstantIntValue(ofr: dimSz);
64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65 return;
66 Value cond =
67 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
68 getValueOrCreateConstantIndexOp(b, loc, sum),
69 getValueOrCreateConstantIndexOp(b, loc, dimSz));
70 // Conjunction over all dims for which we are in-bounds.
71 if (inBoundsCond)
72 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
73 else
74 inBoundsCond = cond;
75 });
76 return inBoundsCond;
77}
78
79/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
80/// masking) fast path and a slow path.
81/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
82/// newly created conditional upon function return.
83/// To accommodate for the fact that the original vector.transfer indexing may
84/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
85/// scf.if op returns a view and values of type index.
86/// At this time, only vector.transfer_read case is implemented.
87///
88/// Example (a 2-D vector.transfer_read):
89/// ```
90/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
91/// ```
92/// is transformed into:
93/// ```
94/// %1:3 = scf.if (%inBounds) {
95/// // fast path, direct cast
96/// memref.cast %A: memref<A...> to compatibleMemRefType
97/// scf.yield %view : compatibleMemRefType, index, index
98/// } else {
99/// // slow path, not in-bounds vector.transfer or linalg.copy.
100/// memref.cast %alloc: memref<B...> to compatibleMemRefType
101/// scf.yield %4 : compatibleMemRefType, index, index
102// }
103/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
104/// ```
105/// where `alloc` is a top of the function alloca'ed buffer of one vector.
106///
107/// Preconditions:
108/// 1. `xferOp.getPermutationMap()` must be a minor identity map
109/// 2. the rank of the `xferOp.memref()` and the rank of the
110/// `xferOp.getVector()` must be equal. This will be relaxed in the future
111/// but requires rank-reducing subviews.
112static LogicalResult
113splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
114 // TODO: support 0-d corner case.
115 if (xferOp.getTransferRank() == 0)
116 return failure();
117
118 // TODO: expand support to these 2 cases.
119 if (!xferOp.getPermutationMap().isMinorIdentity())
120 return failure();
121 // Must have some out-of-bounds dimension to be a candidate for splitting.
122 if (!xferOp.hasOutOfBoundsDim())
123 return failure();
124 // Don't split transfer operations directly under IfOp, this avoids applying
125 // the pattern recursively.
126 // TODO: improve the filtering condition to make it more applicable.
127 if (isa<scf::IfOp>(xferOp->getParentOp()))
128 return failure();
129 return success();
130}
131
132/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
133/// be cast. If the MemRefTypes don't have the same rank or are not strided,
134/// return null; otherwise:
135/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
136/// 2. else return a new MemRefType obtained by iterating over the shape and
137/// strides and:
138/// a. keeping the ones that are static and equal across `aT` and `bT`.
139/// b. using a dynamic shape and/or stride for the dimensions that don't
140/// agree.
141static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
142 if (memref::CastOp::areCastCompatible(aT, bT))
143 return aT;
144 if (aT.getRank() != bT.getRank())
145 return MemRefType();
146 int64_t aOffset, bOffset;
147 SmallVector<int64_t, 4> aStrides, bStrides;
148 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
149 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
150 aStrides.size() != bStrides.size())
151 return MemRefType();
152
153 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154 int64_t resOffset;
155 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
156 resStrides(bT.getRank(), 0);
157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158 resShape[idx] =
159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160 resStrides[idx] =
161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162 }
163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
164 return MemRefType::get(
165 resShape, aT.getElementType(),
166 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167}
168
169/// Casts the given memref to a compatible memref type. If the source memref has
170/// a different address space than the target type, a `memref.memory_space_cast`
171/// is first inserted, followed by a `memref.cast`.
172static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
173 MemRefType compatibleMemRefType) {
174 MemRefType sourceType = cast<MemRefType>(memref.getType());
175 Value res = memref;
176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177 sourceType = MemRefType::get(
178 sourceType.getShape(), sourceType.getElementType(),
179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180 res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181 }
182 if (sourceType == compatibleMemRefType)
183 return res;
184 return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185}
186
187/// Operates under a scoped context to build the intersection between the
188/// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
189// TODO: view intersection/union/differences should be a proper std op.
190static std::pair<Value, Value>
191createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
192 Value alloc) {
193 Location loc = xferOp.getLoc();
194 int64_t memrefRank = xferOp.getShapedType().getRank();
195 // TODO: relax this precondition, will require rank-reducing subviews.
196 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
197 "Expected memref rank to match the alloc rank");
198 ValueRange leadingIndices =
199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
200 SmallVector<OpFoldResult, 4> sizes;
201 sizes.append(in_start: leadingIndices.begin(), in_end: leadingIndices.end());
202 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
204 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
205 Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
206 xferOp.getSource(), indicesIdx);
207 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
208 Value index = xferOp.getIndices()[indicesIdx];
209 AffineExpr i, j, k;
210 bindDims(xferOp.getContext(), i, j, k);
211 SmallVector<AffineMap, 4> maps =
212 AffineMap::inferFromExprList(exprsList: MapList{{i - j, k}}, context: b.getContext());
213 // affine_min(%dimMemRef - %index, %dimAlloc)
214 Value affineMin = b.create<affine::AffineMinOp>(
215 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216 sizes.push_back(Elt: affineMin);
217 });
218
219 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
220 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
221 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
222 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
223 auto copySrc = b.create<memref::SubViewOp>(
224 loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225 auto copyDest = b.create<memref::SubViewOp>(
226 loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227 return std::make_pair(copySrc, copyDest);
228}
229
230/// Given an `xferOp` for which:
231/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
232/// 2. a memref of single vector `alloc` has been allocated.
233/// Produce IR resembling:
234/// ```
235/// %1:3 = scf.if (%inBounds) {
236/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
237/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
238/// scf.yield %view, ... : compatibleMemRefType, index, index
239/// } else {
240/// %2 = linalg.fill(%pad, %alloc)
241/// %3 = subview %view [...][...][...]
242/// %4 = subview %alloc [0, 0] [...] [...]
243/// linalg.copy(%3, %4)
244/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
245/// scf.yield %5, ... : compatibleMemRefType, index, index
246/// }
247/// ```
248/// Return the produced scf::IfOp.
249static scf::IfOp
250createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
251 TypeRange returnTypes, Value inBoundsCond,
252 MemRefType compatibleMemRefType, Value alloc) {
253 Location loc = xferOp.getLoc();
254 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
255 Value memref = xferOp.getSource();
256 return b.create<scf::IfOp>(
257 loc, inBoundsCond,
258 [&](OpBuilder &b, Location loc) {
259 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
260 scf::ValueVector viewAndIndices{res};
261 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
262 xferOp.getIndices().end());
263 b.create<scf::YieldOp>(loc, viewAndIndices);
264 },
265 [&](OpBuilder &b, Location loc) {
266 b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
267 ValueRange{alloc});
268 // Take partial subview of memref which guarantees no dimension
269 // overflows.
270 IRRewriter rewriter(b);
271 std::pair<Value, Value> copyArgs = createSubViewIntersection(
272 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
273 alloc);
274 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
275 Value casted =
276 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
277 scf::ValueVector viewAndIndices{casted};
278 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
279 zero);
280 b.create<scf::YieldOp>(loc, viewAndIndices);
281 });
282}
283
284/// Given an `xferOp` for which:
285/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
286/// 2. a memref of single vector `alloc` has been allocated.
287/// Produce IR resembling:
288/// ```
289/// %1:3 = scf.if (%inBounds) {
290/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
291/// memref.cast %A: memref<A...> to compatibleMemRefType
292/// scf.yield %view, ... : compatibleMemRefType, index, index
293/// } else {
294/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
295/// %3 = vector.type_cast %extra_alloc :
296/// memref<...> to memref<vector<...>>
297/// store %2, %3[] : memref<vector<...>>
298/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
299/// scf.yield %4, ... : compatibleMemRefType, index, index
300/// }
301/// ```
302/// Return the produced scf::IfOp.
303static scf::IfOp createFullPartialVectorTransferRead(
304 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
305 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
306 Location loc = xferOp.getLoc();
307 scf::IfOp fullPartialIfOp;
308 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
309 Value memref = xferOp.getSource();
310 return b.create<scf::IfOp>(
311 loc, inBoundsCond,
312 [&](OpBuilder &b, Location loc) {
313 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
314 scf::ValueVector viewAndIndices{res};
315 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316 xferOp.getIndices().end());
317 b.create<scf::YieldOp>(loc, viewAndIndices);
318 },
319 [&](OpBuilder &b, Location loc) {
320 Operation *newXfer = b.clone(*xferOp.getOperation());
321 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
322 b.create<memref::StoreOp>(
323 loc, vector,
324 b.create<vector::TypeCastOp>(
325 loc, MemRefType::get({}, vector.getType()), alloc));
326
327 Value casted =
328 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
329 scf::ValueVector viewAndIndices{casted};
330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331 zero);
332 b.create<scf::YieldOp>(loc, viewAndIndices);
333 });
334}
335
336/// Given an `xferOp` for which:
337/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338/// 2. a memref of single vector `alloc` has been allocated.
339/// Produce IR resembling:
340/// ```
341/// %1:3 = scf.if (%inBounds) {
342/// memref.cast %A: memref<A...> to compatibleMemRefType
343/// scf.yield %view, ... : compatibleMemRefType, index, index
344/// } else {
345/// %3 = vector.type_cast %extra_alloc :
346/// memref<...> to memref<vector<...>>
347/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348/// scf.yield %4, ... : compatibleMemRefType, index, index
349/// }
350/// ```
351static ValueRange
352getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353 TypeRange returnTypes, Value inBoundsCond,
354 MemRefType compatibleMemRefType, Value alloc) {
355 Location loc = xferOp.getLoc();
356 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
357 Value memref = xferOp.getSource();
358 return b
359 .create<scf::IfOp>(
360 loc, inBoundsCond,
361 [&](OpBuilder &b, Location loc) {
362 Value res =
363 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
364 scf::ValueVector viewAndIndices{res};
365 viewAndIndices.insert(viewAndIndices.end(),
366 xferOp.getIndices().begin(),
367 xferOp.getIndices().end());
368 b.create<scf::YieldOp>(loc, viewAndIndices);
369 },
370 [&](OpBuilder &b, Location loc) {
371 Value casted =
372 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
373 scf::ValueVector viewAndIndices{casted};
374 viewAndIndices.insert(viewAndIndices.end(),
375 xferOp.getTransferRank(), zero);
376 b.create<scf::YieldOp>(loc, viewAndIndices);
377 })
378 ->getResults();
379}
380
381/// Given an `xferOp` for which:
382/// 1. `inBoundsCond` has been computed.
383/// 2. a memref of single vector `alloc` has been allocated.
384/// 3. it originally wrote to %view
385/// Produce IR resembling:
386/// ```
387/// %notInBounds = arith.xori %inBounds, %true
388/// scf.if (%notInBounds) {
389/// %3 = subview %alloc [...][...][...]
390/// %4 = subview %view [0, 0][...][...]
391/// linalg.copy(%3, %4)
392/// }
393/// ```
394static void createFullPartialLinalgCopy(RewriterBase &b,
395 vector::TransferWriteOp xferOp,
396 Value inBoundsCond, Value alloc) {
397 Location loc = xferOp.getLoc();
398 auto notInBounds = b.create<arith::XOrIOp>(
399 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
400 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
401 IRRewriter rewriter(b);
402 std::pair<Value, Value> copyArgs = createSubViewIntersection(
403 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
404 alloc);
405 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
406 b.create<scf::YieldOp>(loc, ValueRange{});
407 });
408}
409
410/// Given an `xferOp` for which:
411/// 1. `inBoundsCond` has been computed.
412/// 2. a memref of single vector `alloc` has been allocated.
413/// 3. it originally wrote to %view
414/// Produce IR resembling:
415/// ```
416/// %notInBounds = arith.xori %inBounds, %true
417/// scf.if (%notInBounds) {
418/// %2 = load %alloc : memref<vector<...>>
419/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
420/// }
421/// ```
422static void createFullPartialVectorTransferWrite(RewriterBase &b,
423 vector::TransferWriteOp xferOp,
424 Value inBoundsCond,
425 Value alloc) {
426 Location loc = xferOp.getLoc();
427 auto notInBounds = b.create<arith::XOrIOp>(
428 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
429 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
430 IRMapping mapping;
431 Value load = b.create<memref::LoadOp>(
432 loc,
433 b.create<vector::TypeCastOp>(
434 loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
435 ValueRange());
436 mapping.map(xferOp.getVector(), load);
437 b.clone(*xferOp.getOperation(), mapping);
438 b.create<scf::YieldOp>(loc, ValueRange{});
439 });
440}
441
442// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
443static Operation *getAutomaticAllocationScope(Operation *op) {
444 // Find the closest surrounding allocation scope that is not a known looping
445 // construct (putting alloca's in loops doesn't always lower to deallocation
446 // until the end of the loop).
447 Operation *scope = nullptr;
448 for (Operation *parent = op->getParentOp(); parent != nullptr;
449 parent = parent->getParentOp()) {
450 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451 scope = parent;
452 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
453 break;
454 }
455 assert(scope && "Expected op to be inside automatic allocation scope");
456 return scope;
457}
458
459/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460/// masking) fastpath and a slowpath.
461///
462/// For vector.transfer_read:
463/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464/// newly created conditional upon function return.
465/// To accomodate for the fact that the original vector.transfer indexing may be
466/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467/// scf.if op returns a view and values of type index.
468///
469/// Example (a 2-D vector.transfer_read):
470/// ```
471/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472/// ```
473/// is transformed into:
474/// ```
475/// %1:3 = scf.if (%inBounds) {
476/// // fastpath, direct cast
477/// memref.cast %A: memref<A...> to compatibleMemRefType
478/// scf.yield %view : compatibleMemRefType, index, index
479/// } else {
480/// // slowpath, not in-bounds vector.transfer or linalg.copy.
481/// memref.cast %alloc: memref<B...> to compatibleMemRefType
482/// scf.yield %4 : compatibleMemRefType, index, index
483// }
484/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485/// ```
486/// where `alloc` is a top of the function alloca'ed buffer of one vector.
487///
488/// For vector.transfer_write:
489/// There are 2 conditional blocks. First a block to decide which memref and
490/// indices to use for an unmasked, inbounds write. Then a conditional block to
491/// further copy a partial buffer into the final result in the slow path case.
492///
493/// Example (a 2-D vector.transfer_write):
494/// ```
495/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496/// ```
497/// is transformed into:
498/// ```
499/// %1:3 = scf.if (%inBounds) {
500/// memref.cast %A: memref<A...> to compatibleMemRefType
501/// scf.yield %view : compatibleMemRefType, index, index
502/// } else {
503/// memref.cast %alloc: memref<B...> to compatibleMemRefType
504/// scf.yield %4 : compatibleMemRefType, index, index
505/// }
506/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507/// true]}
508/// scf.if (%notInBounds) {
509/// // slowpath: not in-bounds vector.transfer or linalg.copy.
510/// }
511/// ```
512/// where `alloc` is a top of the function alloca'ed buffer of one vector.
513///
514/// Preconditions:
515/// 1. `xferOp.getPermutationMap()` must be a minor identity map
516/// 2. the rank of the `xferOp.getSource()` and the rank of the
517/// `xferOp.getVector()` must be equal. This will be relaxed in the future
518/// but requires rank-reducing subviews.
519LogicalResult mlir::vector::splitFullAndPartialTransfer(
520 RewriterBase &b, VectorTransferOpInterface xferOp,
521 VectorTransformsOptions options, scf::IfOp *ifOp) {
522 if (options.vectorTransferSplit == VectorTransferSplit::None)
523 return failure();
524
525 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526 auto inBoundsAttr = b.getBoolArrayAttr(bools);
527 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528 b.modifyOpInPlace(xferOp, [&]() {
529 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
530 });
531 return success();
532 }
533
534 // Assert preconditions. Additionally, keep the variables in an inner scope to
535 // ensure they aren't used in the wrong scopes further down.
536 {
537 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
538 "Expected splitFullAndPartialTransferPrecondition to hold");
539
540 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
542
543 if (!(xferReadOp || xferWriteOp))
544 return failure();
545 if (xferWriteOp && xferWriteOp.getMask())
546 return failure();
547 if (xferReadOp && xferReadOp.getMask())
548 return failure();
549 }
550
551 RewriterBase::InsertionGuard guard(b);
552 b.setInsertionPoint(xferOp);
553 Value inBoundsCond = createInBoundsCond(
554 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
555 if (!inBoundsCond)
556 return failure();
557
558 // Top of the function `alloc` for transient storage.
559 Value alloc;
560 {
561 RewriterBase::InsertionGuard guard(b);
562 Operation *scope = getAutomaticAllocationScope(xferOp);
563 assert(scope->getNumRegions() == 1 &&
564 "AutomaticAllocationScope with >1 regions");
565 b.setInsertionPointToStart(&scope->getRegion(index: 0).front());
566 auto shape = xferOp.getVectorType().getShape();
567 Type elementType = xferOp.getVectorType().getElementType();
568 alloc = b.create<memref::AllocaOp>(scope->getLoc(),
569 MemRefType::get(shape, elementType),
570 ValueRange{}, b.getI64IntegerAttr(32));
571 }
572
573 MemRefType compatibleMemRefType =
574 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
575 cast<MemRefType>(alloc.getType()));
576 if (!compatibleMemRefType)
577 return failure();
578
579 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
580 b.getIndexType());
581 returnTypes[0] = compatibleMemRefType;
582
583 if (auto xferReadOp =
584 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
585 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
586 scf::IfOp fullPartialIfOp =
587 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
588 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
589 inBoundsCond,
590 compatibleMemRefType, alloc)
591 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
592 inBoundsCond, compatibleMemRefType,
593 alloc);
594 if (ifOp)
595 *ifOp = fullPartialIfOp;
596
597 // Set existing read op to in-bounds, it always reads from a full buffer.
598 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
600
601 b.modifyOpInPlace(xferOp, [&]() {
602 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
603 });
604
605 return success();
606 }
607
608 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
609
610 // Decide which location to write the entire vector to.
611 auto memrefAndIndices = getLocationToWriteFullVec(
612 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
613
614 // Do an in bounds write to either the output or the extra allocated buffer.
615 // The operation is cloned to prevent deleting information needed for the
616 // later IR creation.
617 IRMapping mapping;
618 mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
619 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
620 auto *clone = b.clone(*xferWriteOp, mapping);
621 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
622
623 // Create a potential copy from the allocated buffer to the final output in
624 // the slow path case.
625 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
626 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
627 else
628 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
629
630 b.eraseOp(op: xferOp);
631
632 return success();
633}
634
635namespace {
636/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
637/// may take an extra filter to perform selection at a finer granularity.
638struct VectorTransferFullPartialRewriter : public RewritePattern {
639 using FilterConstraintType =
640 std::function<LogicalResult(VectorTransferOpInterface op)>;
641
642 explicit VectorTransferFullPartialRewriter(
643 MLIRContext *context,
644 VectorTransformsOptions options = VectorTransformsOptions(),
645 FilterConstraintType filter =
646 [](VectorTransferOpInterface op) { return success(); },
647 PatternBenefit benefit = 1)
648 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
649 filter(std::move(filter)) {}
650
651 /// Performs the rewrite.
652 LogicalResult matchAndRewrite(Operation *op,
653 PatternRewriter &rewriter) const override;
654
655private:
656 VectorTransformsOptions options;
657 FilterConstraintType filter;
658};
659
660} // namespace
661
662LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
663 Operation *op, PatternRewriter &rewriter) const {
664 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
665 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
666 failed(filter(xferOp)))
667 return failure();
668 return splitFullAndPartialTransfer(rewriter, xferOp, options);
669}
670
671void mlir::vector::populateVectorTransferFullPartialPatterns(
672 RewritePatternSet &patterns, const VectorTransformsOptions &options) {
673 patterns.add<VectorTransferFullPartialRewriter>(arg: patterns.getContext(),
674 args: options);
675}
676

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp