1 | //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements functions concerned with optimizing transfer_read and |
10 | // transfer_write ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
21 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
22 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
23 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
24 | #include "mlir/IR/Dominance.h" |
25 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/StringRef.h" |
28 | #include "llvm/Support/Debug.h" |
29 | |
30 | #define DEBUG_TYPE "vector-transfer-opt" |
31 | |
32 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
33 | |
34 | using namespace mlir; |
35 | |
36 | /// Return the ancestor op in the region or nullptr if the region is not |
37 | /// an ancestor of the op. |
38 | static Operation *findAncestorOpInRegion(Region *region, Operation *op) { |
39 | for (; op != nullptr && op->getParentRegion() != region; |
40 | op = op->getParentOp()) |
41 | ; |
42 | return op; |
43 | } |
44 | |
45 | namespace { |
46 | |
47 | class TransferOptimization { |
48 | public: |
49 | TransferOptimization(RewriterBase &rewriter, Operation *op) |
50 | : rewriter(rewriter), dominators(op), postDominators(op) {} |
51 | void deadStoreOp(vector::TransferWriteOp); |
52 | void storeToLoadForwarding(vector::TransferReadOp); |
53 | void removeDeadOp() { |
54 | for (Operation *op : opToErase) |
55 | rewriter.eraseOp(op); |
56 | opToErase.clear(); |
57 | } |
58 | |
59 | private: |
60 | RewriterBase &rewriter; |
61 | bool isReachable(Operation *start, Operation *dest); |
62 | DominanceInfo dominators; |
63 | PostDominanceInfo postDominators; |
64 | std::vector<Operation *> opToErase; |
65 | }; |
66 | |
67 | } // namespace |
68 | /// Return true if there is a path from start operation to dest operation, |
69 | /// otherwise return false. The operations have to be in the same region. |
70 | bool TransferOptimization::isReachable(Operation *start, Operation *dest) { |
71 | assert(start->getParentRegion() == dest->getParentRegion() && |
72 | "This function only works for ops i the same region"); |
73 | // Simple case where the start op dominate the destination. |
74 | if (dominators.dominates(a: start, b: dest)) |
75 | return true; |
76 | return start->getBlock()->isReachable(other: dest->getBlock()); |
77 | } |
78 | |
79 | /// For transfer_write to overwrite fully another transfer_write must: |
80 | /// 1. Access the same memref with the same indices and vector type. |
81 | /// 2. Post-dominate the other transfer_write operation. |
82 | /// If several candidates are available, one must be post-dominated by all the |
83 | /// others since they are all post-dominating the same transfer_write. We only |
84 | /// consider the transfer_write post-dominated by all the other candidates as |
85 | /// this will be the first transfer_write executed after the potentially dead |
86 | /// transfer_write. |
87 | /// If we found such an overwriting transfer_write we know that the original |
88 | /// transfer_write is dead if all reads that can be reached from the potentially |
89 | /// dead transfer_write are dominated by the overwriting transfer_write. |
90 | void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { |
91 | LLVM_DEBUG(DBGS() << "Candidate for dead store: "<< *write.getOperation() |
92 | << "\n"); |
93 | llvm::SmallVector<Operation *, 8> blockingAccesses; |
94 | Operation *firstOverwriteCandidate = nullptr; |
95 | Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(write.getBase())); |
96 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
97 | source.getUsers().end()); |
98 | llvm::SmallDenseSet<Operation *, 32> processed; |
99 | while (!users.empty()) { |
100 | Operation *user = users.pop_back_val(); |
101 | // If the user has already been processed skip. |
102 | if (!processed.insert(V: user).second) |
103 | continue; |
104 | if (isa<ViewLikeOpInterface>(user)) { |
105 | users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end()); |
106 | continue; |
107 | } |
108 | if (isMemoryEffectFree(op: user)) |
109 | continue; |
110 | if (user == write.getOperation()) |
111 | continue; |
112 | if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { |
113 | // Check candidate that can override the store. |
114 | if (memref::isSameViewOrTrivialAlias( |
115 | a: cast<MemrefValue>(nextWrite.getBase()), |
116 | b: cast<MemrefValue>(write.getBase())) && |
117 | checkSameValueWAW(nextWrite, write) && |
118 | postDominators.postDominates(nextWrite, write)) { |
119 | if (firstOverwriteCandidate == nullptr || |
120 | postDominators.postDominates(firstOverwriteCandidate, nextWrite)) |
121 | firstOverwriteCandidate = nextWrite; |
122 | else |
123 | assert( |
124 | postDominators.postDominates(nextWrite, firstOverwriteCandidate)); |
125 | continue; |
126 | } |
127 | } |
128 | if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { |
129 | // Don't need to consider disjoint accesses. |
130 | if (vector::isDisjointTransferSet( |
131 | cast<VectorTransferOpInterface>(write.getOperation()), |
132 | cast<VectorTransferOpInterface>(transferOp.getOperation()), |
133 | /*testDynamicValueUsingBounds=*/true)) |
134 | continue; |
135 | } |
136 | blockingAccesses.push_back(Elt: user); |
137 | } |
138 | if (firstOverwriteCandidate == nullptr) |
139 | return; |
140 | Region *topRegion = firstOverwriteCandidate->getParentRegion(); |
141 | Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); |
142 | assert(writeAncestor && |
143 | "write op should be recursively part of the top region"); |
144 | |
145 | for (Operation *access : blockingAccesses) { |
146 | Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access); |
147 | // TODO: if the access and write have the same ancestor we could recurse in |
148 | // the region to know if the access is reachable with more precision. |
149 | if (accessAncestor == nullptr || |
150 | !isReachable(start: writeAncestor, dest: accessAncestor)) |
151 | continue; |
152 | if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) { |
153 | LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " |
154 | << *accessAncestor << "\n"); |
155 | return; |
156 | } |
157 | } |
158 | LLVM_DEBUG(DBGS() << "Found dead store: "<< *write.getOperation() |
159 | << " overwritten by: "<< *firstOverwriteCandidate << "\n"); |
160 | opToErase.push_back(write.getOperation()); |
161 | } |
162 | |
163 | /// A transfer_write candidate to storeToLoad forwarding must: |
164 | /// 1. Access the same memref with the same indices and vector type as the |
165 | /// transfer_read. |
166 | /// 2. Dominate the transfer_read operation. |
167 | /// If several candidates are available, one must be dominated by all the others |
168 | /// since they are all dominating the same transfer_read. We only consider the |
169 | /// transfer_write dominated by all the other candidates as this will be the |
170 | /// last transfer_write executed before the transfer_read. |
171 | /// If we found such a candidate we can do the forwarding if all the other |
172 | /// potentially aliasing ops that may reach the transfer_read are post-dominated |
173 | /// by the transfer_write. |
174 | void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { |
175 | if (read.hasOutOfBoundsDim()) |
176 | return; |
177 | LLVM_DEBUG(DBGS() << "Candidate for Forwarding: "<< *read.getOperation() |
178 | << "\n"); |
179 | SmallVector<Operation *, 8> blockingWrites; |
180 | vector::TransferWriteOp lastwrite = nullptr; |
181 | Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(read.getBase())); |
182 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
183 | source.getUsers().end()); |
184 | llvm::SmallDenseSet<Operation *, 32> processed; |
185 | while (!users.empty()) { |
186 | Operation *user = users.pop_back_val(); |
187 | // If the user has already been processed skip. |
188 | if (!processed.insert(V: user).second) |
189 | continue; |
190 | if (isa<ViewLikeOpInterface>(user)) { |
191 | users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end()); |
192 | continue; |
193 | } |
194 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
195 | continue; |
196 | if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { |
197 | // If there is a write, but we can prove that it is disjoint we can ignore |
198 | // the write. |
199 | if (vector::isDisjointTransferSet( |
200 | cast<VectorTransferOpInterface>(write.getOperation()), |
201 | cast<VectorTransferOpInterface>(read.getOperation()), |
202 | /*testDynamicValueUsingBounds=*/true)) |
203 | continue; |
204 | if (memref::isSameViewOrTrivialAlias( |
205 | a: cast<MemrefValue>(read.getBase()), |
206 | b: cast<MemrefValue>(write.getBase())) && |
207 | dominators.dominates(write, read) && checkSameValueRAW(write, read)) { |
208 | if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) |
209 | lastwrite = write; |
210 | else |
211 | assert(dominators.dominates(write, lastwrite)); |
212 | continue; |
213 | } |
214 | } |
215 | blockingWrites.push_back(Elt: user); |
216 | } |
217 | |
218 | if (lastwrite == nullptr) |
219 | return; |
220 | |
221 | Region *topRegion = lastwrite->getParentRegion(); |
222 | Operation *readAncestor = findAncestorOpInRegion(topRegion, read); |
223 | assert(readAncestor && |
224 | "read op should be recursively part of the top region"); |
225 | |
226 | for (Operation *write : blockingWrites) { |
227 | Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write); |
228 | // TODO: if the store and read have the same ancestor we could recurse in |
229 | // the region to know if the read is reachable with more precision. |
230 | if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor)) |
231 | continue; |
232 | if (!postDominators.postDominates(lastwrite, write)) { |
233 | LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " |
234 | << *write << "\n"); |
235 | return; |
236 | } |
237 | } |
238 | |
239 | LLVM_DEBUG(DBGS() << "Forward value from "<< *lastwrite.getOperation() |
240 | << " to: "<< *read.getOperation() << "\n"); |
241 | read.replaceAllUsesWith(lastwrite.getVector()); |
242 | opToErase.push_back(read.getOperation()); |
243 | } |
244 | |
245 | /// Converts OpFoldResults to int64_t shape without unit dims. |
246 | static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { |
247 | SmallVector<int64_t> reducedShape; |
248 | for (const auto size : mixedSizes) { |
249 | if (llvm::dyn_cast_if_present<Value>(Val: size)) { |
250 | reducedShape.push_back(ShapedType::kDynamic); |
251 | continue; |
252 | } |
253 | |
254 | auto value = cast<IntegerAttr>(cast<Attribute>(Val: size)).getValue(); |
255 | if (value == 1) |
256 | continue; |
257 | reducedShape.push_back(Elt: value.getSExtValue()); |
258 | } |
259 | return reducedShape; |
260 | } |
261 | |
262 | /// Drops unit dimensions from the input MemRefType. |
263 | static MemRefType dropUnitDims(MemRefType inputType, |
264 | ArrayRef<OpFoldResult> offsets, |
265 | ArrayRef<OpFoldResult> sizes, |
266 | ArrayRef<OpFoldResult> strides) { |
267 | auto targetShape = getReducedShape(mixedSizes: sizes); |
268 | MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType( |
269 | targetShape, inputType, offsets, sizes, strides); |
270 | return rankReducedType.canonicalizeStridedLayout(); |
271 | } |
272 | |
273 | /// Creates a rank-reducing memref.subview op that drops unit dims from its |
274 | /// input. Or just returns the input if it was already without unit dims. |
275 | static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, |
276 | mlir::Location loc, |
277 | Value input) { |
278 | MemRefType inputType = cast<MemRefType>(input.getType()); |
279 | SmallVector<OpFoldResult> offsets(inputType.getRank(), |
280 | rewriter.getIndexAttr(0)); |
281 | SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input); |
282 | SmallVector<OpFoldResult> strides(inputType.getRank(), |
283 | rewriter.getIndexAttr(1)); |
284 | MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); |
285 | |
286 | if (resultType.canonicalizeStridedLayout() == |
287 | inputType.canonicalizeStridedLayout()) |
288 | return input; |
289 | return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, |
290 | sizes, strides); |
291 | } |
292 | |
293 | /// Returns the number of dims that aren't unit dims. |
294 | static int getReducedRank(ArrayRef<int64_t> shape) { |
295 | return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; }); |
296 | } |
297 | |
298 | /// Trims non-scalable one dimensions from `oldType` and returns the result |
299 | /// type. |
300 | static VectorType trimNonScalableUnitDims(VectorType oldType) { |
301 | SmallVector<int64_t> newShape; |
302 | SmallVector<bool> newScalableDims; |
303 | for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { |
304 | if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) |
305 | continue; |
306 | newShape.push_back(dimSize); |
307 | newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); |
308 | } |
309 | return VectorType::get(newShape, oldType.getElementType(), newScalableDims); |
310 | } |
311 | |
312 | // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. |
313 | static FailureOr<Value> |
314 | createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, |
315 | vector::CreateMaskOp op) { |
316 | auto type = op.getType(); |
317 | VectorType reducedType = trimNonScalableUnitDims(type); |
318 | if (reducedType.getRank() == type.getRank()) |
319 | return failure(); |
320 | |
321 | SmallVector<Value> reducedOperands; |
322 | for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( |
323 | type.getShape(), type.getScalableDims(), op.getOperands())) { |
324 | if (dim == 1 && !dimIsScalable) { |
325 | // If the mask for the unit dim is not a constant of 1, do nothing. |
326 | auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); |
327 | if (!constant || (constant.value() != 1)) |
328 | return failure(); |
329 | continue; |
330 | } |
331 | reducedOperands.push_back(operand); |
332 | } |
333 | return rewriter |
334 | .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) |
335 | .getResult(); |
336 | } |
337 | |
338 | namespace { |
339 | |
340 | /// Rewrites `vector.transfer_read` ops where the source has unit dims, by |
341 | /// inserting a memref.subview dropping those unit dims. The vector shapes are |
342 | /// also reduced accordingly. |
343 | class TransferReadDropUnitDimsPattern |
344 | : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { |
345 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
346 | |
347 | FailureOr<Value> |
348 | matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, |
349 | vector::MaskingOpInterface maskingOp, |
350 | PatternRewriter &rewriter) const override { |
351 | auto loc = transferReadOp.getLoc(); |
352 | Value vector = transferReadOp.getVector(); |
353 | VectorType vectorType = cast<VectorType>(vector.getType()); |
354 | Value source = transferReadOp.getBase(); |
355 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
356 | // TODO: support tensor types. |
357 | if (!sourceType) |
358 | return failure(); |
359 | // TODO: generalize this pattern, relax the requirements here. |
360 | if (transferReadOp.hasOutOfBoundsDim()) |
361 | return failure(); |
362 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
363 | return failure(); |
364 | // Check if the source shape can be further reduced. |
365 | int reducedRank = getReducedRank(sourceType.getShape()); |
366 | if (reducedRank == sourceType.getRank()) |
367 | return failure(); |
368 | // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail |
369 | // out. |
370 | if (reducedRank == 0 && maskingOp) |
371 | return failure(); |
372 | // Check if the reduced vector shape matches the reduced source shape. |
373 | // Otherwise, this case is not supported yet. |
374 | VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); |
375 | if (reducedRank != reducedVectorType.getRank()) |
376 | return failure(); |
377 | if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { |
378 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
379 | })) |
380 | return failure(); |
381 | |
382 | Value maskOp = transferReadOp.getMask(); |
383 | if (maskOp) { |
384 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
385 | if (!createMaskOp) |
386 | return rewriter.notifyMatchFailure( |
387 | transferReadOp, "unsupported mask op, only 'vector.create_mask' is " |
388 | "currently supported"); |
389 | FailureOr<Value> rankReducedCreateMask = |
390 | createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); |
391 | if (failed(Result: rankReducedCreateMask)) |
392 | return failure(); |
393 | maskOp = *rankReducedCreateMask; |
394 | } |
395 | |
396 | Value reducedShapeSource = |
397 | rankReducingSubviewDroppingUnitDims(rewriter, loc, source); |
398 | Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
399 | SmallVector<Value> zeros(reducedRank, c0); |
400 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
401 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
402 | Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>( |
403 | loc, reducedVectorType, reducedShapeSource, zeros, identityMap, |
404 | transferReadOp.getPadding(), maskOp, |
405 | rewriter.getBoolArrayAttr(inBounds)); |
406 | |
407 | if (maskingOp) { |
408 | auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( |
409 | loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), |
410 | maskingOp.getMask()); |
411 | newTransferReadOp = mlir::vector::maskOperation( |
412 | builder&: rewriter, maskableOp: newTransferReadOp, mask: shapeCastMask); |
413 | } |
414 | |
415 | auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( |
416 | loc, vectorType, newTransferReadOp->getResults()[0]); |
417 | |
418 | return shapeCast; |
419 | } |
420 | }; |
421 | |
422 | /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) |
423 | /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The |
424 | /// vector shapes are also reduced accordingly. |
425 | class TransferWriteDropUnitDimsPattern |
426 | : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> { |
427 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
428 | |
429 | FailureOr<Value> |
430 | matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, |
431 | vector::MaskingOpInterface maskingOp, |
432 | PatternRewriter &rewriter) const override { |
433 | auto loc = transferWriteOp.getLoc(); |
434 | Value vector = transferWriteOp.getVector(); |
435 | VectorType vectorType = cast<VectorType>(vector.getType()); |
436 | Value source = transferWriteOp.getBase(); |
437 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
438 | // TODO: support tensor type. |
439 | if (!sourceType) |
440 | return failure(); |
441 | // TODO: generalize this pattern, relax the requirements here. |
442 | if (transferWriteOp.hasOutOfBoundsDim()) |
443 | return failure(); |
444 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
445 | return failure(); |
446 | // Check if the destination shape can be further reduced. |
447 | int reducedRank = getReducedRank(sourceType.getShape()); |
448 | if (reducedRank == sourceType.getRank()) |
449 | return failure(); |
450 | // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail |
451 | // out. |
452 | if (reducedRank == 0 && maskingOp) |
453 | return failure(); |
454 | // Check if the reduced vector shape matches the reduced destination shape. |
455 | // Otherwise, this case is not supported yet. |
456 | VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); |
457 | if (reducedRank != reducedVectorType.getRank()) |
458 | return failure(); |
459 | if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { |
460 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
461 | })) |
462 | return failure(); |
463 | |
464 | Value maskOp = transferWriteOp.getMask(); |
465 | if (maskOp) { |
466 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
467 | if (!createMaskOp) |
468 | return rewriter.notifyMatchFailure( |
469 | transferWriteOp, |
470 | "unsupported mask op, only 'vector.create_mask' is " |
471 | "currently supported"); |
472 | FailureOr<Value> rankReducedCreateMask = |
473 | createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); |
474 | if (failed(Result: rankReducedCreateMask)) |
475 | return failure(); |
476 | maskOp = *rankReducedCreateMask; |
477 | } |
478 | Value reducedShapeSource = |
479 | rankReducingSubviewDroppingUnitDims(rewriter, loc, source); |
480 | Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
481 | SmallVector<Value> zeros(reducedRank, c0); |
482 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
483 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
484 | auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>( |
485 | loc, reducedVectorType, vector); |
486 | Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>( |
487 | loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, |
488 | maskOp, rewriter.getBoolArrayAttr(inBounds)); |
489 | |
490 | if (maskingOp) { |
491 | auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( |
492 | loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), |
493 | maskingOp.getMask()); |
494 | newXferWrite = |
495 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newXferWrite, mask: shapeCastMask); |
496 | } |
497 | |
498 | if (transferWriteOp.hasPureTensorSemantics()) |
499 | return newXferWrite->getResults()[0]; |
500 | |
501 | // With Memref semantics, there's no return value. Use empty value to signal |
502 | // success. |
503 | return Value(); |
504 | } |
505 | }; |
506 | |
507 | } // namespace |
508 | |
509 | /// Creates a memref.collapse_shape collapsing all inner dimensions of the |
510 | /// input starting at `firstDimToCollapse`. |
511 | static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, |
512 | Value input, int64_t firstDimToCollapse) { |
513 | ShapedType inputType = cast<ShapedType>(input.getType()); |
514 | if (inputType.getRank() == 1) |
515 | return input; |
516 | SmallVector<ReassociationIndices> reassociation; |
517 | for (int64_t i = 0; i < firstDimToCollapse; ++i) |
518 | reassociation.push_back(Elt: ReassociationIndices{i}); |
519 | ReassociationIndices collapsedIndices; |
520 | for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) |
521 | collapsedIndices.push_back(Elt: i); |
522 | reassociation.push_back(Elt: collapsedIndices); |
523 | return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); |
524 | } |
525 | |
526 | /// Returns the new indices that collapses the inner dimensions starting from |
527 | /// the `firstDimToCollapse` dimension. |
528 | static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, |
529 | Location loc, |
530 | ArrayRef<int64_t> shape, |
531 | ValueRange indices, |
532 | int64_t firstDimToCollapse) { |
533 | assert(firstDimToCollapse < static_cast<int64_t>(indices.size())); |
534 | |
535 | // If all the collapsed indices are zero then no extra logic is needed. |
536 | // Otherwise, a new offset/index has to be computed. |
537 | SmallVector<Value> indicesAfterCollapsing( |
538 | indices.begin(), indices.begin() + firstDimToCollapse); |
539 | SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, |
540 | indices.end()); |
541 | if (llvm::all_of(Range&: indicesToCollapse, P: isZeroInteger)) { |
542 | indicesAfterCollapsing.push_back(Elt: indicesToCollapse[0]); |
543 | return indicesAfterCollapsing; |
544 | } |
545 | |
546 | // Compute the remaining trailing index/offset required for reading from |
547 | // the collapsed memref: |
548 | // |
549 | // offset = 0 |
550 | // for (i = firstDimToCollapse; i < outputRank; ++i) |
551 | // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] |
552 | // |
553 | // For this example: |
554 | // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : |
555 | // memref<1x43x2xi32>, vector<1x2xi32> |
556 | // which would be collapsed to: |
557 | // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : |
558 | // memref<1x86xi32>, vector<2xi32> |
559 | // one would get the following offset: |
560 | // %offset = %arg0 * 43 |
561 | OpFoldResult collapsedOffset = |
562 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult(); |
563 | |
564 | auto collapsedStrides = computeSuffixProduct( |
565 | sizes: ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end())); |
566 | |
567 | // Compute the collapsed offset. |
568 | auto &&[collapsedExpr, collapsedVals] = |
569 | computeLinearIndex(sourceOffset: collapsedOffset, strides: collapsedStrides, indices: indicesToCollapse); |
570 | collapsedOffset = affine::makeComposedFoldedAffineApply( |
571 | rewriter, loc, collapsedExpr, collapsedVals); |
572 | |
573 | if (auto value = dyn_cast<Value>(collapsedOffset)) { |
574 | indicesAfterCollapsing.push_back(Elt: value); |
575 | } else { |
576 | indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>( |
577 | location: loc, args: *getConstantIntValue(ofr: collapsedOffset))); |
578 | } |
579 | |
580 | return indicesAfterCollapsing; |
581 | } |
582 | |
583 | namespace { |
584 | |
585 | /// Rewrites contiguous row-major vector.transfer_read ops by inserting |
586 | /// memref.collapse_shape on the source so that the resulting |
587 | /// vector.transfer_read has a 1D source. Requires the source shape to be |
588 | /// already reduced i.e. without unit dims. |
589 | /// |
590 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
591 | /// the trailing dimension of the vector read is smaller than the provided |
592 | /// bitwidth. |
593 | class FlattenContiguousRowMajorTransferReadPattern |
594 | : public OpRewritePattern<vector::TransferReadOp> { |
595 | public: |
596 | FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, |
597 | unsigned vectorBitwidth, |
598 | PatternBenefit benefit) |
599 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
600 | targetVectorBitwidth(vectorBitwidth) {} |
601 | |
602 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
603 | PatternRewriter &rewriter) const override { |
604 | auto loc = transferReadOp.getLoc(); |
605 | Value vector = transferReadOp.getVector(); |
606 | VectorType vectorType = cast<VectorType>(vector.getType()); |
607 | auto source = transferReadOp.getBase(); |
608 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
609 | |
610 | // 0. Check pre-conditions |
611 | // Contiguity check is valid on tensors only. |
612 | if (!sourceType) |
613 | return failure(); |
614 | // If this is already 0D/1D, there's nothing to do. |
615 | if (vectorType.getRank() <= 1) |
616 | return failure(); |
617 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
618 | return failure(); |
619 | unsigned trailingVectorDimBitwidth = |
620 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
621 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
622 | return failure(); |
623 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType)) |
624 | return failure(); |
625 | // TODO: generalize this pattern, relax the requirements here. |
626 | if (transferReadOp.hasOutOfBoundsDim()) |
627 | return failure(); |
628 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
629 | return failure(); |
630 | if (transferReadOp.getMask()) |
631 | return failure(); |
632 | |
633 | int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); |
634 | |
635 | // 1. Collapse the source memref |
636 | Value collapsedSource = |
637 | collapseInnerDims(rewriter, loc, source, firstDimToCollapse); |
638 | MemRefType collapsedSourceType = |
639 | cast<MemRefType>(collapsedSource.getType()); |
640 | int64_t collapsedRank = collapsedSourceType.getRank(); |
641 | assert(collapsedRank == firstDimToCollapse + 1); |
642 | |
643 | // 2. Generate input args for a new vector.transfer_read that will read |
644 | // from the collapsed memref. |
645 | // 2.1. New dim exprs + affine map |
646 | SmallVector<AffineExpr, 1> dimExprs{ |
647 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
648 | auto collapsedMap = |
649 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
650 | |
651 | // 2.2 New indices |
652 | SmallVector<Value> collapsedIndices = |
653 | getCollapsedIndices(rewriter, loc, sourceType.getShape(), |
654 | transferReadOp.getIndices(), firstDimToCollapse); |
655 | |
656 | // 3. Create new vector.transfer_read that reads from the collapsed memref |
657 | VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, |
658 | vectorType.getElementType()); |
659 | vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( |
660 | loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); |
661 | flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
662 | |
663 | // 4. Replace the old transfer_read with the new one reading from the |
664 | // collapsed shape |
665 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
666 | transferReadOp, cast<VectorType>(vector.getType()), flatRead); |
667 | return success(); |
668 | } |
669 | |
670 | private: |
671 | // Minimum bitwidth that the trailing vector dimension should have after |
672 | // flattening. |
673 | unsigned targetVectorBitwidth; |
674 | }; |
675 | |
676 | /// Rewrites contiguous row-major vector.transfer_write ops by inserting |
677 | /// memref.collapse_shape on the source so that the resulting |
678 | /// vector.transfer_write has a 1D source. Requires the source shape to be |
679 | /// already reduced i.e. without unit dims. |
680 | /// |
681 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
682 | /// the trailing dimension of the vector read is smaller than the provided |
683 | /// bitwidth. |
684 | class FlattenContiguousRowMajorTransferWritePattern |
685 | : public OpRewritePattern<vector::TransferWriteOp> { |
686 | public: |
687 | FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, |
688 | unsigned vectorBitwidth, |
689 | PatternBenefit benefit) |
690 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
691 | targetVectorBitwidth(vectorBitwidth) {} |
692 | |
693 | LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, |
694 | PatternRewriter &rewriter) const override { |
695 | auto loc = transferWriteOp.getLoc(); |
696 | Value vector = transferWriteOp.getVector(); |
697 | VectorType vectorType = cast<VectorType>(vector.getType()); |
698 | Value source = transferWriteOp.getBase(); |
699 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
700 | |
701 | // 0. Check pre-conditions |
702 | // Contiguity check is valid on tensors only. |
703 | if (!sourceType) |
704 | return failure(); |
705 | // If this is already 0D/1D, there's nothing to do. |
706 | if (vectorType.getRank() <= 1) |
707 | // Already 0D/1D, nothing to do. |
708 | return failure(); |
709 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
710 | return failure(); |
711 | unsigned trailingVectorDimBitwidth = |
712 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
713 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
714 | return failure(); |
715 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType)) |
716 | return failure(); |
717 | // TODO: generalize this pattern, relax the requirements here. |
718 | if (transferWriteOp.hasOutOfBoundsDim()) |
719 | return failure(); |
720 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
721 | return failure(); |
722 | if (transferWriteOp.getMask()) |
723 | return failure(); |
724 | |
725 | int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); |
726 | |
727 | // 1. Collapse the source memref |
728 | Value collapsedSource = |
729 | collapseInnerDims(rewriter, loc, source, firstDimToCollapse); |
730 | MemRefType collapsedSourceType = |
731 | cast<MemRefType>(collapsedSource.getType()); |
732 | int64_t collapsedRank = collapsedSourceType.getRank(); |
733 | assert(collapsedRank == firstDimToCollapse + 1); |
734 | |
735 | // 2. Generate input args for a new vector.transfer_read that will read |
736 | // from the collapsed memref. |
737 | // 2.1. New dim exprs + affine map |
738 | SmallVector<AffineExpr, 1> dimExprs{ |
739 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
740 | auto collapsedMap = |
741 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
742 | |
743 | // 2.2 New indices |
744 | SmallVector<Value> collapsedIndices = |
745 | getCollapsedIndices(rewriter, loc, sourceType.getShape(), |
746 | transferWriteOp.getIndices(), firstDimToCollapse); |
747 | |
748 | // 3. Create new vector.transfer_write that writes to the collapsed memref |
749 | VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, |
750 | vectorType.getElementType()); |
751 | Value flatVector = |
752 | rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); |
753 | vector::TransferWriteOp flatWrite = |
754 | rewriter.create<vector::TransferWriteOp>( |
755 | loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); |
756 | flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
757 | |
758 | // 4. Replace the old transfer_write with the new one writing the |
759 | // collapsed shape |
760 | rewriter.eraseOp(op: transferWriteOp); |
761 | return success(); |
762 | } |
763 | |
764 | private: |
765 | // Minimum bitwidth that the trailing vector dimension should have after |
766 | // flattening. |
767 | unsigned targetVectorBitwidth; |
768 | }; |
769 | |
770 | /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` |
771 | /// to `memref.load` patterns. The `match` method is shared for both |
772 | /// `vector.extract` and `vector.extract_element`. |
773 | template <class VectorExtractOp> |
774 | class RewriteScalarExtractOfTransferReadBase |
775 | : public OpRewritePattern<VectorExtractOp> { |
776 | using Base = OpRewritePattern<VectorExtractOp>; |
777 | |
778 | public: |
779 | RewriteScalarExtractOfTransferReadBase(MLIRContext *context, |
780 | PatternBenefit benefit, |
781 | bool allowMultipleUses) |
782 | : Base(context, benefit), allowMultipleUses(allowMultipleUses) {} |
783 | |
784 | LogicalResult match(VectorExtractOp extractOp) const { |
785 | auto xferOp = |
786 | extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); |
787 | if (!xferOp) |
788 | return failure(); |
789 | // Check that we are extracting a scalar and not a sub-vector. |
790 | if (isa<VectorType>(extractOp.getResult().getType())) |
791 | return failure(); |
792 | // If multiple uses are not allowed, check if xfer has a single use. |
793 | if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) |
794 | return failure(); |
795 | // If multiple uses are allowed, check if all the xfer uses are extract ops. |
796 | if (allowMultipleUses && |
797 | !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { |
798 | return isa<vector::ExtractOp, vector::ExtractElementOp>( |
799 | use.getOwner()); |
800 | })) |
801 | return failure(); |
802 | // Mask not supported. |
803 | if (xferOp.getMask()) |
804 | return failure(); |
805 | // Map not supported. |
806 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
807 | return failure(); |
808 | // Cannot rewrite if the indices may be out of bounds. |
809 | if (xferOp.hasOutOfBoundsDim()) |
810 | return failure(); |
811 | return success(); |
812 | } |
813 | |
814 | private: |
815 | bool allowMultipleUses; |
816 | }; |
817 | |
818 | /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. |
819 | /// |
820 | /// All the users of the transfer op must be either `vector.extractelement` or |
821 | /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite |
822 | /// transfer ops with any number of users. Otherwise, rewrite only if the |
823 | /// extract op is the single user of the transfer op. Rewriting a single |
824 | /// vector load with multiple scalar loads may negatively affect performance. |
825 | class RewriteScalarExtractElementOfTransferRead |
826 | : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { |
827 | using RewriteScalarExtractOfTransferReadBase:: |
828 | RewriteScalarExtractOfTransferReadBase; |
829 | |
830 | LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp, |
831 | PatternRewriter &rewriter) const override { |
832 | if (failed(match(extractOp))) |
833 | return failure(); |
834 | |
835 | // Construct scalar load. |
836 | auto loc = extractOp.getLoc(); |
837 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
838 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
839 | xferOp.getIndices().end()); |
840 | if (extractOp.getPosition()) { |
841 | AffineExpr sym0, sym1; |
842 | bindSymbols(extractOp.getContext(), sym0, sym1); |
843 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
844 | rewriter, loc, sym0 + sym1, |
845 | {newIndices[newIndices.size() - 1], extractOp.getPosition()}); |
846 | if (auto value = dyn_cast<Value>(ofr)) { |
847 | newIndices[newIndices.size() - 1] = value; |
848 | } else { |
849 | newIndices[newIndices.size() - 1] = |
850 | rewriter.create<arith::ConstantIndexOp>(loc, |
851 | *getConstantIntValue(ofr)); |
852 | } |
853 | } |
854 | if (isa<MemRefType>(xferOp.getBase().getType())) { |
855 | rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(), |
856 | newIndices); |
857 | } else { |
858 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
859 | extractOp, xferOp.getBase(), newIndices); |
860 | } |
861 | |
862 | return success(); |
863 | } |
864 | }; |
865 | |
866 | /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. |
867 | /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. |
868 | /// |
869 | /// All the users of the transfer op must be either `vector.extractelement` or |
870 | /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite |
871 | /// transfer ops with any number of users. Otherwise, rewrite only if the |
872 | /// extract op is the single user of the transfer op. Rewriting a single |
873 | /// vector load with multiple scalar loads may negatively affect performance. |
874 | class RewriteScalarExtractOfTransferRead |
875 | : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { |
876 | using RewriteScalarExtractOfTransferReadBase:: |
877 | RewriteScalarExtractOfTransferReadBase; |
878 | |
879 | LogicalResult matchAndRewrite(vector::ExtractOp extractOp, |
880 | PatternRewriter &rewriter) const override { |
881 | if (failed(match(extractOp))) |
882 | return failure(); |
883 | |
884 | // Construct scalar load. |
885 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
886 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
887 | xferOp.getIndices().end()); |
888 | for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { |
889 | assert(isa<Attribute>(pos) && "Unexpected non-constant index"); |
890 | int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt(); |
891 | int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; |
892 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
893 | rewriter, extractOp.getLoc(), |
894 | rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); |
895 | if (auto value = dyn_cast<Value>(ofr)) { |
896 | newIndices[idx] = value; |
897 | } else { |
898 | newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( |
899 | extractOp.getLoc(), *getConstantIntValue(ofr)); |
900 | } |
901 | } |
902 | if (isa<MemRefType>(xferOp.getBase().getType())) { |
903 | rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(), |
904 | newIndices); |
905 | } else { |
906 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
907 | extractOp, xferOp.getBase(), newIndices); |
908 | } |
909 | |
910 | return success(); |
911 | } |
912 | }; |
913 | |
914 | /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) |
915 | /// to memref.store. |
916 | class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { |
917 | using OpRewritePattern::OpRewritePattern; |
918 | |
919 | LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, |
920 | PatternRewriter &rewriter) const override { |
921 | // Must be a scalar write. |
922 | auto vecType = xferOp.getVectorType(); |
923 | if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) |
924 | return failure(); |
925 | // Mask not supported. |
926 | if (xferOp.getMask()) |
927 | return failure(); |
928 | // Map not supported. |
929 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
930 | return failure(); |
931 | // Only float and integer element types are supported. |
932 | Value scalar = |
933 | rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector()); |
934 | // Construct a scalar store. |
935 | if (isa<MemRefType>(xferOp.getBase().getType())) { |
936 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
937 | xferOp, scalar, xferOp.getBase(), xferOp.getIndices()); |
938 | } else { |
939 | rewriter.replaceOpWithNewOp<tensor::InsertOp>( |
940 | xferOp, scalar, xferOp.getBase(), xferOp.getIndices()); |
941 | } |
942 | return success(); |
943 | } |
944 | }; |
945 | |
946 | } // namespace |
947 | |
948 | void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, |
949 | Operation *rootOp) { |
950 | TransferOptimization opt(rewriter, rootOp); |
951 | // Run store to load forwarding first since it can expose more dead store |
952 | // opportunity. |
953 | rootOp->walk([&](vector::TransferReadOp read) { |
954 | if (isa<MemRefType>(read.getShapedType())) |
955 | opt.storeToLoadForwarding(read: read); |
956 | }); |
957 | opt.removeDeadOp(); |
958 | rootOp->walk([&](vector::TransferWriteOp write) { |
959 | if (isa<MemRefType>(write.getShapedType())) |
960 | opt.deadStoreOp(write: write); |
961 | }); |
962 | opt.removeDeadOp(); |
963 | } |
964 | |
965 | void mlir::vector::populateScalarVectorTransferLoweringPatterns( |
966 | RewritePatternSet &patterns, PatternBenefit benefit, |
967 | bool allowMultipleUses) { |
968 | patterns.add<RewriteScalarExtractElementOfTransferRead, |
969 | RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(), |
970 | args&: benefit, args&: allowMultipleUses); |
971 | patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit); |
972 | } |
973 | |
974 | void mlir::vector::populateVectorTransferDropUnitDimsPatterns( |
975 | RewritePatternSet &patterns, PatternBenefit benefit) { |
976 | patterns |
977 | .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( |
978 | arg: patterns.getContext(), args&: benefit); |
979 | } |
980 | |
981 | void mlir::vector::populateFlattenVectorTransferPatterns( |
982 | RewritePatternSet &patterns, unsigned targetVectorBitwidth, |
983 | PatternBenefit benefit) { |
984 | patterns.add<FlattenContiguousRowMajorTransferReadPattern, |
985 | FlattenContiguousRowMajorTransferWritePattern>( |
986 | arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit); |
987 | populateDropUnitDimWithShapeCastPatterns(patterns, benefit); |
988 | } |
989 |
Definitions
- findAncestorOpInRegion
- TransferOptimization
- TransferOptimization
- removeDeadOp
- isReachable
- deadStoreOp
- storeToLoadForwarding
- getReducedShape
- dropUnitDims
- rankReducingSubviewDroppingUnitDims
- getReducedRank
- trimNonScalableUnitDims
- createMaskDropNonScalableUnitDims
- TransferReadDropUnitDimsPattern
- matchAndRewriteMaskableOp
- TransferWriteDropUnitDimsPattern
- matchAndRewriteMaskableOp
- collapseInnerDims
- getCollapsedIndices
- FlattenContiguousRowMajorTransferReadPattern
- FlattenContiguousRowMajorTransferReadPattern
- matchAndRewrite
- FlattenContiguousRowMajorTransferWritePattern
- FlattenContiguousRowMajorTransferWritePattern
- matchAndRewrite
- RewriteScalarExtractOfTransferReadBase
- RewriteScalarExtractOfTransferReadBase
- match
- RewriteScalarExtractElementOfTransferRead
- matchAndRewrite
- RewriteScalarExtractOfTransferRead
- matchAndRewrite
- RewriteScalarWrite
- matchAndRewrite
- transferOpflowOpt
- populateScalarVectorTransferLoweringPatterns
- populateVectorTransferDropUnitDimsPatterns
Improve your Profiling and Debugging skills
Find out more