1 | //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements functions concerned with optimizing transfer_read and |
10 | // transfer_write ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
18 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
20 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
21 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
22 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
23 | #include "mlir/IR/Dominance.h" |
24 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/StringRef.h" |
27 | #include "llvm/Support/Debug.h" |
28 | |
29 | #define DEBUG_TYPE "vector-transfer-opt" |
30 | |
31 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
32 | |
33 | using namespace mlir; |
34 | |
35 | /// Return the ancestor op in the region or nullptr if the region is not |
36 | /// an ancestor of the op. |
37 | static Operation *findAncestorOpInRegion(Region *region, Operation *op) { |
38 | for (; op != nullptr && op->getParentRegion() != region; |
39 | op = op->getParentOp()) |
40 | ; |
41 | return op; |
42 | } |
43 | |
44 | namespace { |
45 | |
46 | class TransferOptimization { |
47 | public: |
48 | TransferOptimization(RewriterBase &rewriter, Operation *op) |
49 | : rewriter(rewriter), dominators(op), postDominators(op) {} |
50 | void deadStoreOp(vector::TransferWriteOp); |
51 | void storeToLoadForwarding(vector::TransferReadOp); |
52 | void removeDeadOp() { |
53 | for (Operation *op : opToErase) |
54 | rewriter.eraseOp(op); |
55 | opToErase.clear(); |
56 | } |
57 | |
58 | private: |
59 | RewriterBase &rewriter; |
60 | bool isReachable(Operation *start, Operation *dest); |
61 | DominanceInfo dominators; |
62 | PostDominanceInfo postDominators; |
63 | std::vector<Operation *> opToErase; |
64 | }; |
65 | |
66 | } // namespace |
67 | /// Return true if there is a path from start operation to dest operation, |
68 | /// otherwise return false. The operations have to be in the same region. |
69 | bool TransferOptimization::isReachable(Operation *start, Operation *dest) { |
70 | assert(start->getParentRegion() == dest->getParentRegion() && |
71 | "This function only works for ops i the same region" ); |
72 | // Simple case where the start op dominate the destination. |
73 | if (dominators.dominates(a: start, b: dest)) |
74 | return true; |
75 | Block *startBlock = start->getBlock(); |
76 | Block *destBlock = dest->getBlock(); |
77 | SmallVector<Block *, 32> worklist(startBlock->succ_begin(), |
78 | startBlock->succ_end()); |
79 | SmallPtrSet<Block *, 32> visited; |
80 | while (!worklist.empty()) { |
81 | Block *bb = worklist.pop_back_val(); |
82 | if (!visited.insert(Ptr: bb).second) |
83 | continue; |
84 | if (dominators.dominates(a: bb, b: destBlock)) |
85 | return true; |
86 | worklist.append(in_start: bb->succ_begin(), in_end: bb->succ_end()); |
87 | } |
88 | return false; |
89 | } |
90 | |
91 | /// For transfer_write to overwrite fully another transfer_write must: |
92 | /// 1. Access the same memref with the same indices and vector type. |
93 | /// 2. Post-dominate the other transfer_write operation. |
94 | /// If several candidates are available, one must be post-dominated by all the |
95 | /// others since they are all post-dominating the same transfer_write. We only |
96 | /// consider the transfer_write post-dominated by all the other candidates as |
97 | /// this will be the first transfer_write executed after the potentially dead |
98 | /// transfer_write. |
99 | /// If we found such an overwriting transfer_write we know that the original |
100 | /// transfer_write is dead if all reads that can be reached from the potentially |
101 | /// dead transfer_write are dominated by the overwriting transfer_write. |
102 | void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { |
103 | LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() |
104 | << "\n" ); |
105 | llvm::SmallVector<Operation *, 8> blockingAccesses; |
106 | Operation *firstOverwriteCandidate = nullptr; |
107 | Value source = write.getSource(); |
108 | // Skip subview ops. |
109 | while (auto subView = source.getDefiningOp<memref::SubViewOp>()) |
110 | source = subView.getSource(); |
111 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
112 | source.getUsers().end()); |
113 | llvm::SmallDenseSet<Operation *, 32> processed; |
114 | while (!users.empty()) { |
115 | Operation *user = users.pop_back_val(); |
116 | // If the user has already been processed skip. |
117 | if (!processed.insert(V: user).second) |
118 | continue; |
119 | if (auto subView = dyn_cast<memref::SubViewOp>(user)) { |
120 | users.append(subView->getUsers().begin(), subView->getUsers().end()); |
121 | continue; |
122 | } |
123 | if (isMemoryEffectFree(op: user)) |
124 | continue; |
125 | if (user == write.getOperation()) |
126 | continue; |
127 | if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { |
128 | // Check candidate that can override the store. |
129 | if (write.getSource() == nextWrite.getSource() && |
130 | checkSameValueWAW(nextWrite, write) && |
131 | postDominators.postDominates(nextWrite, write)) { |
132 | if (firstOverwriteCandidate == nullptr || |
133 | postDominators.postDominates(firstOverwriteCandidate, nextWrite)) |
134 | firstOverwriteCandidate = nextWrite; |
135 | else |
136 | assert( |
137 | postDominators.postDominates(nextWrite, firstOverwriteCandidate)); |
138 | continue; |
139 | } |
140 | } |
141 | if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { |
142 | // Don't need to consider disjoint accesses. |
143 | if (vector::isDisjointTransferSet( |
144 | cast<VectorTransferOpInterface>(write.getOperation()), |
145 | cast<VectorTransferOpInterface>(transferOp.getOperation()), |
146 | /*testDynamicValueUsingBounds=*/true)) |
147 | continue; |
148 | } |
149 | blockingAccesses.push_back(Elt: user); |
150 | } |
151 | if (firstOverwriteCandidate == nullptr) |
152 | return; |
153 | Region *topRegion = firstOverwriteCandidate->getParentRegion(); |
154 | Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); |
155 | assert(writeAncestor && |
156 | "write op should be recursively part of the top region" ); |
157 | |
158 | for (Operation *access : blockingAccesses) { |
159 | Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access); |
160 | // TODO: if the access and write have the same ancestor we could recurse in |
161 | // the region to know if the access is reachable with more precision. |
162 | if (accessAncestor == nullptr || |
163 | !isReachable(start: writeAncestor, dest: accessAncestor)) |
164 | continue; |
165 | if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) { |
166 | LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " |
167 | << *accessAncestor << "\n" ); |
168 | return; |
169 | } |
170 | } |
171 | LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() |
172 | << " overwritten by: " << *firstOverwriteCandidate << "\n" ); |
173 | opToErase.push_back(write.getOperation()); |
174 | } |
175 | |
176 | /// A transfer_write candidate to storeToLoad forwarding must: |
177 | /// 1. Access the same memref with the same indices and vector type as the |
178 | /// transfer_read. |
179 | /// 2. Dominate the transfer_read operation. |
180 | /// If several candidates are available, one must be dominated by all the others |
181 | /// since they are all dominating the same transfer_read. We only consider the |
182 | /// transfer_write dominated by all the other candidates as this will be the |
183 | /// last transfer_write executed before the transfer_read. |
184 | /// If we found such a candidate we can do the forwarding if all the other |
185 | /// potentially aliasing ops that may reach the transfer_read are post-dominated |
186 | /// by the transfer_write. |
187 | void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { |
188 | if (read.hasOutOfBoundsDim()) |
189 | return; |
190 | LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() |
191 | << "\n" ); |
192 | SmallVector<Operation *, 8> blockingWrites; |
193 | vector::TransferWriteOp lastwrite = nullptr; |
194 | Value source = read.getSource(); |
195 | // Skip subview ops. |
196 | while (auto subView = source.getDefiningOp<memref::SubViewOp>()) |
197 | source = subView.getSource(); |
198 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
199 | source.getUsers().end()); |
200 | llvm::SmallDenseSet<Operation *, 32> processed; |
201 | while (!users.empty()) { |
202 | Operation *user = users.pop_back_val(); |
203 | // If the user has already been processed skip. |
204 | if (!processed.insert(V: user).second) |
205 | continue; |
206 | if (auto subView = dyn_cast<memref::SubViewOp>(user)) { |
207 | users.append(subView->getUsers().begin(), subView->getUsers().end()); |
208 | continue; |
209 | } |
210 | if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) { |
211 | users.append(collapsed->getUsers().begin(), collapsed->getUsers().end()); |
212 | continue; |
213 | } |
214 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
215 | continue; |
216 | if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { |
217 | // If there is a write, but we can prove that it is disjoint we can ignore |
218 | // the write. |
219 | if (vector::isDisjointTransferSet( |
220 | cast<VectorTransferOpInterface>(write.getOperation()), |
221 | cast<VectorTransferOpInterface>(read.getOperation()), |
222 | /*testDynamicValueUsingBounds=*/true)) |
223 | continue; |
224 | if (write.getSource() == read.getSource() && |
225 | dominators.dominates(write, read) && checkSameValueRAW(write, read)) { |
226 | if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) |
227 | lastwrite = write; |
228 | else |
229 | assert(dominators.dominates(write, lastwrite)); |
230 | continue; |
231 | } |
232 | } |
233 | blockingWrites.push_back(Elt: user); |
234 | } |
235 | |
236 | if (lastwrite == nullptr) |
237 | return; |
238 | |
239 | Region *topRegion = lastwrite->getParentRegion(); |
240 | Operation *readAncestor = findAncestorOpInRegion(topRegion, read); |
241 | assert(readAncestor && |
242 | "read op should be recursively part of the top region" ); |
243 | |
244 | for (Operation *write : blockingWrites) { |
245 | Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write); |
246 | // TODO: if the store and read have the same ancestor we could recurse in |
247 | // the region to know if the read is reachable with more precision. |
248 | if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor)) |
249 | continue; |
250 | if (!postDominators.postDominates(lastwrite, write)) { |
251 | LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " |
252 | << *write << "\n" ); |
253 | return; |
254 | } |
255 | } |
256 | |
257 | LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() |
258 | << " to: " << *read.getOperation() << "\n" ); |
259 | read.replaceAllUsesWith(lastwrite.getVector()); |
260 | opToErase.push_back(read.getOperation()); |
261 | } |
262 | |
263 | /// Converts OpFoldResults to int64_t shape without unit dims. |
264 | static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { |
265 | SmallVector<int64_t> reducedShape; |
266 | for (const auto size : mixedSizes) { |
267 | if (llvm::dyn_cast_if_present<Value>(Val: size)) { |
268 | reducedShape.push_back(ShapedType::kDynamic); |
269 | continue; |
270 | } |
271 | |
272 | auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue(); |
273 | if (value == 1) |
274 | continue; |
275 | reducedShape.push_back(Elt: value.getSExtValue()); |
276 | } |
277 | return reducedShape; |
278 | } |
279 | |
280 | /// Drops unit dimensions from the input MemRefType. |
281 | static MemRefType dropUnitDims(MemRefType inputType, |
282 | ArrayRef<OpFoldResult> offsets, |
283 | ArrayRef<OpFoldResult> sizes, |
284 | ArrayRef<OpFoldResult> strides) { |
285 | auto targetShape = getReducedShape(mixedSizes: sizes); |
286 | Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( |
287 | targetShape, inputType, offsets, sizes, strides); |
288 | return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); |
289 | } |
290 | |
291 | /// Creates a rank-reducing memref.subview op that drops unit dims from its |
292 | /// input. Or just returns the input if it was already without unit dims. |
293 | static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, |
294 | mlir::Location loc, |
295 | Value input) { |
296 | MemRefType inputType = cast<MemRefType>(input.getType()); |
297 | SmallVector<OpFoldResult> offsets(inputType.getRank(), |
298 | rewriter.getIndexAttr(0)); |
299 | SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input); |
300 | SmallVector<OpFoldResult> strides(inputType.getRank(), |
301 | rewriter.getIndexAttr(1)); |
302 | MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); |
303 | |
304 | if (canonicalizeStridedLayout(resultType) == |
305 | canonicalizeStridedLayout(inputType)) |
306 | return input; |
307 | return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, |
308 | sizes, strides); |
309 | } |
310 | |
311 | /// Returns the number of dims that aren't unit dims. |
312 | static int getReducedRank(ArrayRef<int64_t> shape) { |
313 | return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; }); |
314 | } |
315 | |
316 | /// Trims non-scalable one dimensions from `oldType` and returns the result |
317 | /// type. |
318 | static VectorType trimNonScalableUnitDims(VectorType oldType) { |
319 | SmallVector<int64_t> newShape; |
320 | SmallVector<bool> newScalableDims; |
321 | for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { |
322 | if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) |
323 | continue; |
324 | newShape.push_back(dimSize); |
325 | newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); |
326 | } |
327 | return VectorType::get(newShape, oldType.getElementType(), newScalableDims); |
328 | } |
329 | |
330 | // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. |
331 | static FailureOr<Value> |
332 | createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, |
333 | vector::CreateMaskOp op) { |
334 | auto type = op.getType(); |
335 | VectorType reducedType = trimNonScalableUnitDims(type); |
336 | if (reducedType.getRank() == type.getRank()) |
337 | return failure(); |
338 | |
339 | SmallVector<Value> reducedOperands; |
340 | for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( |
341 | type.getShape(), type.getScalableDims(), op.getOperands())) { |
342 | if (dim == 1 && !dimIsScalable) { |
343 | // If the mask for the unit dim is not a constant of 1, do nothing. |
344 | auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); |
345 | if (!constant || (constant.value() != 1)) |
346 | return failure(); |
347 | continue; |
348 | } |
349 | reducedOperands.push_back(operand); |
350 | } |
351 | return rewriter |
352 | .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) |
353 | .getResult(); |
354 | } |
355 | |
356 | namespace { |
357 | |
358 | /// Rewrites `vector.transfer_read` ops where the source has unit dims, by |
359 | /// inserting a memref.subview dropping those unit dims. The vector shapes are |
360 | /// also reduced accordingly. |
361 | class TransferReadDropUnitDimsPattern |
362 | : public OpRewritePattern<vector::TransferReadOp> { |
363 | using OpRewritePattern::OpRewritePattern; |
364 | |
365 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
366 | PatternRewriter &rewriter) const override { |
367 | auto loc = transferReadOp.getLoc(); |
368 | Value vector = transferReadOp.getVector(); |
369 | VectorType vectorType = cast<VectorType>(vector.getType()); |
370 | Value source = transferReadOp.getSource(); |
371 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
372 | // TODO: support tensor types. |
373 | if (!sourceType) |
374 | return failure(); |
375 | // TODO: generalize this pattern, relax the requirements here. |
376 | if (transferReadOp.hasOutOfBoundsDim()) |
377 | return failure(); |
378 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
379 | return failure(); |
380 | // Check if the source shape can be further reduced. |
381 | int reducedRank = getReducedRank(sourceType.getShape()); |
382 | if (reducedRank == sourceType.getRank()) |
383 | return failure(); |
384 | // Check if the reduced vector shape matches the reduced source shape. |
385 | // Otherwise, this case is not supported yet. |
386 | VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); |
387 | if (reducedRank != reducedVectorType.getRank()) |
388 | return failure(); |
389 | if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { |
390 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
391 | })) |
392 | return failure(); |
393 | |
394 | Value maskOp = transferReadOp.getMask(); |
395 | if (maskOp) { |
396 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
397 | if (!createMaskOp) |
398 | return rewriter.notifyMatchFailure( |
399 | transferReadOp, "unsupported mask op, only 'vector.create_mask' is " |
400 | "currently supported" ); |
401 | FailureOr<Value> rankReducedCreateMask = |
402 | createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); |
403 | if (failed(result: rankReducedCreateMask)) |
404 | return failure(); |
405 | maskOp = *rankReducedCreateMask; |
406 | } |
407 | |
408 | Value reducedShapeSource = |
409 | rankReducingSubviewDroppingUnitDims(rewriter, loc, source); |
410 | Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
411 | SmallVector<Value> zeros(reducedRank, c0); |
412 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
413 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
414 | auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( |
415 | loc, reducedVectorType, reducedShapeSource, zeros, identityMap, |
416 | transferReadOp.getPadding(), maskOp, |
417 | rewriter.getBoolArrayAttr(inBounds)); |
418 | auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( |
419 | loc, vectorType, newTransferReadOp); |
420 | rewriter.replaceOp(transferReadOp, shapeCast); |
421 | |
422 | return success(); |
423 | } |
424 | }; |
425 | |
426 | /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) |
427 | /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The |
428 | /// vector shapes are also reduced accordingly. |
429 | class TransferWriteDropUnitDimsPattern |
430 | : public OpRewritePattern<vector::TransferWriteOp> { |
431 | using OpRewritePattern::OpRewritePattern; |
432 | |
433 | LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, |
434 | PatternRewriter &rewriter) const override { |
435 | auto loc = transferWriteOp.getLoc(); |
436 | Value vector = transferWriteOp.getVector(); |
437 | VectorType vectorType = cast<VectorType>(vector.getType()); |
438 | Value source = transferWriteOp.getSource(); |
439 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
440 | // TODO: support tensor type. |
441 | if (!sourceType) |
442 | return failure(); |
443 | // TODO: generalize this pattern, relax the requirements here. |
444 | if (transferWriteOp.hasOutOfBoundsDim()) |
445 | return failure(); |
446 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
447 | return failure(); |
448 | // Check if the destination shape can be further reduced. |
449 | int reducedRank = getReducedRank(sourceType.getShape()); |
450 | if (reducedRank == sourceType.getRank()) |
451 | return failure(); |
452 | // Check if the reduced vector shape matches the reduced destination shape. |
453 | // Otherwise, this case is not supported yet. |
454 | VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); |
455 | if (reducedRank != reducedVectorType.getRank()) |
456 | return failure(); |
457 | if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { |
458 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
459 | })) |
460 | return failure(); |
461 | |
462 | Value maskOp = transferWriteOp.getMask(); |
463 | if (maskOp) { |
464 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
465 | if (!createMaskOp) |
466 | return rewriter.notifyMatchFailure( |
467 | transferWriteOp, |
468 | "unsupported mask op, only 'vector.create_mask' is " |
469 | "currently supported" ); |
470 | FailureOr<Value> rankReducedCreateMask = |
471 | createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); |
472 | if (failed(result: rankReducedCreateMask)) |
473 | return failure(); |
474 | maskOp = *rankReducedCreateMask; |
475 | } |
476 | Value reducedShapeSource = |
477 | rankReducingSubviewDroppingUnitDims(rewriter, loc, source); |
478 | Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
479 | SmallVector<Value> zeros(reducedRank, c0); |
480 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
481 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
482 | auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( |
483 | loc, reducedVectorType, vector); |
484 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
485 | transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, |
486 | identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); |
487 | |
488 | return success(); |
489 | } |
490 | }; |
491 | |
492 | } // namespace |
493 | |
494 | /// Creates a memref.collapse_shape collapsing all inner dimensions of the |
495 | /// input starting at `firstDimToCollapse`. |
496 | static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, |
497 | Value input, int64_t firstDimToCollapse) { |
498 | ShapedType inputType = cast<ShapedType>(input.getType()); |
499 | if (inputType.getRank() == 1) |
500 | return input; |
501 | SmallVector<ReassociationIndices> reassociation; |
502 | for (int64_t i = 0; i < firstDimToCollapse; ++i) |
503 | reassociation.push_back(Elt: ReassociationIndices{i}); |
504 | ReassociationIndices collapsedIndices; |
505 | for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) |
506 | collapsedIndices.push_back(Elt: i); |
507 | reassociation.push_back(Elt: collapsedIndices); |
508 | return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); |
509 | } |
510 | |
511 | /// Checks that the indices corresponding to dimensions starting at |
512 | /// `firstDimToCollapse` are constant 0, and writes to `outIndices` |
513 | /// the truncated indices where `firstDimToCollapse` is now the innermost dim. |
514 | /// TODO: Extract the logic that writes to outIndices so that this method |
515 | /// simply checks one pre-condition. |
516 | static LogicalResult |
517 | checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, |
518 | SmallVector<Value> &outIndices) { |
519 | int64_t rank = indices.size(); |
520 | if (firstDimToCollapse >= rank) |
521 | return failure(); |
522 | for (int64_t i = firstDimToCollapse; i < rank; ++i) { |
523 | std::optional<int64_t> cst = getConstantIntValue(ofr: indices[i]); |
524 | if (!cst || cst.value() != 0) |
525 | return failure(); |
526 | } |
527 | outIndices = indices; |
528 | outIndices.resize(N: firstDimToCollapse + 1); |
529 | return success(); |
530 | } |
531 | |
532 | namespace { |
533 | |
534 | /// Rewrites contiguous row-major vector.transfer_read ops by inserting |
535 | /// memref.collapse_shape on the source so that the resulting |
536 | /// vector.transfer_read has a 1D source. Requires the source shape to be |
537 | /// already reduced i.e. without unit dims. |
538 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
539 | /// the trailing dimension of the vector read is smaller than the provided |
540 | /// bitwidth. |
541 | class FlattenContiguousRowMajorTransferReadPattern |
542 | : public OpRewritePattern<vector::TransferReadOp> { |
543 | public: |
544 | FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, |
545 | unsigned vectorBitwidth, |
546 | PatternBenefit benefit) |
547 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
548 | targetVectorBitwidth(vectorBitwidth) {} |
549 | |
550 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
551 | PatternRewriter &rewriter) const override { |
552 | auto loc = transferReadOp.getLoc(); |
553 | Value vector = transferReadOp.getVector(); |
554 | VectorType vectorType = cast<VectorType>(vector.getType()); |
555 | auto source = transferReadOp.getSource(); |
556 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
557 | |
558 | // 0. Check pre-conditions |
559 | // Contiguity check is valid on tensors only. |
560 | if (!sourceType) |
561 | return failure(); |
562 | // If this is already 0D/1D, there's nothing to do. |
563 | if (vectorType.getRank() <= 1) |
564 | return failure(); |
565 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
566 | return failure(); |
567 | unsigned trailingVectorDimBitwidth = |
568 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
569 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
570 | return failure(); |
571 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType)) |
572 | return failure(); |
573 | // TODO: generalize this pattern, relax the requirements here. |
574 | if (transferReadOp.hasOutOfBoundsDim()) |
575 | return failure(); |
576 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
577 | return failure(); |
578 | if (transferReadOp.getMask()) |
579 | return failure(); |
580 | |
581 | int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); |
582 | |
583 | // 1. Collapse the source memref |
584 | Value collapsedSource = |
585 | collapseInnerDims(rewriter, loc, source, firstDimToCollapse); |
586 | MemRefType collapsedSourceType = |
587 | dyn_cast<MemRefType>(collapsedSource.getType()); |
588 | int64_t collapsedRank = collapsedSourceType.getRank(); |
589 | assert(collapsedRank == firstDimToCollapse + 1); |
590 | |
591 | // 2. Generate input args for a new vector.transfer_read that will read |
592 | // from the collapsed memref. |
593 | // 2.1. New dim exprs + affine map |
594 | SmallVector<AffineExpr, 1> dimExprs{ |
595 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
596 | auto collapsedMap = |
597 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
598 | |
599 | // 2.2 New indices |
600 | // If all the collapsed indices are zero then no extra logic is needed. |
601 | // Otherwise, a new offset/index has to be computed. |
602 | SmallVector<Value> collapsedIndices; |
603 | if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), |
604 | firstDimToCollapse, |
605 | collapsedIndices))) { |
606 | // Copy all the leading indices. |
607 | SmallVector<Value> indices = transferReadOp.getIndices(); |
608 | collapsedIndices.append(in_start: indices.begin(), |
609 | in_end: indices.begin() + firstDimToCollapse); |
610 | |
611 | // Compute the remaining trailing index/offset required for reading from |
612 | // the collapsed memref: |
613 | // |
614 | // offset = 0 |
615 | // for (i = firstDimToCollapse; i < outputRank; ++i) |
616 | // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] |
617 | // |
618 | // For this example: |
619 | // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) : |
620 | // memref<1x43x2xi32>, vector<1x2xi32> |
621 | // which would be collapsed to: |
622 | // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) : |
623 | // memref<1x86xi32>, vector<2xi32> |
624 | // one would get the following offset: |
625 | // %offset = %arg0 * 43 |
626 | OpFoldResult collapsedOffset = |
627 | rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); |
628 | |
629 | auto sourceShape = sourceType.getShape(); |
630 | auto collapsedStrides = computeSuffixProduct(sizes: ArrayRef<int64_t>( |
631 | sourceShape.begin() + firstDimToCollapse, sourceShape.end())); |
632 | |
633 | // Compute the collapsed offset. |
634 | ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, |
635 | indices.end()); |
636 | auto &&[collapsedExpr, collapsedVals] = computeLinearIndex( |
637 | collapsedOffset, collapsedStrides, indicesToCollapse); |
638 | collapsedOffset = affine::makeComposedFoldedAffineApply( |
639 | rewriter, loc, collapsedExpr, collapsedVals); |
640 | |
641 | if (collapsedOffset.is<Value>()) { |
642 | collapsedIndices.push_back(Elt: collapsedOffset.get<Value>()); |
643 | } else { |
644 | collapsedIndices.push_back(Elt: rewriter.create<arith::ConstantIndexOp>( |
645 | loc, *getConstantIntValue(ofr: collapsedOffset))); |
646 | } |
647 | } |
648 | |
649 | // 3. Create new vector.transfer_read that reads from the collapsed memref |
650 | VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, |
651 | vectorType.getElementType()); |
652 | vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( |
653 | loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); |
654 | flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
655 | |
656 | // 4. Replace the old transfer_read with the new one reading from the |
657 | // collapsed shape |
658 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
659 | transferReadOp, cast<VectorType>(vector.getType()), flatRead); |
660 | return success(); |
661 | } |
662 | |
663 | private: |
664 | // Minimum bitwidth that the trailing vector dimension should have after |
665 | // flattening. |
666 | unsigned targetVectorBitwidth; |
667 | }; |
668 | |
669 | /// Rewrites contiguous row-major vector.transfer_write ops by inserting |
670 | /// memref.collapse_shape on the source so that the resulting |
671 | /// vector.transfer_write has a 1D source. Requires the source shape to be |
672 | /// already reduced i.e. without unit dims. |
673 | class FlattenContiguousRowMajorTransferWritePattern |
674 | : public OpRewritePattern<vector::TransferWriteOp> { |
675 | public: |
676 | FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, |
677 | unsigned vectorBitwidth, |
678 | PatternBenefit benefit) |
679 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
680 | targetVectorBitwidth(vectorBitwidth) {} |
681 | |
682 | LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, |
683 | PatternRewriter &rewriter) const override { |
684 | auto loc = transferWriteOp.getLoc(); |
685 | Value vector = transferWriteOp.getVector(); |
686 | VectorType vectorType = cast<VectorType>(vector.getType()); |
687 | Value source = transferWriteOp.getSource(); |
688 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
689 | // Contiguity check is valid on tensors only. |
690 | if (!sourceType) |
691 | return failure(); |
692 | if (vectorType.getRank() <= 1) |
693 | // Already 0D/1D, nothing to do. |
694 | return failure(); |
695 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
696 | return failure(); |
697 | unsigned trailingVectorDimBitwidth = |
698 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
699 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
700 | return failure(); |
701 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType)) |
702 | return failure(); |
703 | int64_t firstContiguousInnerDim = |
704 | sourceType.getRank() - vectorType.getRank(); |
705 | // TODO: generalize this pattern, relax the requirements here. |
706 | if (transferWriteOp.hasOutOfBoundsDim()) |
707 | return failure(); |
708 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
709 | return failure(); |
710 | if (transferWriteOp.getMask()) |
711 | return failure(); |
712 | SmallVector<Value> collapsedIndices; |
713 | if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), |
714 | firstContiguousInnerDim, |
715 | collapsedIndices))) |
716 | return failure(); |
717 | |
718 | Value collapsedSource = |
719 | collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); |
720 | MemRefType collapsedSourceType = |
721 | cast<MemRefType>(collapsedSource.getType()); |
722 | int64_t collapsedRank = collapsedSourceType.getRank(); |
723 | assert(collapsedRank == firstContiguousInnerDim + 1); |
724 | SmallVector<AffineExpr, 1> dimExprs{ |
725 | getAffineDimExpr(position: firstContiguousInnerDim, context: rewriter.getContext())}; |
726 | auto collapsedMap = |
727 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
728 | VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, |
729 | vectorType.getElementType()); |
730 | Value flatVector = |
731 | rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); |
732 | vector::TransferWriteOp flatWrite = |
733 | rewriter.create<vector::TransferWriteOp>( |
734 | loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); |
735 | flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
736 | rewriter.eraseOp(op: transferWriteOp); |
737 | return success(); |
738 | } |
739 | |
740 | private: |
741 | // Minimum bitwidth that the trailing vector dimension should have after |
742 | // flattening. |
743 | unsigned targetVectorBitwidth; |
744 | }; |
745 | |
746 | /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` |
747 | /// to `memref.load` patterns. The `match` method is shared for both |
748 | /// `vector.extract` and `vector.extract_element`. |
749 | template <class VectorExtractOp> |
750 | class |
751 | : public OpRewritePattern<VectorExtractOp> { |
752 | using = OpRewritePattern<VectorExtractOp>; |
753 | |
754 | public: |
755 | (MLIRContext *context, |
756 | PatternBenefit benefit, |
757 | bool allowMultipleUses) |
758 | : Base::OpRewritePattern(context, benefit), |
759 | allowMultipleUses(allowMultipleUses) {} |
760 | |
761 | LogicalResult (VectorExtractOp ) const override { |
762 | auto xferOp = |
763 | extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); |
764 | if (!xferOp) |
765 | return failure(); |
766 | // Check that we are extracting a scalar and not a sub-vector. |
767 | if (isa<VectorType>(extractOp.getResult().getType())) |
768 | return failure(); |
769 | // If multiple uses are not allowed, check if xfer has a single use. |
770 | if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) |
771 | return failure(); |
772 | // If multiple uses are allowed, check if all the xfer uses are extract ops. |
773 | if (allowMultipleUses && |
774 | !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { |
775 | return isa<vector::ExtractOp, vector::ExtractElementOp>( |
776 | use.getOwner()); |
777 | })) |
778 | return failure(); |
779 | // Mask not supported. |
780 | if (xferOp.getMask()) |
781 | return failure(); |
782 | // Map not supported. |
783 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
784 | return failure(); |
785 | // Cannot rewrite if the indices may be out of bounds. |
786 | if (xferOp.hasOutOfBoundsDim()) |
787 | return failure(); |
788 | return success(); |
789 | } |
790 | |
791 | private: |
792 | bool ; |
793 | }; |
794 | |
795 | /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. |
796 | /// |
797 | /// All the users of the transfer op must be either `vector.extractelement` or |
798 | /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite |
799 | /// transfer ops with any number of users. Otherwise, rewrite only if the |
800 | /// extract op is the single user of the transfer op. Rewriting a single |
801 | /// vector load with multiple scalar loads may negatively affect performance. |
802 | class |
803 | : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { |
804 | using RewriteScalarExtractOfTransferReadBase:: |
805 | RewriteScalarExtractOfTransferReadBase; |
806 | |
807 | void (vector::ExtractElementOp , |
808 | PatternRewriter &rewriter) const override { |
809 | // Construct scalar load. |
810 | auto loc = extractOp.getLoc(); |
811 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
812 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
813 | xferOp.getIndices().end()); |
814 | if (extractOp.getPosition()) { |
815 | AffineExpr sym0, sym1; |
816 | bindSymbols(extractOp.getContext(), sym0, sym1); |
817 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
818 | rewriter, loc, sym0 + sym1, |
819 | {newIndices[newIndices.size() - 1], extractOp.getPosition()}); |
820 | if (ofr.is<Value>()) { |
821 | newIndices[newIndices.size() - 1] = ofr.get<Value>(); |
822 | } else { |
823 | newIndices[newIndices.size() - 1] = |
824 | rewriter.create<arith::ConstantIndexOp>(loc, |
825 | *getConstantIntValue(ofr)); |
826 | } |
827 | } |
828 | if (isa<MemRefType>(xferOp.getSource().getType())) { |
829 | rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), |
830 | newIndices); |
831 | } else { |
832 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
833 | extractOp, xferOp.getSource(), newIndices); |
834 | } |
835 | } |
836 | }; |
837 | |
838 | /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. |
839 | /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. |
840 | /// |
841 | /// All the users of the transfer op must be either `vector.extractelement` or |
842 | /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite |
843 | /// transfer ops with any number of users. Otherwise, rewrite only if the |
844 | /// extract op is the single user of the transfer op. Rewriting a single |
845 | /// vector load with multiple scalar loads may negatively affect performance. |
846 | class |
847 | : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { |
848 | using RewriteScalarExtractOfTransferReadBase:: |
849 | RewriteScalarExtractOfTransferReadBase; |
850 | |
851 | void (vector::ExtractOp , |
852 | PatternRewriter &rewriter) const override { |
853 | // Construct scalar load. |
854 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
855 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
856 | xferOp.getIndices().end()); |
857 | for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { |
858 | assert(pos.is<Attribute>() && "Unexpected non-constant index" ); |
859 | int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt(); |
860 | int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; |
861 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
862 | rewriter, extractOp.getLoc(), |
863 | rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); |
864 | if (ofr.is<Value>()) { |
865 | newIndices[idx] = ofr.get<Value>(); |
866 | } else { |
867 | newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( |
868 | extractOp.getLoc(), *getConstantIntValue(ofr)); |
869 | } |
870 | } |
871 | if (isa<MemRefType>(xferOp.getSource().getType())) { |
872 | rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), |
873 | newIndices); |
874 | } else { |
875 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
876 | extractOp, xferOp.getSource(), newIndices); |
877 | } |
878 | } |
879 | }; |
880 | |
881 | /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) |
882 | /// to memref.store. |
883 | class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { |
884 | using OpRewritePattern::OpRewritePattern; |
885 | |
886 | LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, |
887 | PatternRewriter &rewriter) const override { |
888 | // Must be a scalar write. |
889 | auto vecType = xferOp.getVectorType(); |
890 | if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) |
891 | return failure(); |
892 | // Mask not supported. |
893 | if (xferOp.getMask()) |
894 | return failure(); |
895 | // Map not supported. |
896 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
897 | return failure(); |
898 | // Only float and integer element types are supported. |
899 | Value scalar; |
900 | if (vecType.getRank() == 0) { |
901 | // vector.extract does not support vector<f32> etc., so use |
902 | // vector.extractelement instead. |
903 | scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), |
904 | xferOp.getVector()); |
905 | } else { |
906 | SmallVector<int64_t> pos(vecType.getRank(), 0); |
907 | scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), |
908 | xferOp.getVector(), pos); |
909 | } |
910 | // Construct a scalar store. |
911 | if (isa<MemRefType>(xferOp.getSource().getType())) { |
912 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
913 | xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); |
914 | } else { |
915 | rewriter.replaceOpWithNewOp<tensor::InsertOp>( |
916 | xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); |
917 | } |
918 | return success(); |
919 | } |
920 | }; |
921 | |
922 | } // namespace |
923 | |
924 | void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, |
925 | Operation *rootOp) { |
926 | TransferOptimization opt(rewriter, rootOp); |
927 | // Run store to load forwarding first since it can expose more dead store |
928 | // opportunity. |
929 | rootOp->walk([&](vector::TransferReadOp read) { |
930 | if (isa<MemRefType>(read.getShapedType())) |
931 | opt.storeToLoadForwarding(read: read); |
932 | }); |
933 | opt.removeDeadOp(); |
934 | rootOp->walk([&](vector::TransferWriteOp write) { |
935 | if (isa<MemRefType>(write.getShapedType())) |
936 | opt.deadStoreOp(write: write); |
937 | }); |
938 | opt.removeDeadOp(); |
939 | } |
940 | |
941 | void mlir::vector::populateScalarVectorTransferLoweringPatterns( |
942 | RewritePatternSet &patterns, PatternBenefit benefit, |
943 | bool allowMultipleUses) { |
944 | patterns.add<RewriteScalarExtractElementOfTransferRead, |
945 | RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(), |
946 | args&: benefit, args&: allowMultipleUses); |
947 | patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit); |
948 | } |
949 | |
950 | void mlir::vector::populateVectorTransferDropUnitDimsPatterns( |
951 | RewritePatternSet &patterns, PatternBenefit benefit) { |
952 | patterns |
953 | .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( |
954 | arg: patterns.getContext(), args&: benefit); |
955 | populateShapeCastFoldingPatterns(patterns); |
956 | } |
957 | |
958 | void mlir::vector::populateFlattenVectorTransferPatterns( |
959 | RewritePatternSet &patterns, unsigned targetVectorBitwidth, |
960 | PatternBenefit benefit) { |
961 | patterns.add<FlattenContiguousRowMajorTransferReadPattern, |
962 | FlattenContiguousRowMajorTransferWritePattern>( |
963 | arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit); |
964 | populateShapeCastFoldingPatterns(patterns, benefit); |
965 | populateDropUnitDimWithShapeCastPatterns(patterns, benefit); |
966 | } |
967 | |