1 | //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites as 1->N patterns. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
14 | |
15 | #include <cassert> |
16 | #include <cstdint> |
17 | #include <functional> |
18 | #include <optional> |
19 | #include <type_traits> |
20 | |
21 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
22 | #include "mlir/Dialect/Arith/IR/Arith.h" |
23 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
24 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
25 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
26 | #include "mlir/Dialect/SCF/IR/SCF.h" |
27 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
28 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
29 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
30 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
31 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
32 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
33 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
34 | #include "mlir/IR/BuiltinTypes.h" |
35 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
36 | #include "mlir/IR/Location.h" |
37 | #include "mlir/IR/Matchers.h" |
38 | #include "mlir/IR/PatternMatch.h" |
39 | #include "mlir/IR/TypeUtilities.h" |
40 | #include "mlir/Interfaces/VectorInterfaces.h" |
41 | #include "mlir/Support/LogicalResult.h" |
42 | |
43 | #include "llvm/ADT/DenseSet.h" |
44 | #include "llvm/ADT/MapVector.h" |
45 | #include "llvm/ADT/STLExtras.h" |
46 | #include "llvm/Support/CommandLine.h" |
47 | #include "llvm/Support/Debug.h" |
48 | #include "llvm/Support/FormatVariadic.h" |
49 | #include "llvm/Support/raw_ostream.h" |
50 | |
51 | #define DEBUG_TYPE "vector-to-vector" |
52 | |
53 | using namespace mlir; |
54 | using namespace mlir::vector; |
55 | |
56 | template <typename IntType> |
57 | static SmallVector<IntType> (ArrayAttr arrayAttr) { |
58 | return llvm::to_vector<4>(llvm::map_range( |
59 | arrayAttr.getAsRange<IntegerAttr>(), |
60 | [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); |
61 | } |
62 | |
63 | // Helper to find an index in an affine map. |
64 | static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) { |
65 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
66 | int64_t idx = map.getDimPosition(idx: i); |
67 | if (idx == index) |
68 | return i; |
69 | } |
70 | return std::nullopt; |
71 | } |
72 | |
73 | namespace { |
74 | |
75 | /// ShapeCastOpFolder folds cancelling ShapeCastOps away. |
76 | // |
77 | // Example: |
78 | // |
79 | // The following MLIR with cancelling ShapeCastOps: |
80 | // |
81 | // %0 = source : vector<5x4x2xf32> |
82 | // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> |
83 | // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> |
84 | // %3 = user %2 : vector<5x4x2xf32> |
85 | // |
86 | // Should canonicalize to the following: |
87 | // |
88 | // %0 = source : vector<5x4x2xf32> |
89 | // %1 = user %0 : vector<5x4x2xf32> |
90 | // |
91 | struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { |
92 | using OpRewritePattern::OpRewritePattern; |
93 | |
94 | LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, |
95 | PatternRewriter &rewriter) const override { |
96 | // Check if 'shapeCastOp' has vector source/result type. |
97 | auto sourceVectorType = |
98 | dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType()); |
99 | auto resultVectorType = |
100 | dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType()); |
101 | if (!sourceVectorType || !resultVectorType) |
102 | return failure(); |
103 | |
104 | // Check if shape cast op source operand is also a shape cast op. |
105 | auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( |
106 | shapeCastOp.getSource().getDefiningOp()); |
107 | if (!sourceShapeCastOp) |
108 | return failure(); |
109 | auto operandSourceVectorType = |
110 | cast<VectorType>(sourceShapeCastOp.getSource().getType()); |
111 | auto operandResultVectorType = sourceShapeCastOp.getType(); |
112 | |
113 | // Check if shape cast operations invert each other. |
114 | if (operandSourceVectorType != resultVectorType || |
115 | operandResultVectorType != sourceVectorType) |
116 | return failure(); |
117 | |
118 | rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); |
119 | return success(); |
120 | } |
121 | }; |
122 | |
123 | /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. |
124 | /// Ex: |
125 | /// ``` |
126 | /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> |
127 | /// %1 = vector.multi_reduction add, %0 [1] |
128 | /// : vector<8x32x16xf32> to vector<8x16xf32> |
129 | /// ``` |
130 | /// Gets converted to: |
131 | /// ``` |
132 | /// %1 = vector.contract {indexing_maps = [ |
133 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
134 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
135 | /// affine_map<(d0, d1, d2) -> (d0, d1)>], |
136 | /// iterator_types = ["parallel", "parallel", "reduction"], |
137 | /// kind = add} %0, %arg1, %cst_f0 |
138 | /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> |
139 | /// ``` |
140 | struct MultiReduceToContract |
141 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
142 | using OpRewritePattern::OpRewritePattern; |
143 | |
144 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, |
145 | PatternRewriter &rewriter) const override { |
146 | if (reduceOp.getKind() != vector::CombiningKind::ADD) |
147 | return failure(); |
148 | Operation *mulOp = reduceOp.getSource().getDefiningOp(); |
149 | if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) |
150 | return failure(); |
151 | SmallVector<bool> reductionMask = reduceOp.getReductionMask(); |
152 | auto srcMap = rewriter.getMultiDimIdentityMap(rank: reductionMask.size()); |
153 | SmallVector<AffineExpr> exprs; |
154 | SmallVector<vector::IteratorType> iteratorTypes; |
155 | for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { |
156 | if (!isReduceDim.value()) { |
157 | iteratorTypes.push_back(vector::IteratorType::parallel); |
158 | exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); |
159 | } else { |
160 | iteratorTypes.push_back(vector::IteratorType::reduction); |
161 | } |
162 | } |
163 | auto dstMap = |
164 | AffineMap::get(/*dimCount=*/reductionMask.size(), |
165 | /*symbolCount=*/0, exprs, reduceOp.getContext()); |
166 | rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( |
167 | reduceOp, mulOp->getOperand(idx: 0), mulOp->getOperand(idx: 1), reduceOp.getAcc(), |
168 | rewriter.getAffineMapArrayAttr(values: {srcMap, srcMap, dstMap}), |
169 | rewriter.getArrayAttr(value: llvm::to_vector(llvm::map_range( |
170 | iteratorTypes, [&](IteratorType t) -> mlir::Attribute { |
171 | return IteratorTypeAttr::get(rewriter.getContext(), t); |
172 | })))); |
173 | return success(); |
174 | } |
175 | }; |
176 | |
177 | /// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user. |
178 | /// Ex: |
179 | /// ``` |
180 | /// %0 = vector.transpose %arg0, [2, 0, 1] |
181 | /// : vector<32x16x8xf32> to vector<8x32x16xf32> |
182 | /// %1 = vector.contract {indexing_maps = [ |
183 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
184 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
185 | /// affine_map<(d0, d1, d2) -> (d0, d1)>], |
186 | /// iterator_types = ["parallel", "parallel", "reduction"], |
187 | /// kind = add} %0, %arg1, %cst_f0 |
188 | /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> |
189 | /// ``` |
190 | /// Gets converted to: |
191 | /// ``` |
192 | /// %1 = vector.contract {indexing_maps = [ |
193 | /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, |
194 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
195 | /// affine_map<(d0, d1, d2) -> (d0, d1)>], |
196 | /// iterator_types = ["parallel", "parallel", "reduction"], |
197 | /// kind = add} %arg0, %arg1, %cst_f0 |
198 | /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> |
199 | /// ``` |
200 | struct CombineContractABTranspose final |
201 | : public OpRewritePattern<vector::ContractionOp> { |
202 | using OpRewritePattern::OpRewritePattern; |
203 | |
204 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
205 | PatternRewriter &rewriter) const override { |
206 | SmallVector<AffineMap> maps = |
207 | llvm::to_vector<4>(contractOp.getIndexingMapsArray()); |
208 | Value lhs = contractOp.getLhs(); |
209 | Value rhs = contractOp.getRhs(); |
210 | size_t index = 0; |
211 | bool changed = false; |
212 | for (Value *operand : {&lhs, &rhs}) { |
213 | AffineMap &map = maps[index++]; |
214 | auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); |
215 | if (!transposeOp) |
216 | continue; |
217 | AffineMap permutationMap = AffineMap::getPermutationMap( |
218 | transposeOp.getPermutation(), contractOp.getContext()); |
219 | map = inversePermutation(permutationMap).compose(map); |
220 | *operand = transposeOp.getVector(); |
221 | changed = true; |
222 | } |
223 | if (!changed) |
224 | return failure(); |
225 | rewriter.replaceOpWithNewOp<vector::ContractionOp>( |
226 | contractOp, lhs, rhs, contractOp.getAcc(), |
227 | rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); |
228 | return success(); |
229 | } |
230 | }; |
231 | |
232 | /// Merges accumulator and result transposes into contract. |
233 | /// |
234 | /// For example: |
235 | /// ```mlir |
236 | /// %accT = vector.transpose %acc, [0, 2, 1] |
237 | /// : vector<2x8x4xf32> to vector<2x4x8xf32> |
238 | /// %contract = vector.contract { |
239 | /// indexing_maps = [ |
240 | /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, |
241 | /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, |
242 | /// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> |
243 | /// ], |
244 | /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], |
245 | /// kind = #vector.kind<add> |
246 | /// } %lhs, %rhs, %accT |
247 | /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> |
248 | /// %0 = vector.transpose %contract, [0, 2, 1] |
249 | /// : vector<2x4x8xf32> to vector<2x8x4> |
250 | /// ``` |
251 | /// Becomes: |
252 | /// ```mlir |
253 | /// %0 = vector.contract { |
254 | /// indexing_maps = [ |
255 | /// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, |
256 | /// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, |
257 | /// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> |
258 | /// ], |
259 | /// iterator_types = ["parallel", "parallel", "parallel", "reduction"], |
260 | /// kind = #vector.kind<add> |
261 | /// } %lhs, %rhs, %acc |
262 | /// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32> |
263 | /// ``` |
264 | struct CombineContractResultTranspose final |
265 | : public OpRewritePattern<vector::TransposeOp> { |
266 | using OpRewritePattern::OpRewritePattern; |
267 | |
268 | LogicalResult matchAndRewrite(vector::TransposeOp resTOp, |
269 | PatternRewriter &rewriter) const override { |
270 | auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>(); |
271 | if (!contractOp || !contractOp->hasOneUse()) |
272 | return failure(); |
273 | |
274 | auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>(); |
275 | if (!accTOp) |
276 | return failure(); |
277 | |
278 | MLIRContext *context = contractOp.getContext(); |
279 | auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray()); |
280 | AffineMap contractMap = maps.back(); |
281 | |
282 | // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B. |
283 | // To index into A in contract, we need revert(f)(g(C)) -> A. |
284 | auto accTMap = |
285 | AffineMap::getPermutationMap(accTOp.getPermutation(), context); |
286 | |
287 | // Contract performs g(C) -> D. Result transpose performs h(D) -> E. |
288 | // To index into E in contract, we need h(g(C)) -> E. |
289 | auto resTMap = |
290 | AffineMap::getPermutationMap(resTOp.getPermutation(), context); |
291 | auto combinedResMap = resTMap.compose(contractMap); |
292 | |
293 | // The accumulator and result share the same indexing map. So they should be |
294 | // the same to be able to merge. This means combinedResMap is the same as |
295 | // inversePermutation(accTMap).compose(contractMap), which means |
296 | if (inversePermutation(accTMap) != resTMap) |
297 | return failure(); |
298 | maps.back() = combinedResMap; |
299 | |
300 | rewriter.replaceOpWithNewOp<vector::ContractionOp>( |
301 | resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(), |
302 | rewriter.getAffineMapArrayAttr(values: maps), contractOp.getIteratorTypes()); |
303 | return success(); |
304 | } |
305 | }; |
306 | |
307 | /// Merge BroadcastOp into ContractionOp user. |
308 | /// Ex: |
309 | /// ``` |
310 | /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> |
311 | /// %1 = vector.contract {indexing_maps = [ |
312 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
313 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
314 | /// affine_map<(d0, d1, d2) -> (d0, d1)>], |
315 | /// iterator_types = ["parallel", "parallel", "reduction"], |
316 | /// kind = add} %0, %arg1, %cst_f0 |
317 | /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> |
318 | /// ``` |
319 | /// Gets converted to: |
320 | /// ``` |
321 | /// %1 = vector.contract {indexing_maps = [ |
322 | /// affine_map<(d0, d1, d2) -> (d1, d2)>, |
323 | /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, |
324 | /// affine_map<(d0, d1, d2) -> (d0, d1)>], |
325 | /// iterator_types = ["parallel", "parallel", "reduction"], |
326 | /// kind = add} %arg0, %arg1, %cst_f0 |
327 | /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> |
328 | /// ``` |
329 | struct CombineContractBroadcast |
330 | : public OpRewritePattern<vector::ContractionOp> { |
331 | using OpRewritePattern::OpRewritePattern; |
332 | |
333 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
334 | PatternRewriter &rewriter) const override { |
335 | SmallVector<AffineMap> maps = |
336 | llvm::to_vector<4>(contractOp.getIndexingMapsArray()); |
337 | Value lhs = contractOp.getLhs(); |
338 | Value rhs = contractOp.getRhs(); |
339 | size_t index = 0; |
340 | bool changed = false; |
341 | for (Value *operand : {&lhs, &rhs}) { |
342 | AffineMap &map = maps[index++]; |
343 | auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); |
344 | if (!broadcast) |
345 | continue; |
346 | // contractionOp can only take vector as operands. |
347 | auto srcType = dyn_cast<VectorType>(broadcast.getSourceType()); |
348 | if (!srcType || |
349 | srcType.getRank() == broadcast.getResultVectorType().getRank()) |
350 | continue; |
351 | int64_t rankDiff = |
352 | broadcast.getResultVectorType().getRank() - srcType.getRank(); |
353 | bool innerDimBroadcast = false; |
354 | SmallVector<AffineExpr> originalDims; |
355 | for (const auto &dim : llvm::enumerate(srcType.getShape())) { |
356 | if (dim.value() != broadcast.getResultVectorType().getDimSize( |
357 | rankDiff + dim.index())) { |
358 | innerDimBroadcast = true; |
359 | break; |
360 | } |
361 | originalDims.push_back( |
362 | rewriter.getAffineDimExpr(dim.index() + rankDiff)); |
363 | } |
364 | // Contract doesn't support inner dimension broadcast. Once this is |
365 | // relaxed we can remove this case. |
366 | if (innerDimBroadcast) |
367 | continue; |
368 | |
369 | // It would be incorrect to fold a broadcast onto a reduction dimension |
370 | // of non-unit size. |
371 | bool nonUnitDimReductionBroadcast = false; |
372 | for (int64_t i = 0; i < rankDiff; ++i) { |
373 | if (broadcast.getResultVectorType().getDimSize(i) != 1 && |
374 | isReductionIterator(contractOp.getIteratorTypes() |
375 | .getValue()[map.getDimPosition(i)])) { |
376 | nonUnitDimReductionBroadcast = true; |
377 | break; |
378 | } |
379 | } |
380 | if (nonUnitDimReductionBroadcast) |
381 | continue; |
382 | |
383 | AffineMap broadcastMap = |
384 | AffineMap::get(broadcast.getResultVectorType().getRank(), 0, |
385 | originalDims, contractOp.getContext()); |
386 | map = broadcastMap.compose(map); |
387 | *operand = broadcast.getSource(); |
388 | changed = true; |
389 | } |
390 | |
391 | if (!changed) |
392 | return failure(); |
393 | |
394 | // Determine which dims are usused, now that the maps have been composed |
395 | // with the broadcast maps. |
396 | llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); |
397 | // Compress unused dims. |
398 | for (auto &m : maps) |
399 | m = compressDims(m, unusedDimsBitVector); |
400 | // Compute the combined iterators. |
401 | SmallVector<Attribute> iterators; |
402 | for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { |
403 | if (!unusedDimsBitVector.test(Idx: i)) |
404 | iterators.push_back(Elt: contractOp.getIteratorTypes().getValue()[i]); |
405 | } |
406 | // Check that compressing unused dims isn't removing all reduction dimension |
407 | // pairs. For example, if the vector.contract had only one reduction |
408 | // iterator and that was a unit-dimension created by a broadcast, |
409 | // then we should bail here, otherwise we would create a contract without |
410 | // a reduction dimension pair. |
411 | bool hasReductionIteratorApplyingOnBothSides = false; |
412 | for (unsigned i = 0; i < iterators.size(); ++i) { |
413 | if (!isReductionIterator(attr: iterators[i])) |
414 | continue; |
415 | if (getResultIndex(map: maps[0], index: i) && getResultIndex(map: maps[1], index: i)) { |
416 | hasReductionIteratorApplyingOnBothSides = true; |
417 | break; |
418 | } |
419 | } |
420 | if (!hasReductionIteratorApplyingOnBothSides) |
421 | return failure(); |
422 | |
423 | // If the compressed maps have a dimension that is not used by either LHS or |
424 | // RHS then the ContractionOp verifier would fail. |
425 | if (getUnusedDimsBitVector(maps: {maps[0], maps[1]}).any()) |
426 | return failure(); |
427 | rewriter.replaceOpWithNewOp<vector::ContractionOp>( |
428 | contractOp, lhs, rhs, contractOp.getAcc(), |
429 | rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); |
430 | return success(); |
431 | } |
432 | }; |
433 | |
434 | /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and |
435 | /// contraction ops closer, which kicks in CombineContractBroadcast pattern when |
436 | /// casting ops are around these operations. |
437 | /// Ex: |
438 | /// ``` |
439 | /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> |
440 | /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> |
441 | /// ``` |
442 | /// Gets converted to: |
443 | /// ``` |
444 | /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> |
445 | /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> |
446 | /// ``` |
447 | struct ReorderCastOpsOnBroadcast |
448 | : public OpInterfaceRewritePattern<CastOpInterface> { |
449 | using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; |
450 | |
451 | LogicalResult matchAndRewrite(CastOpInterface op, |
452 | PatternRewriter &rewriter) const override { |
453 | if (op->getNumOperands() != 1) |
454 | return failure(); |
455 | auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); |
456 | if (!bcastOp) |
457 | return failure(); |
458 | |
459 | Type castResTy = getElementTypeOrSelf(op->getResult(0)); |
460 | if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType())) |
461 | castResTy = vecTy.clone(castResTy); |
462 | auto *castOp = |
463 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), |
464 | bcastOp.getSource(), castResTy, op->getAttrs()); |
465 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>( |
466 | op, op->getResult(0).getType(), castOp->getResult(0)); |
467 | return success(); |
468 | } |
469 | }; |
470 | |
471 | /// Reorders elementwise(transpose) to transpose(elementwise). This makes |
472 | /// transpose ops and contraction ops closer, which kicks in |
473 | /// CombineContractABTranspose pattern when elementwise ops are between these |
474 | /// operations. Ex: |
475 | /// ``` |
476 | /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> |
477 | /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> |
478 | /// %r = arith.addf %at, %bt : vector<2x4xf32> |
479 | /// ``` |
480 | /// Gets converted to: |
481 | /// ``` |
482 | /// %0 = arith.addf %a, %b : vector<4x2xf32> |
483 | /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> |
484 | /// ``` |
485 | struct ReorderElementwiseOpsOnTranspose final |
486 | : public OpTraitRewritePattern<OpTrait::Elementwise> { |
487 | using OpTraitRewritePattern::OpTraitRewritePattern; |
488 | LogicalResult matchAndRewrite(Operation *op, |
489 | PatternRewriter &rewriter) const override { |
490 | if (op->getNumResults() != 1 || op->getNumRegions() != 0) |
491 | return failure(); |
492 | |
493 | // Make sure all operands are transpose/constant ops and collect their |
494 | // transposition maps. |
495 | SmallVector<ArrayRef<int64_t>> transposeMaps; |
496 | transposeMaps.reserve(N: op->getNumOperands()); |
497 | // Record the initial type before transposition. We'll use its shape later. |
498 | // Any type will do here as we will check all transpose maps are the same. |
499 | VectorType srcType; |
500 | for (Value operand : op->getOperands()) { |
501 | auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); |
502 | if (transposeOp) { |
503 | transposeMaps.push_back(Elt: transposeOp.getPermutation()); |
504 | srcType = transposeOp.getSourceVectorType(); |
505 | } else if (!matchPattern(value: operand, pattern: m_Constant())) { |
506 | return failure(); |
507 | } |
508 | } |
509 | if (transposeMaps.empty()) |
510 | return failure(); |
511 | // This is an elementwise op, so all transposed operands should have the |
512 | // same type. We need to additionally check that all transposes uses the |
513 | // same map. |
514 | if (!llvm::all_equal(Range&: transposeMaps)) |
515 | return rewriter.notifyMatchFailure(arg&: op, msg: "different transpose map" ); |
516 | |
517 | SmallVector<Value> srcValues; |
518 | srcValues.reserve(N: op->getNumOperands()); |
519 | |
520 | // If there are constant operands, we need to insert inverse transposes for |
521 | // them. Calculate the inverse order first. |
522 | auto order = transposeMaps.front(); |
523 | SmallVector<int64_t> invOrder(order.size()); |
524 | for (int i = 0, e = order.size(); i < e; ++i) |
525 | invOrder[order[i]] = i; |
526 | |
527 | for (Value operand : op->getOperands()) { |
528 | auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); |
529 | if (transposeOp) { |
530 | srcValues.push_back(Elt: transposeOp.getVector()); |
531 | } else { |
532 | // This is a constant. Create a reverse transpose op for it. |
533 | auto vectorType = |
534 | srcType.clone(cast<VectorType>(operand.getType()).getElementType()); |
535 | srcValues.push_back(rewriter.create<vector::TransposeOp>( |
536 | operand.getLoc(), vectorType, operand, invOrder)); |
537 | } |
538 | } |
539 | |
540 | auto vectorType = srcType.clone( |
541 | cast<VectorType>(op->getResultTypes()[0]).getElementType()); |
542 | Operation *elementwiseOp = |
543 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, |
544 | vectorType, op->getAttrs()); |
545 | rewriter.replaceOpWithNewOp<vector::TransposeOp>( |
546 | op, op->getResultTypes()[0], elementwiseOp->getResult(0), |
547 | transposeMaps.front()); |
548 | return success(); |
549 | } |
550 | }; |
551 | |
552 | // Returns the values in `arrayAttr` as an integer vector. |
553 | static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { |
554 | return llvm::to_vector<4>( |
555 | llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), |
556 | [](IntegerAttr attr) { return attr.getInt(); })); |
557 | } |
558 | |
559 | // Shuffles vector.bitcast op after vector.extract op. |
560 | // |
561 | // This transforms IR like: |
562 | // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> |
563 | // %1 = vector.extract %0[3] : f16 from vector<8xf16> |
564 | // Into: |
565 | // %0 = vector.extract %src[1] : f32 from vector<4xf32> |
566 | // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> |
567 | // %2 = vector.extract %1[1] : f16 from vector<2xf16> |
568 | struct |
569 | : public OpRewritePattern<vector::ExtractOp> { |
570 | using OpRewritePattern::OpRewritePattern; |
571 | |
572 | LogicalResult matchAndRewrite(vector::ExtractOp , |
573 | PatternRewriter &rewriter) const override { |
574 | // Only support extracting scalars for now. |
575 | if (extractOp.getSourceVectorType().getRank() != 1) |
576 | return failure(); |
577 | |
578 | auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); |
579 | if (!castOp) |
580 | return failure(); |
581 | |
582 | VectorType castSrcType = castOp.getSourceVectorType(); |
583 | VectorType castDstType = castOp.getResultVectorType(); |
584 | assert(castSrcType.getRank() == castDstType.getRank()); |
585 | |
586 | // Fail to match if we only have one element in the cast op source. |
587 | // This is to avoid infinite loop given that this pattern can generate |
588 | // such cases. |
589 | if (castSrcType.getNumElements() == 1) |
590 | return failure(); |
591 | |
592 | // Only support casting to a larger number of elements or now. |
593 | // E.g., vector<4xf32> -> vector<8xf16>. |
594 | if (castSrcType.getNumElements() > castDstType.getNumElements()) |
595 | return failure(); |
596 | |
597 | unsigned expandRatio = |
598 | castDstType.getNumElements() / castSrcType.getNumElements(); |
599 | |
600 | auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t { |
601 | assert(values[0].is<Attribute>() && "Unexpected non-constant index" ); |
602 | return cast<IntegerAttr>(values[0].get<Attribute>()).getInt(); |
603 | }; |
604 | |
605 | uint64_t index = getFirstIntValue(extractOp.getMixedPosition()); |
606 | |
607 | // Get the single scalar (as a vector) in the source value that packs the |
608 | // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> |
609 | Location loc = extractOp.getLoc(); |
610 | Value packedValue = rewriter.create<vector::ExtractOp>( |
611 | loc, castOp.getSource(), index / expandRatio); |
612 | Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType()); |
613 | Value zero = rewriter.create<arith::ConstantOp>( |
614 | loc, packedVecType, rewriter.getZeroAttr(packedVecType)); |
615 | packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero, |
616 | /*position=*/0); |
617 | |
618 | // Cast it to a vector with the desired scalar's type. |
619 | // E.g. f32 -> vector<2xf16> |
620 | VectorType packedType = |
621 | VectorType::get({expandRatio}, castDstType.getElementType()); |
622 | Value castedValue = |
623 | rewriter.create<vector::BitCastOp>(loc, packedType, packedValue); |
624 | |
625 | // Finally extract the desired scalar. |
626 | rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue, |
627 | index % expandRatio); |
628 | return success(); |
629 | } |
630 | }; |
631 | |
632 | // Shuffles vector.bitcast op after vector.extract_strided_slice op. |
633 | // |
634 | // This transforms IR like: |
635 | // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> |
636 | // %0 = vector.extract_strided_slice %cast { |
637 | // offsets = [4], sizes = [4], strides = [1] |
638 | // } : vector<8xf16> to vector<4xf16> |
639 | // Into: |
640 | // %0 = vector.extract_strided_slice %src { |
641 | // offsets = [2], sizes = [2], strides = [1] |
642 | // } : vector<4xf32> to vector<2xf32> |
643 | // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> |
644 | struct |
645 | : public OpRewritePattern<vector::ExtractStridedSliceOp> { |
646 | using OpRewritePattern::OpRewritePattern; |
647 | |
648 | LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp , |
649 | PatternRewriter &rewriter) const override { |
650 | auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); |
651 | if (!castOp) |
652 | return failure(); |
653 | |
654 | VectorType castSrcType = castOp.getSourceVectorType(); |
655 | VectorType castDstType = castOp.getResultVectorType(); |
656 | assert(castSrcType.getRank() == castDstType.getRank()); |
657 | |
658 | int64_t castSrcLastDim = castSrcType.getShape().back(); |
659 | int64_t castDstLastDim = castDstType.getShape().back(); |
660 | // Require casting to more elements for now; other cases to be implemented. |
661 | if (castSrcLastDim > castDstLastDim) |
662 | return failure(); |
663 | |
664 | // Only accept all one strides for now. |
665 | if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), |
666 | [](const APInt &val) { return !val.isOne(); })) |
667 | return failure(); |
668 | |
669 | unsigned rank = extractOp.getSourceVectorType().getRank(); |
670 | assert(castDstLastDim % castSrcLastDim == 0); |
671 | int64_t expandRatio = castDstLastDim / castSrcLastDim; |
672 | |
673 | // If we have a less number of offsets than the rank, then implicitly we |
674 | // are selecting the full range for the last bitcasted dimension; other |
675 | // dimensions aren't affected. Otherwise, we need to scale down the last |
676 | // dimension's offset given we are extracting from less elements now. |
677 | ArrayAttr newOffsets = extractOp.getOffsets(); |
678 | if (newOffsets.size() == rank) { |
679 | SmallVector<int64_t> offsets = getIntValueVector(newOffsets); |
680 | if (offsets.back() % expandRatio != 0) |
681 | return failure(); |
682 | offsets.back() = offsets.back() / expandRatio; |
683 | newOffsets = rewriter.getI64ArrayAttr(offsets); |
684 | } |
685 | |
686 | // Similarly for sizes. |
687 | ArrayAttr newSizes = extractOp.getSizes(); |
688 | if (newSizes.size() == rank) { |
689 | SmallVector<int64_t> sizes = getIntValueVector(newSizes); |
690 | if (sizes.back() % expandRatio != 0) |
691 | return failure(); |
692 | sizes.back() = sizes.back() / expandRatio; |
693 | newSizes = rewriter.getI64ArrayAttr(sizes); |
694 | } |
695 | |
696 | SmallVector<int64_t> dims = |
697 | llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape()); |
698 | dims.back() = dims.back() / expandRatio; |
699 | VectorType = |
700 | VectorType::get(dims, castSrcType.getElementType()); |
701 | |
702 | auto = rewriter.create<vector::ExtractStridedSliceOp>( |
703 | extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, |
704 | newSizes, extractOp.getStrides()); |
705 | |
706 | rewriter.replaceOpWithNewOp<vector::BitCastOp>( |
707 | extractOp, extractOp.getType(), newExtractOp); |
708 | |
709 | return success(); |
710 | } |
711 | }; |
712 | |
713 | // Shuffles vector.bitcast op before vector.insert_strided_slice op. |
714 | // |
715 | // This transforms IR like: |
716 | // %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4> |
717 | // %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8> |
718 | // Into: |
719 | // %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8> |
720 | // %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8> |
721 | // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> |
722 | // |
723 | struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { |
724 | using OpRewritePattern::OpRewritePattern; |
725 | |
726 | LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, |
727 | PatternRewriter &rewriter) const override { |
728 | VectorType castSrcType = bitcastOp.getSourceVectorType(); |
729 | VectorType castDstType = bitcastOp.getResultVectorType(); |
730 | |
731 | // 0-D and scalable vectors are not supported yet. |
732 | if (castSrcType.getRank() == 0 || castSrcType.isScalable() || |
733 | castDstType.isScalable()) |
734 | return failure(); |
735 | |
736 | int64_t castSrcLastDim = castSrcType.getShape().back(); |
737 | int64_t castDstLastDim = castDstType.getShape().back(); |
738 | bool isNumElemsShrink = castSrcLastDim >= castDstLastDim; |
739 | int64_t ratio; |
740 | if (isNumElemsShrink) { |
741 | assert(castSrcLastDim % castDstLastDim == 0); |
742 | ratio = castSrcLastDim / castDstLastDim; |
743 | } else { |
744 | assert(castDstLastDim % castSrcLastDim == 0); |
745 | ratio = castDstLastDim / castSrcLastDim; |
746 | } |
747 | |
748 | auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>(); |
749 | if (!insertOp) |
750 | return failure(); |
751 | |
752 | // Only vector sources are supported for now. |
753 | auto insertSrcType = dyn_cast<VectorType>(insertOp.getSourceType()); |
754 | if (!insertSrcType) |
755 | return failure(); |
756 | |
757 | // Bitcast the source. |
758 | SmallVector<int64_t> srcDims(insertSrcType.getShape()); |
759 | srcDims.back() = |
760 | isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio; |
761 | VectorType newCastSrcType = |
762 | VectorType::get(srcDims, castDstType.getElementType()); |
763 | auto newCastSrcOp = rewriter.create<vector::BitCastOp>( |
764 | bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); |
765 | |
766 | SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape()); |
767 | dstDims.back() = |
768 | isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio; |
769 | VectorType newCastDstType = |
770 | VectorType::get(dstDims, castDstType.getElementType()); |
771 | |
772 | // Bitcast the destination. |
773 | auto newCastDstOp = rewriter.create<vector::BitCastOp>( |
774 | bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); |
775 | |
776 | // Generate new insert. |
777 | rewriter.replaceOpWithNewOp<vector::InsertOp>( |
778 | bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition()); |
779 | return success(); |
780 | } |
781 | }; |
782 | |
783 | // Shuffles vector.bitcast op before vector.insert_strided_slice op. |
784 | // |
785 | // This transforms IR like: |
786 | // %0 = vector.insert_strided_slice %src, %dst { |
787 | // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> |
788 | // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> |
789 | // Into: |
790 | // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> |
791 | // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> |
792 | // %2 = vector.insert_strided_slice %src, %dst { |
793 | // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> |
794 | struct BubbleUpBitCastForStridedSliceInsert |
795 | : public OpRewritePattern<vector::BitCastOp> { |
796 | using OpRewritePattern::OpRewritePattern; |
797 | |
798 | LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, |
799 | PatternRewriter &rewriter) const override { |
800 | VectorType castSrcType = bitcastOp.getSourceVectorType(); |
801 | VectorType castDstType = bitcastOp.getResultVectorType(); |
802 | assert(castSrcType.getRank() == castDstType.getRank()); |
803 | // Skip 0-D vector which will not from InsertStridedSliceOp. |
804 | if (castSrcType.getRank() == 0) |
805 | return failure(); |
806 | |
807 | int64_t castSrcLastDim = castSrcType.getShape().back(); |
808 | int64_t castDstLastDim = castDstType.getShape().back(); |
809 | // Require casting to less elements for now; other cases to be implemented. |
810 | if (castSrcLastDim < castDstLastDim) |
811 | return failure(); |
812 | |
813 | assert(castSrcLastDim % castDstLastDim == 0); |
814 | int64_t shrinkRatio = castSrcLastDim / castDstLastDim; |
815 | |
816 | auto insertOp = |
817 | bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); |
818 | if (!insertOp) |
819 | return failure(); |
820 | |
821 | // Only accept all one strides for now. |
822 | if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), |
823 | [](const APInt &val) { return !val.isOne(); })) |
824 | return failure(); |
825 | |
826 | unsigned rank = insertOp.getSourceVectorType().getRank(); |
827 | // Require insert op to have the same rank for the source and destination |
828 | // vector; other cases to be implemented. |
829 | if (rank != insertOp.getDestVectorType().getRank()) |
830 | return failure(); |
831 | |
832 | // Requires that shape of insert op src is castable to dstType. |
833 | unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); |
834 | unsigned destinationWidth = |
835 | castDstType.getElementType().getIntOrFloatBitWidth(); |
836 | unsigned numElements = destinationWidth / sourceWidth; |
837 | if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) |
838 | return failure(); |
839 | |
840 | ArrayAttr newOffsets = insertOp.getOffsets(); |
841 | assert(newOffsets.size() == rank); |
842 | SmallVector<int64_t> offsets = getIntValueVector(newOffsets); |
843 | if (offsets.back() % shrinkRatio != 0) |
844 | return failure(); |
845 | offsets.back() = offsets.back() / shrinkRatio; |
846 | newOffsets = rewriter.getI64ArrayAttr(offsets); |
847 | |
848 | SmallVector<int64_t> srcDims = |
849 | llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); |
850 | srcDims.back() = srcDims.back() / shrinkRatio; |
851 | VectorType newCastSrcType = |
852 | VectorType::get(srcDims, castDstType.getElementType()); |
853 | |
854 | auto newCastSrcOp = rewriter.create<vector::BitCastOp>( |
855 | bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); |
856 | |
857 | SmallVector<int64_t> dstDims = |
858 | llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); |
859 | dstDims.back() = dstDims.back() / shrinkRatio; |
860 | VectorType newCastDstType = |
861 | VectorType::get(dstDims, castDstType.getElementType()); |
862 | |
863 | auto newCastDstOp = rewriter.create<vector::BitCastOp>( |
864 | bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); |
865 | |
866 | rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( |
867 | bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, |
868 | insertOp.getStrides()); |
869 | |
870 | return success(); |
871 | } |
872 | }; |
873 | |
874 | // Breaks down vector.bitcast op |
875 | // |
876 | // This transforms IR like: |
877 | // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> |
878 | // Into: |
879 | // %cst = vector.splat %c0_f32 : vector<4xf32> |
880 | // %1 = vector.extract_strided_slice %0 { |
881 | // offsets = [0], sizes = [4], strides = [1] |
882 | // } : vector<8xf16> to vector<4xf16> |
883 | // %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32> |
884 | // %4 = vector.insert_strided_slice %2, %cst { |
885 | // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> |
886 | // %5 = vector.extract_strided_slice %0 { |
887 | // offsets = [4], sizes = [4], strides = [1] |
888 | // } : vector<8xf16> to vector<4xf16> |
889 | // %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32> |
890 | // %7 = vector.insert_strided_slice %6, %cst { |
891 | // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> |
892 | struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { |
893 | using OpRewritePattern::OpRewritePattern; |
894 | |
895 | public: |
896 | BreakDownVectorBitCast(MLIRContext *context, |
897 | std::function<bool(vector::BitCastOp)> controlFn, |
898 | PatternBenefit benefit) |
899 | : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} |
900 | |
901 | LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, |
902 | PatternRewriter &rewriter) const override { |
903 | |
904 | if (controlFn && !controlFn(bitcastOp)) |
905 | return failure(); |
906 | |
907 | VectorType castSrcType = bitcastOp.getSourceVectorType(); |
908 | VectorType castDstType = bitcastOp.getResultVectorType(); |
909 | assert(castSrcType.getRank() == castDstType.getRank()); |
910 | |
911 | // Only support rank 1 case for now. |
912 | if (castSrcType.getRank() != 1) |
913 | return failure(); |
914 | |
915 | int64_t castSrcLastDim = castSrcType.getShape().back(); |
916 | int64_t castDstLastDim = castDstType.getShape().back(); |
917 | // Require casting to less elements for now; other cases to be implemented. |
918 | if (castSrcLastDim < castDstLastDim) |
919 | return failure(); |
920 | |
921 | assert(castSrcLastDim % castDstLastDim == 0); |
922 | int64_t shrinkRatio = castSrcLastDim / castDstLastDim; |
923 | // Nothing to do if it is already bitcasting to a single element. |
924 | if (castSrcLastDim == shrinkRatio) |
925 | return failure(); |
926 | |
927 | Location loc = bitcastOp.getLoc(); |
928 | Type elemType = castDstType.getElementType(); |
929 | assert(elemType.isSignlessIntOrIndexOrFloat()); |
930 | |
931 | Value zero = rewriter.create<arith::ConstantOp>( |
932 | loc, elemType, rewriter.getZeroAttr(elemType)); |
933 | Value res = rewriter.create<SplatOp>(loc, castDstType, zero); |
934 | |
935 | SmallVector<int64_t> sliceShape{castDstLastDim}; |
936 | SmallVector<int64_t> strides{1}; |
937 | VectorType newCastDstType = |
938 | VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio}, |
939 | castDstType.getElementType()); |
940 | |
941 | for (int i = 0, e = shrinkRatio; i < e; ++i) { |
942 | Value = rewriter.create<ExtractStridedSliceOp>( |
943 | loc, bitcastOp.getSource(), ArrayRef<int64_t>{i * castDstLastDim}, |
944 | sliceShape, strides); |
945 | Value bitcast = |
946 | rewriter.create<BitCastOp>(loc, newCastDstType, extracted); |
947 | res = rewriter.create<InsertStridedSliceOp>( |
948 | loc, bitcast, res, |
949 | ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides); |
950 | } |
951 | rewriter.replaceOp(bitcastOp, res); |
952 | return success(); |
953 | } |
954 | |
955 | private: |
956 | std::function<bool(BitCastOp)> controlFn; |
957 | }; |
958 | |
959 | /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: |
960 | /// ``` |
961 | /// %a = vector.broadcast %arg1 : index to vector<1x4xindex> |
962 | /// %b = vector.broadcast %arg2 : index to vector<1x4xindex> |
963 | /// %r = arith.addi %a, %b : vector<1x4xindex> |
964 | /// ``` |
965 | /// Gets converted to: |
966 | /// ``` |
967 | /// %r = arith.addi %arg0, %arg1 : index |
968 | /// %b = vector.broadcast %r : index to vector<1x4xindex> |
969 | /// ``` |
970 | /// |
971 | /// Both `vector.broadcast` and `vector.splat` are supported as broadcasting |
972 | /// ops. |
973 | struct ReorderElementwiseOpsOnBroadcast final |
974 | : public OpTraitRewritePattern<OpTrait::Elementwise> { |
975 | using OpTraitRewritePattern::OpTraitRewritePattern; |
976 | LogicalResult matchAndRewrite(Operation *op, |
977 | PatternRewriter &rewriter) const override { |
978 | if (op->getNumResults() != 1) |
979 | return failure(); |
980 | if (!llvm::isa<ShapedType>(Val: op->getResults()[0].getType())) |
981 | return failure(); |
982 | if (!OpTrait::hasElementwiseMappableTraits(op)) |
983 | return failure(); |
984 | if (op->getNumOperands() == 0 || |
985 | op->getResults()[0].getType() != op->getOperand(idx: 0).getType()) { |
986 | return failure(); |
987 | } |
988 | // Avoid operations that only accept vector types, since broadcast |
989 | // source might be scalar types. |
990 | if (isa<vector::FMAOp>(op)) { |
991 | return failure(); |
992 | } |
993 | |
994 | // Get the type of the lhs operand |
995 | auto *lhsBcastOrSplat = op->getOperand(idx: 0).getDefiningOp(); |
996 | if (!lhsBcastOrSplat || |
997 | !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) |
998 | return failure(); |
999 | auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(idx: 0).getType(); |
1000 | |
1001 | // Make sure that all operands are broadcast from identical types: |
1002 | // * scalar (`vector.broadcast` + `vector.splat`), or |
1003 | // * vector (`vector.broadcast`). |
1004 | // Otherwise the re-ordering wouldn't be safe. |
1005 | if (!llvm::all_of(Range: op->getOperands(), P: [&lhsBcastOrSplatType](Value val) { |
1006 | auto bcast = val.getDefiningOp<vector::BroadcastOp>(); |
1007 | if (bcast) |
1008 | return (bcast.getOperand().getType() == lhsBcastOrSplatType); |
1009 | auto splat = val.getDefiningOp<vector::SplatOp>(); |
1010 | if (splat) |
1011 | return (splat.getOperand().getType() == lhsBcastOrSplatType); |
1012 | return false; |
1013 | })) { |
1014 | return failure(); |
1015 | } |
1016 | |
1017 | // Collect the source values before broadcasting |
1018 | SmallVector<Value> srcValues; |
1019 | srcValues.reserve(N: op->getNumOperands()); |
1020 | for (Value operand : op->getOperands()) { |
1021 | srcValues.push_back(Elt: operand.getDefiningOp()->getOperand(idx: 0)); |
1022 | } |
1023 | |
1024 | // Create the "elementwise" Op |
1025 | Operation *elementwiseOp = |
1026 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, |
1027 | lhsBcastOrSplatType, op->getAttrs()); |
1028 | |
1029 | // Replace the original Op with the elementwise Op |
1030 | auto vectorType = op->getResultTypes()[0]; |
1031 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>( |
1032 | op, vectorType, elementwiseOp->getResults()); |
1033 | |
1034 | return success(); |
1035 | } |
1036 | }; |
1037 | |
1038 | // Helper that returns a vector comparison that constructs a mask: |
1039 | // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] |
1040 | // |
1041 | // If `dim == 0` then the result will be a 0-D vector. |
1042 | // |
1043 | // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, |
1044 | // much more compact, IR for this operation, but LLVM eventually |
1045 | // generates more elaborate instructions for this intrinsic since it |
1046 | // is very conservative on the boundary conditions. |
1047 | static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, |
1048 | bool force32BitVectorIndices, int64_t dim, |
1049 | Value b, Value *off = nullptr) { |
1050 | auto loc = op->getLoc(); |
1051 | // If we can assume all indices fit in 32-bit, we perform the vector |
1052 | // comparison in 32-bit to get a higher degree of SIMD parallelism. |
1053 | // Otherwise we perform the vector comparison using 64-bit indices. |
1054 | Type idxType = |
1055 | force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); |
1056 | DenseIntElementsAttr indicesAttr; |
1057 | if (dim == 0 && force32BitVectorIndices) { |
1058 | indicesAttr = DenseIntElementsAttr::get( |
1059 | VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); |
1060 | } else if (dim == 0) { |
1061 | indicesAttr = DenseIntElementsAttr::get( |
1062 | VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); |
1063 | } else if (force32BitVectorIndices) { |
1064 | indicesAttr = rewriter.getI32VectorAttr( |
1065 | values: llvm::to_vector<4>(Range: llvm::seq<int32_t>(Begin: 0, End: dim))); |
1066 | } else { |
1067 | indicesAttr = rewriter.getI64VectorAttr( |
1068 | values: llvm::to_vector<4>(Range: llvm::seq<int64_t>(Begin: 0, End: dim))); |
1069 | } |
1070 | Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); |
1071 | // Add in an offset if requested. |
1072 | if (off) { |
1073 | Value o = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: *off); |
1074 | Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); |
1075 | indices = rewriter.create<arith::AddIOp>(loc, ov, indices); |
1076 | } |
1077 | // Construct the vector comparison. |
1078 | Value bound = getValueOrCreateCastToIndexLike(b&: rewriter, loc, targetType: idxType, value: b); |
1079 | Value bounds = |
1080 | rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); |
1081 | return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, |
1082 | bounds); |
1083 | } |
1084 | |
1085 | template <typename ConcreteOp> |
1086 | struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { |
1087 | public: |
1088 | explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt, |
1089 | PatternBenefit benefit = 1) |
1090 | : mlir::OpRewritePattern<ConcreteOp>(context, benefit), |
1091 | force32BitVectorIndices(enableIndexOpt) {} |
1092 | |
1093 | LogicalResult matchAndRewrite(ConcreteOp xferOp, |
1094 | PatternRewriter &rewriter) const override { |
1095 | if (!xferOp.hasOutOfBoundsDim()) |
1096 | return failure(); |
1097 | |
1098 | if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty()) |
1099 | return failure(); |
1100 | |
1101 | Location loc = xferOp->getLoc(); |
1102 | VectorType vtp = xferOp.getVectorType(); |
1103 | |
1104 | // Create the in-bounds mask with all elements between [0 .. dim - offset) |
1105 | // set and [dim - offset .. vector_length) unset. |
1106 | // |
1107 | // TODO: when the leaf transfer rank is k > 1, we need the last `k` |
1108 | // dimensions here. |
1109 | unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; |
1110 | Value off = xferOp.getIndices()[lastIndex]; |
1111 | Value dim = |
1112 | vector::createOrFoldDimOp(b&: rewriter, loc, source: xferOp.getSource(), dim: lastIndex); |
1113 | Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); |
1114 | Value mask = rewriter.create<vector::CreateMaskOp>( |
1115 | loc, |
1116 | VectorType::get(vtp.getShape(), rewriter.getI1Type(), |
1117 | vtp.getScalableDims()), |
1118 | b); |
1119 | if (xferOp.getMask()) { |
1120 | // Intersect the in-bounds with the mask specified as an op parameter. |
1121 | mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); |
1122 | } |
1123 | |
1124 | rewriter.modifyOpInPlace(xferOp, [&]() { |
1125 | xferOp.getMaskMutable().assign(mask); |
1126 | xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
1127 | }); |
1128 | |
1129 | return success(); |
1130 | } |
1131 | |
1132 | private: |
1133 | const bool force32BitVectorIndices; |
1134 | }; |
1135 | |
1136 | /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). |
1137 | class VectorCreateMaskOpConversion |
1138 | : public OpRewritePattern<vector::CreateMaskOp> { |
1139 | public: |
1140 | explicit VectorCreateMaskOpConversion(MLIRContext *context, |
1141 | bool enableIndexOpt, |
1142 | PatternBenefit benefit = 1) |
1143 | : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit), |
1144 | force32BitVectorIndices(enableIndexOpt) {} |
1145 | |
1146 | LogicalResult matchAndRewrite(vector::CreateMaskOp op, |
1147 | PatternRewriter &rewriter) const override { |
1148 | auto dstType = op.getType(); |
1149 | if (cast<VectorType>(dstType).isScalable()) |
1150 | return failure(); |
1151 | int64_t rank = dstType.getRank(); |
1152 | if (rank > 1) |
1153 | return failure(); |
1154 | rewriter.replaceOp( |
1155 | op, buildVectorComparison(rewriter, op, force32BitVectorIndices, |
1156 | rank == 0 ? 0 : dstType.getDimSize(0), |
1157 | op.getOperand(0))); |
1158 | return success(); |
1159 | } |
1160 | |
1161 | private: |
1162 | const bool force32BitVectorIndices; |
1163 | }; |
1164 | |
1165 | /// Returns true if all the `i1` elements of `constantOp` are set to `value`. |
1166 | static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { |
1167 | auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue()); |
1168 | // TODO: Support non-dense constant. |
1169 | if (!denseAttr) |
1170 | return false; |
1171 | |
1172 | assert(denseAttr.getElementType().isInteger(1) && "Unexpected type" ); |
1173 | return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value; |
1174 | } |
1175 | |
1176 | /// Folds a select operation between an all-true and all-false vector. For now, |
1177 | /// only single element vectors (i.e., vector<1xi1>) are supported. That is: |
1178 | /// |
1179 | /// %true = arith.constant dense<true> : vector<1xi1> |
1180 | /// %false = arith.constant dense<false> : vector<1xi1> |
1181 | /// %result = arith.select %cond, %true, %false : i1, vector<1xi1> |
1182 | /// => |
1183 | /// %result = vector.broadcast %cond : i1 to vector<1xi1> |
1184 | /// |
1185 | /// InstCombine seems to handle vectors with multiple elements but not the |
1186 | /// single element ones. |
1187 | struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { |
1188 | using OpRewritePattern<arith::SelectOp>::OpRewritePattern; |
1189 | |
1190 | LogicalResult matchAndRewrite(arith::SelectOp selectOp, |
1191 | PatternRewriter &rewriter) const override { |
1192 | auto vecType = dyn_cast<VectorType>(selectOp.getType()); |
1193 | if (!vecType || !vecType.getElementType().isInteger(1)) |
1194 | return failure(); |
1195 | |
1196 | // Only scalar conditions can be folded. |
1197 | Value cond = selectOp.getCondition(); |
1198 | if (isa<VectorType>(Val: cond.getType())) |
1199 | return failure(); |
1200 | |
1201 | // TODO: Support n-D and scalable vectors. |
1202 | if (vecType.getRank() != 1 || vecType.isScalable()) |
1203 | return failure(); |
1204 | |
1205 | // TODO: Support vectors with multiple elements. |
1206 | if (vecType.getShape()[0] != 1) |
1207 | return failure(); |
1208 | |
1209 | auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>(); |
1210 | if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true)) |
1211 | return failure(); |
1212 | |
1213 | auto falseConst = |
1214 | selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>(); |
1215 | if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false)) |
1216 | return failure(); |
1217 | |
1218 | // Replace select with its condition broadcasted to single element vector. |
1219 | auto elemType = rewriter.getIntegerType(vecType.getNumElements()); |
1220 | auto bcastType = VectorType::get(/*shape=*/{1}, elemType); |
1221 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond); |
1222 | return success(); |
1223 | } |
1224 | }; |
1225 | |
1226 | /// Returns the number of dims can be folded away from transfer ops. It returns |
1227 | /// a failure if it can not determine the number of dims to be folded. |
1228 | /// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and |
1229 | /// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims |
1230 | /// can be dropped by memref.subview ops. |
1231 | /// Example 2: it returns "1" if `srcType` is the same memref type with |
1232 | /// [8192, 16, 8, 1] strides. |
1233 | static FailureOr<size_t> |
1234 | getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { |
1235 | SmallVector<int64_t> srcStrides; |
1236 | int64_t srcOffset; |
1237 | if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) |
1238 | return failure(); |
1239 | |
1240 | // According to vector.transfer_read/write semantics, the vector can be a |
1241 | // slice. Thus, we have to offset the check index with `rankDiff` in |
1242 | // `srcStrides` and source dim sizes. |
1243 | size_t result = 0; |
1244 | int rankDiff = srcType.getRank() - vectorType.getRank(); |
1245 | for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) { |
1246 | // Check that the inner dim size is 1 for both memref type and vector slice. |
1247 | // It can be folded only if they are 1 and the stride is 1. |
1248 | int dim = vectorType.getRank() - i - 1; |
1249 | if (srcStrides[dim + rankDiff] != 1 || |
1250 | srcType.getDimSize(dim + rankDiff) != 1 || |
1251 | vectorType.getDimSize(dim) != 1) |
1252 | break; |
1253 | result++; |
1254 | } |
1255 | return result; |
1256 | } |
1257 | |
1258 | /// Drop inner most contiguous unit dimensions from transfer_read operand. |
1259 | class DropInnerMostUnitDimsTransferRead |
1260 | : public OpRewritePattern<vector::TransferReadOp> { |
1261 | using OpRewritePattern::OpRewritePattern; |
1262 | |
1263 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
1264 | PatternRewriter &rewriter) const override { |
1265 | // TODO: support 0-d corner case. |
1266 | if (readOp.getTransferRank() == 0) |
1267 | return failure(); |
1268 | |
1269 | // TODO: support mask. |
1270 | if (readOp.getMask()) |
1271 | return failure(); |
1272 | |
1273 | auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType()); |
1274 | if (!srcType) |
1275 | return failure(); |
1276 | |
1277 | if (!readOp.getPermutationMap().isMinorIdentity()) |
1278 | return failure(); |
1279 | |
1280 | auto targetType = readOp.getVectorType(); |
1281 | if (targetType.getRank() <= 1) |
1282 | return failure(); |
1283 | |
1284 | FailureOr<size_t> maybeDimsToDrop = |
1285 | getTransferFoldableInnerUnitDims(srcType, targetType); |
1286 | if (failed(result: maybeDimsToDrop)) |
1287 | return failure(); |
1288 | |
1289 | size_t dimsToDrop = maybeDimsToDrop.value(); |
1290 | if (dimsToDrop == 0) |
1291 | return failure(); |
1292 | |
1293 | auto resultTargetVecType = |
1294 | VectorType::get(targetType.getShape().drop_back(dimsToDrop), |
1295 | targetType.getElementType()); |
1296 | |
1297 | auto loc = readOp.getLoc(); |
1298 | SmallVector<OpFoldResult> sizes = |
1299 | memref::getMixedSizes(builder&: rewriter, loc: loc, value: readOp.getSource()); |
1300 | SmallVector<OpFoldResult> offsets(srcType.getRank(), |
1301 | rewriter.getIndexAttr(0)); |
1302 | SmallVector<OpFoldResult> strides(srcType.getRank(), |
1303 | rewriter.getIndexAttr(1)); |
1304 | auto resultMemrefType = |
1305 | cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( |
1306 | srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, |
1307 | strides)); |
1308 | ArrayAttr inBoundsAttr = |
1309 | readOp.getInBounds() |
1310 | ? rewriter.getArrayAttr( |
1311 | readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) |
1312 | : ArrayAttr(); |
1313 | Value rankedReducedView = rewriter.create<memref::SubViewOp>( |
1314 | loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides); |
1315 | auto permMap = getTransferMinorIdentityMap( |
1316 | cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); |
1317 | Value result = rewriter.create<vector::TransferReadOp>( |
1318 | loc, resultTargetVecType, rankedReducedView, |
1319 | readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), |
1320 | readOp.getPadding(), |
1321 | // TODO: support mask. |
1322 | /*mask=*/Value(), inBoundsAttr); |
1323 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, |
1324 | result); |
1325 | return success(); |
1326 | } |
1327 | }; |
1328 | |
1329 | /// Drop inner most contiguous unit dimensions from transfer_write operand. |
1330 | /// E.g., |
1331 | /// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0] |
1332 | /// {in_bounds = [true, true, true, true, true]} |
1333 | /// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32> |
1334 | /// |
1335 | /// will be replaced with |
1336 | /// |
1337 | /// %subview = memref.subview %arg0 |
1338 | /// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1] |
1339 | /// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32> |
1340 | /// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32> |
1341 | /// to vector<1x16x16xf32> |
1342 | /// vector.transfer_write %0, %subview[%c0, %arg2, %c0] |
1343 | /// {in_bounds = [true, true, true]} |
1344 | /// : vector<1x16x16xf32>, memref<1x512x16xf32> |
1345 | class DropInnerMostUnitDimsTransferWrite |
1346 | : public OpRewritePattern<vector::TransferWriteOp> { |
1347 | using OpRewritePattern::OpRewritePattern; |
1348 | |
1349 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
1350 | PatternRewriter &rewriter) const override { |
1351 | // TODO: support 0-d corner case. |
1352 | if (writeOp.getTransferRank() == 0) |
1353 | return failure(); |
1354 | |
1355 | // TODO: support mask. |
1356 | if (writeOp.getMask()) |
1357 | return failure(); |
1358 | |
1359 | auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType()); |
1360 | if (!srcType) |
1361 | return failure(); |
1362 | |
1363 | if (!writeOp.getPermutationMap().isMinorIdentity()) |
1364 | return failure(); |
1365 | |
1366 | auto targetType = writeOp.getVectorType(); |
1367 | if (targetType.getRank() <= 1) |
1368 | return failure(); |
1369 | |
1370 | FailureOr<size_t> maybeDimsToDrop = |
1371 | getTransferFoldableInnerUnitDims(srcType, targetType); |
1372 | if (failed(result: maybeDimsToDrop)) |
1373 | return failure(); |
1374 | |
1375 | size_t dimsToDrop = maybeDimsToDrop.value(); |
1376 | if (dimsToDrop == 0) |
1377 | return failure(); |
1378 | |
1379 | auto resultTargetVecType = |
1380 | VectorType::get(targetType.getShape().drop_back(dimsToDrop), |
1381 | targetType.getElementType()); |
1382 | |
1383 | Location loc = writeOp.getLoc(); |
1384 | SmallVector<OpFoldResult> sizes = |
1385 | memref::getMixedSizes(builder&: rewriter, loc, value: writeOp.getSource()); |
1386 | SmallVector<OpFoldResult> offsets(srcType.getRank(), |
1387 | rewriter.getIndexAttr(0)); |
1388 | SmallVector<OpFoldResult> strides(srcType.getRank(), |
1389 | rewriter.getIndexAttr(1)); |
1390 | auto resultMemrefType = |
1391 | cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( |
1392 | srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes, |
1393 | strides)); |
1394 | ArrayAttr inBoundsAttr = |
1395 | writeOp.getInBounds() |
1396 | ? rewriter.getArrayAttr( |
1397 | writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) |
1398 | : ArrayAttr(); |
1399 | |
1400 | Value rankedReducedView = rewriter.create<memref::SubViewOp>( |
1401 | loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides); |
1402 | auto permMap = getTransferMinorIdentityMap( |
1403 | cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType); |
1404 | |
1405 | auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( |
1406 | loc, resultTargetVecType, writeOp.getVector()); |
1407 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
1408 | writeOp, shapeCast, rankedReducedView, |
1409 | writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), |
1410 | // TODO: support mask. |
1411 | /*mask=*/Value(), inBoundsAttr); |
1412 | return success(); |
1413 | } |
1414 | }; |
1415 | |
1416 | /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul |
1417 | /// semantics to a contraction suitable for MMT (matrix matrix multiplication |
1418 | /// with the RHS transposed) lowering. |
1419 | struct CanonicalizeContractMatmulToMMT final |
1420 | : OpRewritePattern<vector::ContractionOp> { |
1421 | using OpRewritePattern::OpRewritePattern; |
1422 | |
1423 | using FilterConstraintType = |
1424 | std::function<LogicalResult(vector::ContractionOp op)>; |
1425 | |
1426 | CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, |
1427 | FilterConstraintType constraint) |
1428 | : OpRewritePattern<vector::ContractionOp>(context, benefit), |
1429 | filter(std::move(constraint)) {} |
1430 | |
1431 | LogicalResult matchAndRewrite(vector::ContractionOp op, |
1432 | PatternRewriter &rewriter) const override { |
1433 | if (failed(filter(op))) |
1434 | return failure(); |
1435 | |
1436 | Location loc = op.getLoc(); |
1437 | Value lhs = op.getLhs(); |
1438 | Value rhs = op.getRhs(); |
1439 | Value res = op.getAcc(); |
1440 | |
1441 | // Set up the parallel/reduction structure in right form. |
1442 | using MapList = ArrayRef<ArrayRef<AffineExpr>>; |
1443 | auto infer = [&](MapList m) { |
1444 | return AffineMap::inferFromExprList(m, op.getContext()); |
1445 | }; |
1446 | AffineExpr m; |
1447 | AffineExpr n; |
1448 | AffineExpr k; |
1449 | bindDims(ctx: rewriter.getContext(), exprs&: m, exprs&: n, exprs&: k); |
1450 | static constexpr std::array<int64_t, 2> perm = {1, 0}; |
1451 | auto iteratorTypes = op.getIteratorTypes().getValue(); |
1452 | SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); |
1453 | if (iteratorTypes.size() != 3 || |
1454 | !vector::isParallelIterator(attr: iteratorTypes[0]) || |
1455 | !vector::isParallelIterator(attr: iteratorTypes[1]) || |
1456 | !vector::isReductionIterator(attr: iteratorTypes[2])) |
1457 | return rewriter.notifyMatchFailure(op, "contraction is not a gemm" ); |
1458 | |
1459 | // The canonical form is "TNT" = A row-major, B col-major, C row-major. |
1460 | const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); |
1461 | if (maps == canonicalForm) |
1462 | return rewriter.notifyMatchFailure(op, "already in the canonical form" ); |
1463 | |
1464 | // Create a vector transpose making sure to emit zero/sign-extend at the |
1465 | // end. |
1466 | auto createTranspose = [&rewriter, loc](Value mat) -> Value { |
1467 | if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) { |
1468 | Value trans = |
1469 | rewriter.create<vector::TransposeOp>(loc, sext.getIn(), perm); |
1470 | VectorType newType = |
1471 | cast<VectorType>(trans.getType()) |
1472 | .clone(cast<VectorType>(mat.getType()).getElementType()); |
1473 | return rewriter.create<arith::ExtSIOp>(loc, newType, trans); |
1474 | } |
1475 | if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) { |
1476 | Value trans = |
1477 | rewriter.create<vector::TransposeOp>(loc, zext.getIn(), perm); |
1478 | VectorType newType = |
1479 | VectorType::get(cast<VectorType>(trans.getType()).getShape(), |
1480 | cast<VectorType>(mat.getType()).getElementType()); |
1481 | return rewriter.create<arith::ExtUIOp>(loc, newType, trans); |
1482 | } |
1483 | return rewriter.create<vector::TransposeOp>(loc, mat, perm); |
1484 | }; |
1485 | |
1486 | if (maps == infer({{m, k}, {k, n}, {m, n}})) { |
1487 | rhs = createTranspose(rhs); |
1488 | } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { |
1489 | lhs = createTranspose(lhs); |
1490 | } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { |
1491 | rhs = createTranspose(rhs); |
1492 | lhs = createTranspose(lhs); |
1493 | } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { |
1494 | std::swap(a&: rhs, b&: lhs); |
1495 | rhs = createTranspose(rhs); |
1496 | lhs = createTranspose(lhs); |
1497 | } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { |
1498 | std::swap(a&: rhs, b&: lhs); |
1499 | rhs = createTranspose(rhs); |
1500 | } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { |
1501 | std::swap(a&: lhs, b&: rhs); |
1502 | lhs = createTranspose(lhs); |
1503 | } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { |
1504 | std::swap(a&: lhs, b&: rhs); |
1505 | } else { |
1506 | return rewriter.notifyMatchFailure(op, "unhandled contraction form" ); |
1507 | } |
1508 | rewriter.replaceOpWithNewOp<vector::ContractionOp>( |
1509 | op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(values: canonicalForm), |
1510 | op.getIteratorTypes()); |
1511 | return success(); |
1512 | }; |
1513 | |
1514 | private: |
1515 | FilterConstraintType filter; |
1516 | }; |
1517 | |
1518 | /// Pattern to fold arithmetic extensions on floating point data types into |
1519 | /// vector contraction operations. linalg.matmul introduces arithmetic |
1520 | /// extensions on its operands. Please mlir snippets below for more details. |
1521 | /// ```mlir |
1522 | /// "linalg.matmul"(%lhs, %rhs, %acc) ({ |
1523 | /// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32): |
1524 | /// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32 |
1525 | /// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32 |
1526 | /// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32 |
1527 | /// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32 |
1528 | /// "linalg.yield"(%acc) : (f32) -> () |
1529 | /// }) |
1530 | /// ``` |
1531 | /// This restricts the native usage of mixed precision NVIDIA Ampere Tensor |
1532 | /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`. |
1533 | /// This pattern folds the arithmetic extensions into the vector contraction and |
1534 | /// enables the usage of native mixed precision Tensor Core instructions. |
1535 | struct FoldArithExtIntoContractionOp |
1536 | : public OpRewritePattern<vector::ContractionOp> { |
1537 | using OpRewritePattern::OpRewritePattern; |
1538 | |
1539 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
1540 | PatternRewriter &rewriter) const override { |
1541 | |
1542 | auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>(); |
1543 | auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>(); |
1544 | |
1545 | if (!lhsDefOp || !rhsDefOp) { |
1546 | return rewriter.notifyMatchFailure(contractOp, |
1547 | "no defining op on contract operands" ); |
1548 | } |
1549 | |
1550 | rewriter.replaceOpWithNewOp<vector::ContractionOp>( |
1551 | contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0), |
1552 | contractOp.getAcc(), contractOp.getIndexingMapsAttr(), |
1553 | contractOp.getIteratorTypesAttr()); |
1554 | |
1555 | return success(); |
1556 | } |
1557 | }; |
1558 | |
1559 | /// Pattern to fold chained reduction to a series of vector additions and a |
1560 | /// final reduction. This form should require fewer subgroup operations. |
1561 | /// |
1562 | /// ```mlir |
1563 | /// %a = vector.reduction <add> %x, %acc |
1564 | /// %b = vector.reduction <add> %y, %a |
1565 | /// ==> |
1566 | /// %a = arith.addf %x, %y |
1567 | /// %b = vector.reduction <add> %a, %acc |
1568 | /// ``` |
1569 | struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { |
1570 | using OpRewritePattern::OpRewritePattern; |
1571 | |
1572 | LogicalResult matchAndRewrite(vector::ReductionOp op, |
1573 | PatternRewriter &rewriter) const override { |
1574 | // TODO: Handle other combining kinds. |
1575 | if (op.getKind() != vector::CombiningKind::ADD) |
1576 | return failure(); |
1577 | |
1578 | // Accumulator is optional. |
1579 | Value acc = op.getAcc(); |
1580 | if (!acc) |
1581 | return failure(); |
1582 | |
1583 | if (!acc.getType().isIntOrFloat()) |
1584 | return failure(); |
1585 | |
1586 | auto parentReduction = acc.getDefiningOp<vector::ReductionOp>(); |
1587 | if (!parentReduction) |
1588 | return failure(); |
1589 | |
1590 | Location loc = op.getLoc(); |
1591 | Value vAdd; |
1592 | if (isa<IntegerType>(Val: acc.getType())) { |
1593 | vAdd = rewriter.createOrFold<arith::AddIOp>( |
1594 | loc, parentReduction.getVector(), op.getVector()); |
1595 | } else { |
1596 | vAdd = rewriter.create<arith::AddFOp>(loc, parentReduction.getVector(), |
1597 | op.getVector()); |
1598 | } |
1599 | rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd, |
1600 | parentReduction.getAcc()); |
1601 | return success(); |
1602 | } |
1603 | }; |
1604 | |
1605 | /// For vectors with either leading or trailing unit dim, replaces: |
1606 | /// elementwise(a, b) |
1607 | /// with: |
1608 | /// sc_a = shape_cast(a) |
1609 | /// sc_b = shape_cast(b) |
1610 | /// res = elementwise(sc_a, sc_b) |
1611 | /// return shape_cast(res) |
1612 | /// The newly inserted shape_cast Ops fold (before elementwise Op) and then |
1613 | /// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are |
1614 | /// required to be rank > 1. |
1615 | /// |
1616 | /// Ex: |
1617 | /// ``` |
1618 | /// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> |
1619 | /// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> |
1620 | /// ``` |
1621 | /// |
1622 | /// gets converted to: |
1623 | /// |
1624 | /// ``` |
1625 | /// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> |
1626 | /// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> |
1627 | /// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> |
1628 | /// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> |
1629 | /// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32> |
1630 | /// ``` |
1631 | /// |
1632 | /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and |
1633 | /// `%cast`. |
1634 | struct DropUnitDimFromElementwiseOps final |
1635 | : public OpTraitRewritePattern<OpTrait::Elementwise> { |
1636 | using OpTraitRewritePattern::OpTraitRewritePattern; |
1637 | LogicalResult matchAndRewrite(Operation *op, |
1638 | PatternRewriter &rewriter) const override { |
1639 | if (op->getNumResults() != 1 || op->getNumRegions() != 0) |
1640 | return failure(); |
1641 | |
1642 | auto resultVectorType = dyn_cast<VectorType>(op->getResult(idx: 0).getType()); |
1643 | if (!resultVectorType) |
1644 | return failure(); |
1645 | |
1646 | // Check the operand pre-conditions. For `Elementwise` ops all operands are |
1647 | // guaranteed to have identical shapes (with some exceptions such as |
1648 | // `arith.select`) and it suffices to only check one of them. |
1649 | auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(idx: 0).getType()); |
1650 | if (!sourceVectorType) |
1651 | return failure(); |
1652 | if (sourceVectorType.getRank() < 2) |
1653 | return failure(); |
1654 | |
1655 | bool hasTrailingDimUnitFixed = |
1656 | ((sourceVectorType.getShape().back() == 1) && |
1657 | (!sourceVectorType.getScalableDims().back())); |
1658 | bool hasLeadingDimUnitFixed = |
1659 | ((sourceVectorType.getShape().front() == 1) && |
1660 | (!sourceVectorType.getScalableDims().front())); |
1661 | if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed) |
1662 | return failure(); |
1663 | |
1664 | // Drop leading/trailing unit dim by applying vector.shape_cast to all |
1665 | // operands |
1666 | int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1; |
1667 | |
1668 | SmallVector<Value> newOperands; |
1669 | auto loc = op->getLoc(); |
1670 | for (auto operand : op->getOperands()) { |
1671 | auto opVectorType = cast<VectorType>(operand.getType()); |
1672 | VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim); |
1673 | auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand); |
1674 | newOperands.push_back(Elt: opSC); |
1675 | } |
1676 | |
1677 | VectorType newResultVectorType = |
1678 | VectorType::Builder(resultVectorType).dropDim(dim); |
1679 | // Create an updated elementwise Op without leading/trailing unit dim |
1680 | Operation *elementwiseOp = |
1681 | rewriter.create(loc, op->getName().getIdentifier(), newOperands, |
1682 | newResultVectorType, op->getAttrs()); |
1683 | |
1684 | // Restore the leading/trailing unit dim by applying vector.shape_cast |
1685 | // to the result |
1686 | rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType, |
1687 | elementwiseOp->getResult(0)); |
1688 | |
1689 | return success(); |
1690 | } |
1691 | }; |
1692 | |
1693 | /// Pattern to eliminate redundant zero-constants added to reduction operands. |
1694 | /// It's enough for there to be one initial zero value, so we can eliminate the |
1695 | /// extra ones that feed into `vector.reduction <add>`. These get created by the |
1696 | /// `ChainedReduction` pattern. |
1697 | /// |
1698 | /// ```mlir |
1699 | /// %a = arith.addf %x, %zero |
1700 | /// %b = arith.addf %a, %y |
1701 | /// %c = vector.reduction <add> %b, %acc |
1702 | /// ==> |
1703 | /// %b = arith.addf %a, %y |
1704 | /// %c = vector.reduction <add> %b, %acc |
1705 | /// ``` |
1706 | struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { |
1707 | using OpRewritePattern::OpRewritePattern; |
1708 | |
1709 | LogicalResult matchAndRewrite(vector::ReductionOp op, |
1710 | PatternRewriter &rewriter) const override { |
1711 | // TODO: Handle other reduction kinds and their identity values. |
1712 | if (op.getKind() != vector::CombiningKind::ADD) |
1713 | return failure(); |
1714 | |
1715 | Type elemType = op.getSourceVectorType().getElementType(); |
1716 | // The integer case should be handled by `arith.addi` folders, only check |
1717 | // for floats here. |
1718 | if (!isa<FloatType>(Val: elemType)) |
1719 | return failure(); |
1720 | |
1721 | auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>(); |
1722 | if (!vAdd) |
1723 | return failure(); |
1724 | auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>(); |
1725 | if (!addLhs) |
1726 | return failure(); |
1727 | |
1728 | if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat())) |
1729 | return failure(); |
1730 | |
1731 | auto newAdd = rewriter.create<arith::AddFOp>(vAdd.getLoc(), addLhs.getLhs(), |
1732 | vAdd.getRhs()); |
1733 | rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd, |
1734 | op.getAcc()); |
1735 | return success(); |
1736 | } |
1737 | }; |
1738 | |
1739 | /// Example: |
1740 | /// ``` |
1741 | /// %a = vector.reduction <add> %x : vector<2xf32> into f32 |
1742 | /// ``` |
1743 | /// is transformed into: |
1744 | /// ``` |
1745 | /// %y = vector.extract %x[0] : f32 from vector<2xf32> |
1746 | /// %z = vector.extract %x[1] : f32 from vector<2xf32> |
1747 | /// %a = arith.addf %y, %z : f32 |
1748 | /// ``` |
1749 | struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { |
1750 | BreakDownVectorReduction(MLIRContext *context, |
1751 | unsigned , |
1752 | PatternBenefit benefit) |
1753 | : OpRewritePattern(context, benefit), |
1754 | maxNumElementsToExtract(maxNumElementsToExtract) {} |
1755 | |
1756 | LogicalResult matchAndRewrite(vector::ReductionOp op, |
1757 | PatternRewriter &rewriter) const override { |
1758 | VectorType type = op.getSourceVectorType(); |
1759 | if (type.isScalable() || op.isMasked()) |
1760 | return failure(); |
1761 | assert(type.getRank() == 1 && "Expected a 1-d vector" ); |
1762 | |
1763 | int64_t numElems = type.getNumElements(); |
1764 | if (numElems > maxNumElementsToExtract) { |
1765 | return rewriter.notifyMatchFailure( |
1766 | op, llvm::formatv(Fmt: "has too many vector elements ({0}) to break down " |
1767 | "(max allowed: {1})" , |
1768 | Vals&: numElems, Vals: maxNumElementsToExtract)); |
1769 | } |
1770 | |
1771 | Location loc = op.getLoc(); |
1772 | SmallVector<Value> (numElems, nullptr); |
1773 | for (auto [idx, extractedElem] : llvm::enumerate(extracted)) |
1774 | extractedElem = rewriter.create<vector::ExtractOp>( |
1775 | loc, op.getVector(), static_cast<int64_t>(idx)); |
1776 | |
1777 | Value res = extracted.front(); |
1778 | for (auto extractedElem : llvm::drop_begin(extracted)) |
1779 | res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, |
1780 | extractedElem, op.getFastmathAttr()); |
1781 | if (Value acc = op.getAcc()) |
1782 | res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc, |
1783 | op.getFastmathAttr()); |
1784 | |
1785 | rewriter.replaceOp(op, res); |
1786 | return success(); |
1787 | } |
1788 | |
1789 | private: |
1790 | unsigned = 0; |
1791 | }; |
1792 | |
1793 | } // namespace |
1794 | |
1795 | void mlir::vector::populateFoldArithExtensionPatterns( |
1796 | RewritePatternSet &patterns) { |
1797 | patterns.add<FoldArithExtIntoContractionOp>(arg: patterns.getContext()); |
1798 | } |
1799 | |
1800 | void mlir::vector::populateVectorMaskMaterializationPatterns( |
1801 | RewritePatternSet &patterns, bool force32BitVectorIndices, |
1802 | PatternBenefit benefit) { |
1803 | patterns.add<VectorCreateMaskOpConversion, |
1804 | MaterializeTransferMask<vector::TransferReadOp>, |
1805 | MaterializeTransferMask<vector::TransferWriteOp>>( |
1806 | arg: patterns.getContext(), args&: force32BitVectorIndices, args&: benefit); |
1807 | patterns.add<FoldI1Select>(arg: patterns.getContext(), args&: benefit); |
1808 | } |
1809 | |
1810 | void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns, |
1811 | PatternBenefit benefit) { |
1812 | patterns.add<ShapeCastOpFolder>(arg: patterns.getContext(), args&: benefit); |
1813 | } |
1814 | |
1815 | void mlir::vector::populateDropUnitDimWithShapeCastPatterns( |
1816 | RewritePatternSet &patterns, PatternBenefit benefit) { |
1817 | patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>( |
1818 | arg: patterns.getContext(), args&: benefit); |
1819 | } |
1820 | |
1821 | void mlir::vector::populateBubbleVectorBitCastOpPatterns( |
1822 | RewritePatternSet &patterns, PatternBenefit benefit) { |
1823 | patterns.add<BubbleDownVectorBitCastForExtract, |
1824 | BubbleDownBitCastForStridedSliceExtract, |
1825 | BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>( |
1826 | arg: patterns.getContext(), args&: benefit); |
1827 | } |
1828 | |
1829 | void mlir::vector::populateBreakDownVectorBitCastOpPatterns( |
1830 | RewritePatternSet &patterns, |
1831 | std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) { |
1832 | patterns.add<BreakDownVectorBitCast>(patterns.getContext(), |
1833 | std::move(controlFn), benefit); |
1834 | } |
1835 | |
1836 | void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( |
1837 | RewritePatternSet &patterns, |
1838 | std::function<LogicalResult(vector::ContractionOp)> constraint, |
1839 | PatternBenefit benefit) { |
1840 | patterns.add<CanonicalizeContractMatmulToMMT>(arg: patterns.getContext(), args&: benefit, |
1841 | args: std::move(constraint)); |
1842 | } |
1843 | |
1844 | void mlir::vector::populateVectorReductionToContractPatterns( |
1845 | RewritePatternSet &patterns, PatternBenefit benefit) { |
1846 | patterns.add<MultiReduceToContract, CombineContractBroadcast, |
1847 | CombineContractABTranspose, CombineContractResultTranspose, |
1848 | ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>( |
1849 | arg: patterns.getContext(), args&: benefit); |
1850 | } |
1851 | |
1852 | void mlir::vector:: |
1853 | populateVectorTransferCollapseInnerMostContiguousDimsPatterns( |
1854 | RewritePatternSet &patterns, PatternBenefit benefit) { |
1855 | patterns.add<DropInnerMostUnitDimsTransferRead, |
1856 | DropInnerMostUnitDimsTransferWrite>(arg: patterns.getContext(), |
1857 | args&: benefit); |
1858 | } |
1859 | |
1860 | void mlir::vector::populateSinkVectorBroadcastPatterns( |
1861 | RewritePatternSet &patterns, PatternBenefit benefit) { |
1862 | patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>( |
1863 | arg: patterns.getContext(), args&: benefit); |
1864 | } |
1865 | |
1866 | void mlir::vector::populateChainedVectorReductionFoldingPatterns( |
1867 | RewritePatternSet &patterns, PatternBenefit benefit) { |
1868 | patterns.add<ChainedReduction>(arg: patterns.getContext(), args&: benefit); |
1869 | patterns.add<ReduceRedundantZero>(arg: patterns.getContext(), |
1870 | args: PatternBenefit(benefit.getBenefit() + 1)); |
1871 | } |
1872 | |
1873 | void mlir::vector::populateBreakDownVectorReductionPatterns( |
1874 | RewritePatternSet &patterns, unsigned , |
1875 | PatternBenefit benefit) { |
1876 | patterns.add<BreakDownVectorReduction>(arg: patterns.getContext(), |
1877 | args&: maxNumElementsToExtract, args&: benefit); |
1878 | } |
1879 | |
1880 | //===----------------------------------------------------------------------===// |
1881 | // TableGen'd enum attribute definitions |
1882 | //===----------------------------------------------------------------------===// |
1883 | |
1884 | #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc" |
1885 | |