1 | //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// |
2 | // |
3 | /// Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | /// Exceptions. See https://llvm.org/LICENSE.txt for license information. |
5 | /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.multi_reduction' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
17 | #include "mlir/Dialect/Vector/Transforms/Passes.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | |
22 | namespace mlir { |
23 | namespace vector { |
24 | #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION |
25 | #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
26 | } // namespace vector |
27 | } // namespace mlir |
28 | |
29 | #define DEBUG_TYPE "vector-multi-reduction" |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace { |
34 | /// This file implements the following transformations as composable atomic |
35 | /// patterns. |
36 | |
37 | /// Converts vector.multi_reduction into inner-most/outer-most reduction form |
38 | /// by using vector.transpose |
39 | class InnerOuterDimReductionConversion |
40 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
41 | public: |
42 | using OpRewritePattern::OpRewritePattern; |
43 | |
44 | explicit InnerOuterDimReductionConversion( |
45 | MLIRContext *context, vector::VectorMultiReductionLowering options, |
46 | PatternBenefit benefit = 1) |
47 | : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
48 | useInnerDimsForReduction( |
49 | options == vector::VectorMultiReductionLowering::InnerReduction) {} |
50 | |
51 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
52 | PatternRewriter &rewriter) const override { |
53 | // Vector mask setup. |
54 | OpBuilder::InsertionGuard guard(rewriter); |
55 | auto maskableOp = |
56 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
57 | Operation *rootOp; |
58 | if (maskableOp.isMasked()) { |
59 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
60 | rootOp = maskableOp.getMaskingOp(); |
61 | } else { |
62 | rootOp = multiReductionOp; |
63 | } |
64 | |
65 | auto src = multiReductionOp.getSource(); |
66 | auto loc = multiReductionOp.getLoc(); |
67 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
68 | |
69 | // Separate reduction and parallel dims |
70 | auto reductionDimsRange = |
71 | multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>(); |
72 | auto reductionDims = llvm::to_vector<4>(llvm::map_range( |
73 | reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); |
74 | llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), |
75 | reductionDims.end()); |
76 | int64_t reductionSize = reductionDims.size(); |
77 | SmallVector<int64_t, 4> parallelDims; |
78 | for (int64_t i = 0; i < srcRank; ++i) |
79 | if (!reductionDimsSet.contains(i)) |
80 | parallelDims.push_back(i); |
81 | |
82 | // Add transpose only if inner-most/outer-most dimensions are not parallel |
83 | // and there are parallel dims. |
84 | if (parallelDims.empty()) |
85 | return failure(); |
86 | if (useInnerDimsForReduction && |
87 | (parallelDims == |
88 | llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) |
89 | return failure(); |
90 | |
91 | if (!useInnerDimsForReduction && |
92 | (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>( |
93 | reductionDims.size(), |
94 | parallelDims.size() + reductionDims.size())))) |
95 | return failure(); |
96 | |
97 | SmallVector<int64_t, 4> indices; |
98 | if (useInnerDimsForReduction) { |
99 | indices.append(parallelDims.begin(), parallelDims.end()); |
100 | indices.append(reductionDims.begin(), reductionDims.end()); |
101 | } else { |
102 | indices.append(reductionDims.begin(), reductionDims.end()); |
103 | indices.append(parallelDims.begin(), parallelDims.end()); |
104 | } |
105 | |
106 | // If masked, transpose the original mask. |
107 | Value transposedMask; |
108 | if (maskableOp.isMasked()) { |
109 | transposedMask = rewriter.create<vector::TransposeOp>( |
110 | loc, maskableOp.getMaskingOp().getMask(), indices); |
111 | } |
112 | |
113 | // Transpose reduction source. |
114 | auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); |
115 | SmallVector<bool> reductionMask(srcRank, false); |
116 | for (int i = 0; i < reductionSize; ++i) { |
117 | if (useInnerDimsForReduction) |
118 | reductionMask[srcRank - i - 1] = true; |
119 | else |
120 | reductionMask[i] = true; |
121 | } |
122 | |
123 | Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>( |
124 | multiReductionOp.getLoc(), transposeOp.getResult(), |
125 | multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); |
126 | newMultiRedOp = |
127 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiRedOp, mask: transposedMask); |
128 | |
129 | rewriter.replaceOp(op: rootOp, newValues: newMultiRedOp->getResult(idx: 0)); |
130 | return success(); |
131 | } |
132 | |
133 | private: |
134 | const bool useInnerDimsForReduction; |
135 | }; |
136 | |
137 | /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction |
138 | /// dimensions are either inner most or outer most. |
139 | class ReduceMultiDimReductionRank |
140 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
141 | public: |
142 | using OpRewritePattern::OpRewritePattern; |
143 | |
144 | explicit ReduceMultiDimReductionRank( |
145 | MLIRContext *context, vector::VectorMultiReductionLowering options, |
146 | PatternBenefit benefit = 1) |
147 | : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
148 | useInnerDimsForReduction( |
149 | options == vector::VectorMultiReductionLowering::InnerReduction) {} |
150 | |
151 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
152 | PatternRewriter &rewriter) const override { |
153 | // Vector mask setup. |
154 | OpBuilder::InsertionGuard guard(rewriter); |
155 | auto maskableOp = |
156 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
157 | Operation *rootOp; |
158 | if (maskableOp.isMasked()) { |
159 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
160 | rootOp = maskableOp.getMaskingOp(); |
161 | } else { |
162 | rootOp = multiReductionOp; |
163 | } |
164 | |
165 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
166 | auto srcShape = multiReductionOp.getSourceVectorType().getShape(); |
167 | auto srcScalableDims = |
168 | multiReductionOp.getSourceVectorType().getScalableDims(); |
169 | auto loc = multiReductionOp.getLoc(); |
170 | |
171 | // If rank less than 2, nothing to do. |
172 | if (srcRank < 2) |
173 | return failure(); |
174 | |
175 | // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. |
176 | // `vscale * vscale` that's currently not modelled. |
177 | if (llvm::count(srcScalableDims, true) > 1) |
178 | return failure(); |
179 | |
180 | // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. |
181 | SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); |
182 | if (srcRank == 2 && reductionMask.front() != reductionMask.back()) |
183 | return failure(); |
184 | |
185 | // 1. Separate reduction and parallel dims. |
186 | SmallVector<int64_t, 4> parallelDims, parallelShapes; |
187 | SmallVector<bool, 4> parallelScalableDims; |
188 | SmallVector<int64_t, 4> reductionDims, reductionShapes; |
189 | bool isReductionDimScalable = false; |
190 | for (const auto &it : llvm::enumerate(reductionMask)) { |
191 | int64_t i = it.index(); |
192 | bool isReduction = it.value(); |
193 | if (isReduction) { |
194 | reductionDims.push_back(i); |
195 | reductionShapes.push_back(srcShape[i]); |
196 | isReductionDimScalable |= srcScalableDims[i]; |
197 | } else { |
198 | parallelDims.push_back(i); |
199 | parallelShapes.push_back(srcShape[i]); |
200 | parallelScalableDims.push_back(srcScalableDims[i]); |
201 | } |
202 | } |
203 | |
204 | // 2. Compute flattened parallel and reduction sizes. |
205 | int flattenedParallelDim = 0; |
206 | int flattenedReductionDim = 0; |
207 | if (!parallelShapes.empty()) { |
208 | flattenedParallelDim = 1; |
209 | for (auto d : parallelShapes) |
210 | flattenedParallelDim *= d; |
211 | } |
212 | if (!reductionShapes.empty()) { |
213 | flattenedReductionDim = 1; |
214 | for (auto d : reductionShapes) |
215 | flattenedReductionDim *= d; |
216 | } |
217 | // We must at least have some parallel or some reduction. |
218 | assert((flattenedParallelDim || flattenedReductionDim) && |
219 | "expected at least one parallel or reduction dim" ); |
220 | |
221 | // 3. Fail if reduction/parallel dims are not contiguous. |
222 | // Check parallelDims are exactly [0 .. size). |
223 | int64_t counter = 0; |
224 | if (useInnerDimsForReduction && |
225 | llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) |
226 | return failure(); |
227 | // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). |
228 | counter = reductionDims.size(); |
229 | if (!useInnerDimsForReduction && |
230 | llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) |
231 | return failure(); |
232 | |
233 | // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into |
234 | // a single parallel (resp. reduction) dim. |
235 | SmallVector<bool, 2> mask; |
236 | SmallVector<bool, 2> scalableDims; |
237 | SmallVector<int64_t, 2> vectorShape; |
238 | bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); |
239 | if (flattenedParallelDim) { |
240 | mask.push_back(false); |
241 | vectorShape.push_back(Elt: flattenedParallelDim); |
242 | scalableDims.push_back(isParallelDimScalable); |
243 | } |
244 | if (flattenedReductionDim) { |
245 | mask.push_back(true); |
246 | vectorShape.push_back(Elt: flattenedReductionDim); |
247 | scalableDims.push_back(isReductionDimScalable); |
248 | } |
249 | if (!useInnerDimsForReduction && vectorShape.size() == 2) { |
250 | std::swap(mask.front(), mask.back()); |
251 | std::swap(vectorShape.front(), vectorShape.back()); |
252 | std::swap(scalableDims.front(), scalableDims.back()); |
253 | } |
254 | |
255 | Value newVectorMask; |
256 | if (maskableOp.isMasked()) { |
257 | Value vectorMask = maskableOp.getMaskingOp().getMask(); |
258 | auto maskCastedType = VectorType::get( |
259 | vectorShape, |
260 | llvm::cast<VectorType>(vectorMask.getType()).getElementType()); |
261 | newVectorMask = |
262 | rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask); |
263 | } |
264 | |
265 | auto castedType = VectorType::get( |
266 | vectorShape, multiReductionOp.getSourceVectorType().getElementType(), |
267 | scalableDims); |
268 | Value cast = rewriter.create<vector::ShapeCastOp>( |
269 | loc, castedType, multiReductionOp.getSource()); |
270 | |
271 | Value acc = multiReductionOp.getAcc(); |
272 | if (flattenedParallelDim) { |
273 | auto accType = VectorType::get( |
274 | {flattenedParallelDim}, |
275 | multiReductionOp.getSourceVectorType().getElementType(), |
276 | /*scalableDims=*/{isParallelDimScalable}); |
277 | acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc); |
278 | } |
279 | // 6. Creates the flattened form of vector.multi_reduction with inner/outer |
280 | // most dim as reduction. |
281 | Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>( |
282 | loc, cast, acc, mask, multiReductionOp.getKind()); |
283 | newMultiDimRedOp = |
284 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiDimRedOp, mask: newVectorMask); |
285 | |
286 | // 7. If there are no parallel shapes, the result is a scalar. |
287 | // TODO: support 0-d vectors when available. |
288 | if (parallelShapes.empty()) { |
289 | rewriter.replaceOp(op: rootOp, newValues: newMultiDimRedOp->getResult(idx: 0)); |
290 | return success(); |
291 | } |
292 | |
293 | // 8. Creates shape cast for the output n-D -> 2-D. |
294 | VectorType outputCastedType = VectorType::get( |
295 | parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), |
296 | parallelScalableDims); |
297 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
298 | rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); |
299 | return success(); |
300 | } |
301 | |
302 | private: |
303 | const bool useInnerDimsForReduction; |
304 | }; |
305 | |
306 | /// Unrolls vector.multi_reduction with outermost reductions |
307 | /// and combines results |
308 | struct TwoDimMultiReductionToElementWise |
309 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
310 | using OpRewritePattern::OpRewritePattern; |
311 | |
312 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
313 | PatternRewriter &rewriter) const override { |
314 | auto maskableOp = |
315 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
316 | if (maskableOp.isMasked()) |
317 | // TODO: Support masking. |
318 | return failure(); |
319 | |
320 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
321 | // Rank-2 ["parallel", "reduce"] or bail. |
322 | if (srcRank != 2) |
323 | return failure(); |
324 | |
325 | if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) |
326 | return failure(); |
327 | |
328 | auto loc = multiReductionOp.getLoc(); |
329 | ArrayRef<int64_t> srcShape = |
330 | multiReductionOp.getSourceVectorType().getShape(); |
331 | |
332 | Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); |
333 | if (!elementType.isIntOrIndexOrFloat()) |
334 | return failure(); |
335 | |
336 | Value result = multiReductionOp.getAcc(); |
337 | for (int64_t i = 0; i < srcShape[0]; i++) { |
338 | auto operand = rewriter.create<vector::ExtractOp>( |
339 | loc, multiReductionOp.getSource(), i); |
340 | result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), |
341 | operand, result); |
342 | } |
343 | |
344 | rewriter.replaceOp(multiReductionOp, result); |
345 | return success(); |
346 | } |
347 | }; |
348 | |
349 | /// Converts 2d vector.multi_reduction with inner most reduction dimension into |
350 | /// a sequence of vector.reduction ops. |
351 | struct TwoDimMultiReductionToReduction |
352 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
353 | using OpRewritePattern::OpRewritePattern; |
354 | |
355 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
356 | PatternRewriter &rewriter) const override { |
357 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
358 | if (srcRank != 2) |
359 | return failure(); |
360 | |
361 | if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) |
362 | return failure(); |
363 | |
364 | // Vector mask setup. |
365 | OpBuilder::InsertionGuard guard(rewriter); |
366 | auto maskableOp = |
367 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
368 | Operation *rootOp; |
369 | if (maskableOp.isMasked()) { |
370 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
371 | rootOp = maskableOp.getMaskingOp(); |
372 | } else { |
373 | rootOp = multiReductionOp; |
374 | } |
375 | |
376 | auto loc = multiReductionOp.getLoc(); |
377 | Value result = rewriter.create<arith::ConstantOp>( |
378 | loc, multiReductionOp.getDestType(), |
379 | rewriter.getZeroAttr(multiReductionOp.getDestType())); |
380 | int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; |
381 | |
382 | for (int i = 0; i < outerDim; ++i) { |
383 | auto v = rewriter.create<vector::ExtractOp>( |
384 | loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); |
385 | auto acc = rewriter.create<vector::ExtractOp>( |
386 | loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i}); |
387 | Operation *reductionOp = rewriter.create<vector::ReductionOp>( |
388 | loc, multiReductionOp.getKind(), v, acc); |
389 | |
390 | // If masked, slice the mask and mask the new reduction operation. |
391 | if (maskableOp.isMasked()) { |
392 | Value mask = rewriter.create<vector::ExtractOp>( |
393 | loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i}); |
394 | reductionOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: reductionOp, mask); |
395 | } |
396 | |
397 | result = rewriter.create<vector::InsertElementOp>( |
398 | loc, reductionOp->getResult(0), result, |
399 | rewriter.create<arith::ConstantIndexOp>(loc, i)); |
400 | } |
401 | |
402 | rewriter.replaceOp(op: rootOp, newValues: result); |
403 | return success(); |
404 | } |
405 | }; |
406 | |
407 | /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d |
408 | /// form with both a single parallel and reduction dimension. |
409 | /// This is achieved with a simple vector.shape_cast that inserts a leading 1. |
410 | /// The case with a single parallel dimension is a noop and folds away |
411 | /// separately. |
412 | struct OneDimMultiReductionToTwoDim |
413 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
414 | using OpRewritePattern::OpRewritePattern; |
415 | |
416 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
417 | PatternRewriter &rewriter) const override { |
418 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
419 | // Rank-1 or bail. |
420 | if (srcRank != 1) |
421 | return failure(); |
422 | |
423 | // Vector mask setup. |
424 | OpBuilder::InsertionGuard guard(rewriter); |
425 | auto maskableOp = |
426 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
427 | Operation *rootOp; |
428 | Value mask; |
429 | if (maskableOp.isMasked()) { |
430 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
431 | rootOp = maskableOp.getMaskingOp(); |
432 | mask = maskableOp.getMaskingOp().getMask(); |
433 | } else { |
434 | rootOp = multiReductionOp; |
435 | } |
436 | |
437 | auto loc = multiReductionOp.getLoc(); |
438 | auto srcVectorType = multiReductionOp.getSourceVectorType(); |
439 | auto srcShape = srcVectorType.getShape(); |
440 | auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()}, |
441 | srcVectorType.getElementType()); |
442 | auto accType = |
443 | VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); |
444 | assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) && |
445 | "multi_reduction with a single dimension expects a scalar result" ); |
446 | |
447 | // If the unique dim is reduced and we insert a parallel in front, we need a |
448 | // {false, true} mask. |
449 | SmallVector<bool, 2> reductionMask{false, true}; |
450 | |
451 | /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) |
452 | Value cast = rewriter.create<vector::ShapeCastOp>( |
453 | loc, castedType, multiReductionOp.getSource()); |
454 | Value castAcc = rewriter.create<vector::BroadcastOp>( |
455 | loc, accType, multiReductionOp.getAcc()); |
456 | Value castMask; |
457 | if (maskableOp.isMasked()) { |
458 | auto maskType = llvm::cast<ShapedType>(mask.getType()); |
459 | auto castMaskType = |
460 | VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()}, |
461 | maskType.getElementType()); |
462 | castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask); |
463 | } |
464 | |
465 | Operation *newOp = rewriter.create<vector::MultiDimReductionOp>( |
466 | loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); |
467 | newOp = vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: castMask); |
468 | |
469 | rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0), |
470 | ArrayRef<int64_t>{0}); |
471 | return success(); |
472 | } |
473 | }; |
474 | |
475 | struct LowerVectorMultiReductionPass |
476 | : public vector::impl::LowerVectorMultiReductionBase< |
477 | LowerVectorMultiReductionPass> { |
478 | LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) { |
479 | this->loweringStrategy = option; |
480 | } |
481 | |
482 | void runOnOperation() override { |
483 | Operation *op = getOperation(); |
484 | MLIRContext *context = op->getContext(); |
485 | |
486 | RewritePatternSet loweringPatterns(context); |
487 | populateVectorMultiReductionLoweringPatterns(loweringPatterns, |
488 | this->loweringStrategy); |
489 | |
490 | if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) |
491 | signalPassFailure(); |
492 | } |
493 | |
494 | void getDependentDialects(DialectRegistry ®istry) const override { |
495 | registry.insert<vector::VectorDialect>(); |
496 | } |
497 | }; |
498 | |
499 | } // namespace |
500 | |
501 | void mlir::vector::populateVectorMultiReductionLoweringPatterns( |
502 | RewritePatternSet &patterns, VectorMultiReductionLowering options, |
503 | PatternBenefit benefit) { |
504 | patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( |
505 | patterns.getContext(), options, benefit); |
506 | patterns.add<OneDimMultiReductionToTwoDim>(arg: patterns.getContext(), args&: benefit); |
507 | if (options == VectorMultiReductionLowering ::InnerReduction) |
508 | patterns.add<TwoDimMultiReductionToReduction>(arg: patterns.getContext(), |
509 | args&: benefit); |
510 | else |
511 | patterns.add<TwoDimMultiReductionToElementWise>(arg: patterns.getContext(), |
512 | args&: benefit); |
513 | } |
514 | |
515 | std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass( |
516 | vector::VectorMultiReductionLowering option) { |
517 | return std::make_unique<LowerVectorMultiReductionPass>(option); |
518 | } |
519 | |