1//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent patterns to rewrite a vector.transfer
10// op into a fully in-bounds part and a partial part.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15#include <type_traits>
16
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Linalg/IR/Linalg.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23
24#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Interfaces/VectorInterfaces.h"
28
29#include "llvm/ADT/DenseSet.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/Support/CommandLine.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/raw_ostream.h"
35
36#define DEBUG_TYPE "vector-transfer-split"
37
38using namespace mlir;
39using namespace mlir::vector;
40
41/// Build the condition to ensure that a particular VectorTransferOpInterface
42/// is in-bounds.
43static Value createInBoundsCond(RewriterBase &b,
44 VectorTransferOpInterface xferOp) {
45 assert(xferOp.getPermutationMap().isMinorIdentity() &&
46 "Expected minor identity map");
47 Value inBoundsCond;
48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49 // Zip over the resulting vector shape and memref indices.
50 // If the dimension is known to be in-bounds, it does not participate in
51 // the construction of `inBoundsCond`.
52 if (xferOp.isDimInBounds(resultIdx))
53 return;
54 // Fold or create the check that `index + vector_size` <= `memref_size`.
55 Location loc = xferOp.getLoc();
56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
57 OpFoldResult sum = affine::makeComposedFoldedAffineApply(
58 b, loc, b.getAffineDimExpr(position: 0) + b.getAffineConstantExpr(constant: vectorSize),
59 {xferOp.getIndices()[indicesIdx]});
60 OpFoldResult dimSz =
61 memref::getMixedSize(builder&: b, loc, value: xferOp.getBase(), dim: indicesIdx);
62 auto maybeCstSum = getConstantIntValue(ofr: sum);
63 auto maybeCstDimSz = getConstantIntValue(ofr: dimSz);
64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65 return;
66 Value cond =
67 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
68 getValueOrCreateConstantIndexOp(b, loc, sum),
69 getValueOrCreateConstantIndexOp(b, loc, dimSz));
70 // Conjunction over all dims for which we are in-bounds.
71 if (inBoundsCond)
72 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
73 else
74 inBoundsCond = cond;
75 });
76 return inBoundsCond;
77}
78
79/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
80/// masking) fast path and a slow path.
81/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
82/// newly created conditional upon function return.
83/// To accommodate for the fact that the original vector.transfer indexing may
84/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
85/// scf.if op returns a view and values of type index.
86/// At this time, only vector.transfer_read case is implemented.
87///
88/// Example (a 2-D vector.transfer_read):
89/// ```
90/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
91/// ```
92/// is transformed into:
93/// ```
94/// %1:3 = scf.if (%inBounds) {
95/// // fast path, direct cast
96/// memref.cast %A: memref<A...> to compatibleMemRefType
97/// scf.yield %view : compatibleMemRefType, index, index
98/// } else {
99/// // slow path, not in-bounds vector.transfer or linalg.copy.
100/// memref.cast %alloc: memref<B...> to compatibleMemRefType
101/// scf.yield %4 : compatibleMemRefType, index, index
102// }
103/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
104/// ```
105/// where `alloc` is a top of the function alloca'ed buffer of one vector.
106///
107/// Preconditions:
108/// 1. `xferOp.getPermutationMap()` must be a minor identity map
109/// 2. the rank of the `xferOp.memref()` and the rank of the
110/// `xferOp.getVector()` must be equal. This will be relaxed in the future
111/// but requires rank-reducing subviews.
112static LogicalResult
113splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
114 // TODO: support 0-d corner case.
115 if (xferOp.getTransferRank() == 0)
116 return failure();
117
118 // TODO: expand support to these 2 cases.
119 if (!xferOp.getPermutationMap().isMinorIdentity())
120 return failure();
121 // Must have some out-of-bounds dimension to be a candidate for splitting.
122 if (!xferOp.hasOutOfBoundsDim())
123 return failure();
124 // Don't split transfer operations directly under IfOp, this avoids applying
125 // the pattern recursively.
126 // TODO: improve the filtering condition to make it more applicable.
127 if (isa<scf::IfOp>(xferOp->getParentOp()))
128 return failure();
129 return success();
130}
131
132/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
133/// be cast. If the MemRefTypes don't have the same rank or are not strided,
134/// return null; otherwise:
135/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
136/// 2. else return a new MemRefType obtained by iterating over the shape and
137/// strides and:
138/// a. keeping the ones that are static and equal across `aT` and `bT`.
139/// b. using a dynamic shape and/or stride for the dimensions that don't
140/// agree.
141static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
142 if (memref::CastOp::areCastCompatible(aT, bT))
143 return aT;
144 if (aT.getRank() != bT.getRank())
145 return MemRefType();
146 int64_t aOffset, bOffset;
147 SmallVector<int64_t, 4> aStrides, bStrides;
148 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
149 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
150 aStrides.size() != bStrides.size())
151 return MemRefType();
152
153 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154 int64_t resOffset;
155 SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
156 resStrides(bT.getRank(), 0);
157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158 resShape[idx] =
159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160 resStrides[idx] =
161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162 }
163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
164 return MemRefType::get(
165 resShape, aT.getElementType(),
166 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167}
168
169/// Casts the given memref to a compatible memref type. If the source memref has
170/// a different address space than the target type, a `memref.memory_space_cast`
171/// is first inserted, followed by a `memref.cast`.
172static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
173 MemRefType compatibleMemRefType) {
174 MemRefType sourceType = cast<MemRefType>(memref.getType());
175 Value res = memref;
176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177 sourceType = MemRefType::get(
178 sourceType.getShape(), sourceType.getElementType(),
179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180 res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181 }
182 if (sourceType == compatibleMemRefType)
183 return res;
184 return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185}
186
187/// Operates under a scoped context to build the intersection between the
188/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
189// TODO: view intersection/union/differences should be a proper std op.
190static std::pair<Value, Value>
191createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
192 Value alloc) {
193 Location loc = xferOp.getLoc();
194 int64_t memrefRank = xferOp.getShapedType().getRank();
195 // TODO: relax this precondition, will require rank-reducing subviews.
196 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
197 "Expected memref rank to match the alloc rank");
198 ValueRange leadingIndices =
199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
200 SmallVector<OpFoldResult, 4> sizes;
201 sizes.append(in_start: leadingIndices.begin(), in_end: leadingIndices.end());
202 auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
204 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
205 Value dimMemRef =
206 b.create<memref::DimOp>(xferOp.getLoc(), xferOp.getBase(), indicesIdx);
207 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
208 Value index = xferOp.getIndices()[indicesIdx];
209 AffineExpr i, j, k;
210 bindDims(xferOp.getContext(), i, j, k);
211 SmallVector<AffineMap, 4> maps =
212 AffineMap::inferFromExprList(exprsList: MapList{{i - j, k}}, context: b.getContext());
213 // affine_min(%dimMemRef - %index, %dimAlloc)
214 Value affineMin = b.create<affine::AffineMinOp>(
215 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216 sizes.push_back(Elt: affineMin);
217 });
218
219 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
220 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
221 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
222 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
223 auto copySrc = b.create<memref::SubViewOp>(
224 loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
225 auto copyDest = b.create<memref::SubViewOp>(
226 loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
227 return std::make_pair(copySrc, copyDest);
228}
229
230/// Given an `xferOp` for which:
231/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
232/// 2. a memref of single vector `alloc` has been allocated.
233/// Produce IR resembling:
234/// ```
235/// %1:3 = scf.if (%inBounds) {
236/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
237/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
238/// scf.yield %view, ... : compatibleMemRefType, index, index
239/// } else {
240/// %2 = linalg.fill(%pad, %alloc)
241/// %3 = subview %view [...][...][...]
242/// %4 = subview %alloc [0, 0] [...] [...]
243/// linalg.copy(%3, %4)
244/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
245/// scf.yield %5, ... : compatibleMemRefType, index, index
246/// }
247/// ```
248/// Return the produced scf::IfOp.
249static scf::IfOp
250createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
251 TypeRange returnTypes, Value inBoundsCond,
252 MemRefType compatibleMemRefType, Value alloc) {
253 Location loc = xferOp.getLoc();
254 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
255 Value memref = xferOp.getBase();
256 return b.create<scf::IfOp>(
257 loc, inBoundsCond,
258 [&](OpBuilder &b, Location loc) {
259 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
260 scf::ValueVector viewAndIndices{res};
261 llvm::append_range(viewAndIndices, xferOp.getIndices());
262 b.create<scf::YieldOp>(loc, viewAndIndices);
263 },
264 [&](OpBuilder &b, Location loc) {
265 b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
266 ValueRange{alloc});
267 // Take partial subview of memref which guarantees no dimension
268 // overflows.
269 IRRewriter rewriter(b);
270 std::pair<Value, Value> copyArgs = createSubViewIntersection(
271 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
272 alloc);
273 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
274 Value casted =
275 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
276 scf::ValueVector viewAndIndices{casted};
277 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
278 zero);
279 b.create<scf::YieldOp>(loc, viewAndIndices);
280 });
281}
282
283/// Given an `xferOp` for which:
284/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
285/// 2. a memref of single vector `alloc` has been allocated.
286/// Produce IR resembling:
287/// ```
288/// %1:3 = scf.if (%inBounds) {
289/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
290/// memref.cast %A: memref<A...> to compatibleMemRefType
291/// scf.yield %view, ... : compatibleMemRefType, index, index
292/// } else {
293/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
294/// %3 = vector.type_cast %extra_alloc :
295/// memref<...> to memref<vector<...>>
296/// store %2, %3[] : memref<vector<...>>
297/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
298/// scf.yield %4, ... : compatibleMemRefType, index, index
299/// }
300/// ```
301/// Return the produced scf::IfOp.
302static scf::IfOp createFullPartialVectorTransferRead(
303 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
304 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
305 Location loc = xferOp.getLoc();
306 scf::IfOp fullPartialIfOp;
307 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
308 Value memref = xferOp.getBase();
309 return b.create<scf::IfOp>(
310 loc, inBoundsCond,
311 [&](OpBuilder &b, Location loc) {
312 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
313 scf::ValueVector viewAndIndices{res};
314 llvm::append_range(viewAndIndices, xferOp.getIndices());
315 b.create<scf::YieldOp>(loc, viewAndIndices);
316 },
317 [&](OpBuilder &b, Location loc) {
318 Operation *newXfer = b.clone(*xferOp.getOperation());
319 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
320 b.create<memref::StoreOp>(
321 loc, vector,
322 b.create<vector::TypeCastOp>(
323 loc, MemRefType::get({}, vector.getType()), alloc));
324
325 Value casted =
326 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
327 scf::ValueVector viewAndIndices{casted};
328 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
329 zero);
330 b.create<scf::YieldOp>(loc, viewAndIndices);
331 });
332}
333
334/// Given an `xferOp` for which:
335/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
336/// 2. a memref of single vector `alloc` has been allocated.
337/// Produce IR resembling:
338/// ```
339/// %1:3 = scf.if (%inBounds) {
340/// memref.cast %A: memref<A...> to compatibleMemRefType
341/// scf.yield %view, ... : compatibleMemRefType, index, index
342/// } else {
343/// %3 = vector.type_cast %extra_alloc :
344/// memref<...> to memref<vector<...>>
345/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
346/// scf.yield %4, ... : compatibleMemRefType, index, index
347/// }
348/// ```
349static ValueRange
350getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
351 TypeRange returnTypes, Value inBoundsCond,
352 MemRefType compatibleMemRefType, Value alloc) {
353 Location loc = xferOp.getLoc();
354 Value zero = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
355 Value memref = xferOp.getBase();
356 return b
357 .create<scf::IfOp>(
358 loc, inBoundsCond,
359 [&](OpBuilder &b, Location loc) {
360 Value res =
361 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
362 scf::ValueVector viewAndIndices{res};
363 llvm::append_range(viewAndIndices, xferOp.getIndices());
364 b.create<scf::YieldOp>(loc, viewAndIndices);
365 },
366 [&](OpBuilder &b, Location loc) {
367 Value casted =
368 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
369 scf::ValueVector viewAndIndices{casted};
370 viewAndIndices.insert(viewAndIndices.end(),
371 xferOp.getTransferRank(), zero);
372 b.create<scf::YieldOp>(loc, viewAndIndices);
373 })
374 ->getResults();
375}
376
377/// Given an `xferOp` for which:
378/// 1. `inBoundsCond` has been computed.
379/// 2. a memref of single vector `alloc` has been allocated.
380/// 3. it originally wrote to %view
381/// Produce IR resembling:
382/// ```
383/// %notInBounds = arith.xori %inBounds, %true
384/// scf.if (%notInBounds) {
385/// %3 = subview %alloc [...][...][...]
386/// %4 = subview %view [0, 0][...][...]
387/// linalg.copy(%3, %4)
388/// }
389/// ```
390static void createFullPartialLinalgCopy(RewriterBase &b,
391 vector::TransferWriteOp xferOp,
392 Value inBoundsCond, Value alloc) {
393 Location loc = xferOp.getLoc();
394 auto notInBounds = b.create<arith::XOrIOp>(
395 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
396 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
397 IRRewriter rewriter(b);
398 std::pair<Value, Value> copyArgs = createSubViewIntersection(
399 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
400 alloc);
401 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
402 b.create<scf::YieldOp>(loc, ValueRange{});
403 });
404}
405
406/// Given an `xferOp` for which:
407/// 1. `inBoundsCond` has been computed.
408/// 2. a memref of single vector `alloc` has been allocated.
409/// 3. it originally wrote to %view
410/// Produce IR resembling:
411/// ```
412/// %notInBounds = arith.xori %inBounds, %true
413/// scf.if (%notInBounds) {
414/// %2 = load %alloc : memref<vector<...>>
415/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
416/// }
417/// ```
418static void createFullPartialVectorTransferWrite(RewriterBase &b,
419 vector::TransferWriteOp xferOp,
420 Value inBoundsCond,
421 Value alloc) {
422 Location loc = xferOp.getLoc();
423 auto notInBounds = b.create<arith::XOrIOp>(
424 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
425 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
426 IRMapping mapping;
427 Value load = b.create<memref::LoadOp>(
428 loc,
429 b.create<vector::TypeCastOp>(
430 loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
431 ValueRange());
432 mapping.map(xferOp.getVector(), load);
433 b.clone(*xferOp.getOperation(), mapping);
434 b.create<scf::YieldOp>(loc, ValueRange{});
435 });
436}
437
438// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
439static Operation *getAutomaticAllocationScope(Operation *op) {
440 // Find the closest surrounding allocation scope that is not a known looping
441 // construct (putting alloca's in loops doesn't always lower to deallocation
442 // until the end of the loop).
443 Operation *scope = nullptr;
444 for (Operation *parent = op->getParentOp(); parent != nullptr;
445 parent = parent->getParentOp()) {
446 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
447 scope = parent;
448 if (!isa<scf::ForOp, affine::AffineForOp>(parent))
449 break;
450 }
451 assert(scope && "Expected op to be inside automatic allocation scope");
452 return scope;
453}
454
455/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
456/// masking) fastpath and a slowpath.
457///
458/// For vector.transfer_read:
459/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
460/// newly created conditional upon function return.
461/// To accomodate for the fact that the original vector.transfer indexing may be
462/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
463/// scf.if op returns a view and values of type index.
464///
465/// Example (a 2-D vector.transfer_read):
466/// ```
467/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
468/// ```
469/// is transformed into:
470/// ```
471/// %1:3 = scf.if (%inBounds) {
472/// // fastpath, direct cast
473/// memref.cast %A: memref<A...> to compatibleMemRefType
474/// scf.yield %view : compatibleMemRefType, index, index
475/// } else {
476/// // slowpath, not in-bounds vector.transfer or linalg.copy.
477/// memref.cast %alloc: memref<B...> to compatibleMemRefType
478/// scf.yield %4 : compatibleMemRefType, index, index
479// }
480/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
481/// ```
482/// where `alloc` is a top of the function alloca'ed buffer of one vector.
483///
484/// For vector.transfer_write:
485/// There are 2 conditional blocks. First a block to decide which memref and
486/// indices to use for an unmasked, inbounds write. Then a conditional block to
487/// further copy a partial buffer into the final result in the slow path case.
488///
489/// Example (a 2-D vector.transfer_write):
490/// ```
491/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
492/// ```
493/// is transformed into:
494/// ```
495/// %1:3 = scf.if (%inBounds) {
496/// memref.cast %A: memref<A...> to compatibleMemRefType
497/// scf.yield %view : compatibleMemRefType, index, index
498/// } else {
499/// memref.cast %alloc: memref<B...> to compatibleMemRefType
500/// scf.yield %4 : compatibleMemRefType, index, index
501/// }
502/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
503/// true]}
504/// scf.if (%notInBounds) {
505/// // slowpath: not in-bounds vector.transfer or linalg.copy.
506/// }
507/// ```
508/// where `alloc` is a top of the function alloca'ed buffer of one vector.
509///
510/// Preconditions:
511/// 1. `xferOp.getPermutationMap()` must be a minor identity map
512/// 2. the rank of the `xferOp.getBase()` and the rank of the
513/// `xferOp.getVector()` must be equal. This will be relaxed in the future
514/// but requires rank-reducing subviews.
515LogicalResult mlir::vector::splitFullAndPartialTransfer(
516 RewriterBase &b, VectorTransferOpInterface xferOp,
517 VectorTransformsOptions options, scf::IfOp *ifOp) {
518 if (options.vectorTransferSplit == VectorTransferSplit::None)
519 return failure();
520
521 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
522 auto inBoundsAttr = b.getBoolArrayAttr(bools);
523 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
524 b.modifyOpInPlace(xferOp, [&]() {
525 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
526 });
527 return success();
528 }
529
530 // Assert preconditions. Additionally, keep the variables in an inner scope to
531 // ensure they aren't used in the wrong scopes further down.
532 {
533 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
534 "Expected splitFullAndPartialTransferPrecondition to hold");
535
536 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
537 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
538
539 if (!(xferReadOp || xferWriteOp))
540 return failure();
541 if (xferWriteOp && xferWriteOp.getMask())
542 return failure();
543 if (xferReadOp && xferReadOp.getMask())
544 return failure();
545 }
546
547 RewriterBase::InsertionGuard guard(b);
548 b.setInsertionPoint(xferOp);
549 Value inBoundsCond = createInBoundsCond(
550 b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
551 if (!inBoundsCond)
552 return failure();
553
554 // Top of the function `alloc` for transient storage.
555 Value alloc;
556 {
557 RewriterBase::InsertionGuard guard(b);
558 Operation *scope = getAutomaticAllocationScope(xferOp);
559 assert(scope->getNumRegions() == 1 &&
560 "AutomaticAllocationScope with >1 regions");
561 b.setInsertionPointToStart(&scope->getRegion(index: 0).front());
562 auto shape = xferOp.getVectorType().getShape();
563 Type elementType = xferOp.getVectorType().getElementType();
564 alloc = b.create<memref::AllocaOp>(scope->getLoc(),
565 MemRefType::get(shape, elementType),
566 ValueRange{}, b.getI64IntegerAttr(32));
567 }
568
569 MemRefType compatibleMemRefType =
570 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
571 cast<MemRefType>(alloc.getType()));
572 if (!compatibleMemRefType)
573 return failure();
574
575 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
576 b.getIndexType());
577 returnTypes[0] = compatibleMemRefType;
578
579 if (auto xferReadOp =
580 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
581 // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
582 scf::IfOp fullPartialIfOp =
583 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
584 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
585 inBoundsCond,
586 compatibleMemRefType, alloc)
587 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
588 inBoundsCond, compatibleMemRefType,
589 alloc);
590 if (ifOp)
591 *ifOp = fullPartialIfOp;
592
593 // Set existing read op to in-bounds, it always reads from a full buffer.
594 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
595 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
596
597 b.modifyOpInPlace(xferOp, [&]() {
598 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
599 });
600
601 return success();
602 }
603
604 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
605
606 // Decide which location to write the entire vector to.
607 auto memrefAndIndices = getLocationToWriteFullVec(
608 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
609
610 // Do an in bounds write to either the output or the extra allocated buffer.
611 // The operation is cloned to prevent deleting information needed for the
612 // later IR creation.
613 IRMapping mapping;
614 mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
615 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
616 auto *clone = b.clone(*xferWriteOp, mapping);
617 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
618
619 // Create a potential copy from the allocated buffer to the final output in
620 // the slow path case.
621 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
622 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
623 else
624 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
625
626 b.eraseOp(op: xferOp);
627
628 return success();
629}
630
631namespace {
632/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
633/// may take an extra filter to perform selection at a finer granularity.
634struct VectorTransferFullPartialRewriter : public RewritePattern {
635 using FilterConstraintType =
636 std::function<LogicalResult(VectorTransferOpInterface op)>;
637
638 explicit VectorTransferFullPartialRewriter(
639 MLIRContext *context,
640 VectorTransformsOptions options = VectorTransformsOptions(),
641 FilterConstraintType filter =
642 [](VectorTransferOpInterface op) { return success(); },
643 PatternBenefit benefit = 1)
644 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
645 filter(std::move(filter)) {}
646
647 /// Performs the rewrite.
648 LogicalResult matchAndRewrite(Operation *op,
649 PatternRewriter &rewriter) const override;
650
651private:
652 VectorTransformsOptions options;
653 FilterConstraintType filter;
654};
655
656} // namespace
657
658LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
659 Operation *op, PatternRewriter &rewriter) const {
660 auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
661 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
662 failed(filter(xferOp)))
663 return failure();
664 return splitFullAndPartialTransfer(rewriter, xferOp, options);
665}
666
667void mlir::vector::populateVectorTransferFullPartialPatterns(
668 RewritePatternSet &patterns, const VectorTransformsOptions &options) {
669 patterns.add<VectorTransferFullPartialRewriter>(arg: patterns.getContext(),
670 args: options);
671}
672

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp