1 | //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// |
---|---|
2 | // |
3 | /// Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | /// Exceptions. See https://llvm.org/LICENSE.txt for license information. |
5 | /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.multi_reduction' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
17 | #include "mlir/Dialect/Vector/Transforms/Passes.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/TypeUtilities.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | |
22 | namespace mlir { |
23 | namespace vector { |
24 | #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION |
25 | #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
26 | } // namespace vector |
27 | } // namespace mlir |
28 | |
29 | #define DEBUG_TYPE "vector-multi-reduction" |
30 | |
31 | using namespace mlir; |
32 | |
33 | namespace { |
34 | /// This file implements the following transformations as composable atomic |
35 | /// patterns. |
36 | |
37 | /// Converts vector.multi_reduction into inner-most/outer-most reduction form |
38 | /// by using vector.transpose |
39 | class InnerOuterDimReductionConversion |
40 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
41 | public: |
42 | using OpRewritePattern::OpRewritePattern; |
43 | |
44 | explicit InnerOuterDimReductionConversion( |
45 | MLIRContext *context, vector::VectorMultiReductionLowering options, |
46 | PatternBenefit benefit = 1) |
47 | : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
48 | useInnerDimsForReduction( |
49 | options == vector::VectorMultiReductionLowering::InnerReduction) {} |
50 | |
51 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
52 | PatternRewriter &rewriter) const override { |
53 | // Vector mask setup. |
54 | OpBuilder::InsertionGuard guard(rewriter); |
55 | auto maskableOp = |
56 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
57 | Operation *rootOp; |
58 | if (maskableOp.isMasked()) { |
59 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
60 | rootOp = maskableOp.getMaskingOp(); |
61 | } else { |
62 | rootOp = multiReductionOp; |
63 | } |
64 | |
65 | auto src = multiReductionOp.getSource(); |
66 | auto loc = multiReductionOp.getLoc(); |
67 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
68 | |
69 | // Separate reduction and parallel dims |
70 | ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims(); |
71 | llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), |
72 | reductionDims.end()); |
73 | int64_t reductionSize = reductionDims.size(); |
74 | SmallVector<int64_t, 4> parallelDims; |
75 | for (int64_t i = 0; i < srcRank; ++i) |
76 | if (!reductionDimsSet.contains(i)) |
77 | parallelDims.push_back(i); |
78 | |
79 | // Add transpose only if inner-most/outer-most dimensions are not parallel |
80 | // and there are parallel dims. |
81 | if (parallelDims.empty()) |
82 | return failure(); |
83 | if (useInnerDimsForReduction && |
84 | (parallelDims == |
85 | llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) |
86 | return failure(); |
87 | |
88 | if (!useInnerDimsForReduction && |
89 | (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>( |
90 | reductionDims.size(), |
91 | parallelDims.size() + reductionDims.size())))) |
92 | return failure(); |
93 | |
94 | SmallVector<int64_t, 4> indices; |
95 | if (useInnerDimsForReduction) { |
96 | indices.append(parallelDims.begin(), parallelDims.end()); |
97 | indices.append(reductionDims.begin(), reductionDims.end()); |
98 | } else { |
99 | indices.append(reductionDims.begin(), reductionDims.end()); |
100 | indices.append(parallelDims.begin(), parallelDims.end()); |
101 | } |
102 | |
103 | // If masked, transpose the original mask. |
104 | Value transposedMask; |
105 | if (maskableOp.isMasked()) { |
106 | transposedMask = rewriter.create<vector::TransposeOp>( |
107 | loc, maskableOp.getMaskingOp().getMask(), indices); |
108 | } |
109 | |
110 | // Transpose reduction source. |
111 | auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); |
112 | SmallVector<bool> reductionMask(srcRank, false); |
113 | for (int i = 0; i < reductionSize; ++i) { |
114 | if (useInnerDimsForReduction) |
115 | reductionMask[srcRank - i - 1] = true; |
116 | else |
117 | reductionMask[i] = true; |
118 | } |
119 | |
120 | Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>( |
121 | multiReductionOp.getLoc(), transposeOp.getResult(), |
122 | multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); |
123 | newMultiRedOp = |
124 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiRedOp, mask: transposedMask); |
125 | |
126 | rewriter.replaceOp(op: rootOp, newValues: newMultiRedOp->getResult(idx: 0)); |
127 | return success(); |
128 | } |
129 | |
130 | private: |
131 | const bool useInnerDimsForReduction; |
132 | }; |
133 | |
134 | /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction |
135 | /// dimensions are either inner most or outer most. |
136 | class ReduceMultiDimReductionRank |
137 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
138 | public: |
139 | using OpRewritePattern::OpRewritePattern; |
140 | |
141 | explicit ReduceMultiDimReductionRank( |
142 | MLIRContext *context, vector::VectorMultiReductionLowering options, |
143 | PatternBenefit benefit = 1) |
144 | : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), |
145 | useInnerDimsForReduction( |
146 | options == vector::VectorMultiReductionLowering::InnerReduction) {} |
147 | |
148 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
149 | PatternRewriter &rewriter) const override { |
150 | // Vector mask setup. |
151 | OpBuilder::InsertionGuard guard(rewriter); |
152 | auto maskableOp = |
153 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
154 | Operation *rootOp; |
155 | if (maskableOp.isMasked()) { |
156 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
157 | rootOp = maskableOp.getMaskingOp(); |
158 | } else { |
159 | rootOp = multiReductionOp; |
160 | } |
161 | |
162 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
163 | auto srcShape = multiReductionOp.getSourceVectorType().getShape(); |
164 | auto srcScalableDims = |
165 | multiReductionOp.getSourceVectorType().getScalableDims(); |
166 | auto loc = multiReductionOp.getLoc(); |
167 | |
168 | // If rank less than 2, nothing to do. |
169 | if (srcRank < 2) |
170 | return failure(); |
171 | |
172 | // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. |
173 | // `vscale * vscale` that's currently not modelled. |
174 | if (llvm::count(srcScalableDims, true) > 1) |
175 | return failure(); |
176 | |
177 | // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. |
178 | SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); |
179 | if (srcRank == 2 && reductionMask.front() != reductionMask.back()) |
180 | return failure(); |
181 | |
182 | // 1. Separate reduction and parallel dims. |
183 | SmallVector<int64_t, 4> parallelDims, parallelShapes; |
184 | SmallVector<bool, 4> parallelScalableDims; |
185 | SmallVector<int64_t, 4> reductionDims, reductionShapes; |
186 | bool isReductionDimScalable = false; |
187 | for (const auto &it : llvm::enumerate(reductionMask)) { |
188 | int64_t i = it.index(); |
189 | bool isReduction = it.value(); |
190 | if (isReduction) { |
191 | reductionDims.push_back(i); |
192 | reductionShapes.push_back(srcShape[i]); |
193 | isReductionDimScalable |= srcScalableDims[i]; |
194 | } else { |
195 | parallelDims.push_back(i); |
196 | parallelShapes.push_back(srcShape[i]); |
197 | parallelScalableDims.push_back(srcScalableDims[i]); |
198 | } |
199 | } |
200 | |
201 | // 2. Compute flattened parallel and reduction sizes. |
202 | int flattenedParallelDim = 0; |
203 | int flattenedReductionDim = 0; |
204 | if (!parallelShapes.empty()) { |
205 | flattenedParallelDim = 1; |
206 | for (auto d : parallelShapes) |
207 | flattenedParallelDim *= d; |
208 | } |
209 | if (!reductionShapes.empty()) { |
210 | flattenedReductionDim = 1; |
211 | for (auto d : reductionShapes) |
212 | flattenedReductionDim *= d; |
213 | } |
214 | // We must at least have some parallel or some reduction. |
215 | assert((flattenedParallelDim || flattenedReductionDim) && |
216 | "expected at least one parallel or reduction dim"); |
217 | |
218 | // 3. Fail if reduction/parallel dims are not contiguous. |
219 | // Check parallelDims are exactly [0 .. size). |
220 | int64_t counter = 0; |
221 | if (useInnerDimsForReduction && |
222 | llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) |
223 | return failure(); |
224 | // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). |
225 | counter = reductionDims.size(); |
226 | if (!useInnerDimsForReduction && |
227 | llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) |
228 | return failure(); |
229 | |
230 | // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into |
231 | // a single parallel (resp. reduction) dim. |
232 | SmallVector<bool, 2> mask; |
233 | SmallVector<bool, 2> scalableDims; |
234 | SmallVector<int64_t, 2> vectorShape; |
235 | bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); |
236 | if (flattenedParallelDim) { |
237 | mask.push_back(false); |
238 | vectorShape.push_back(Elt: flattenedParallelDim); |
239 | scalableDims.push_back(isParallelDimScalable); |
240 | } |
241 | if (flattenedReductionDim) { |
242 | mask.push_back(true); |
243 | vectorShape.push_back(Elt: flattenedReductionDim); |
244 | scalableDims.push_back(isReductionDimScalable); |
245 | } |
246 | if (!useInnerDimsForReduction && vectorShape.size() == 2) { |
247 | std::swap(mask.front(), mask.back()); |
248 | std::swap(vectorShape.front(), vectorShape.back()); |
249 | std::swap(scalableDims.front(), scalableDims.back()); |
250 | } |
251 | |
252 | Value newVectorMask; |
253 | if (maskableOp.isMasked()) { |
254 | Value vectorMask = maskableOp.getMaskingOp().getMask(); |
255 | auto maskCastedType = VectorType::get( |
256 | vectorShape, |
257 | llvm::cast<VectorType>(vectorMask.getType()).getElementType()); |
258 | newVectorMask = |
259 | rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask); |
260 | } |
261 | |
262 | auto castedType = VectorType::get( |
263 | vectorShape, multiReductionOp.getSourceVectorType().getElementType(), |
264 | scalableDims); |
265 | Value cast = rewriter.create<vector::ShapeCastOp>( |
266 | loc, castedType, multiReductionOp.getSource()); |
267 | |
268 | Value acc = multiReductionOp.getAcc(); |
269 | if (flattenedParallelDim) { |
270 | auto accType = VectorType::get( |
271 | {flattenedParallelDim}, |
272 | multiReductionOp.getSourceVectorType().getElementType(), |
273 | /*scalableDims=*/{isParallelDimScalable}); |
274 | acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc); |
275 | } |
276 | // 6. Creates the flattened form of vector.multi_reduction with inner/outer |
277 | // most dim as reduction. |
278 | Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>( |
279 | loc, cast, acc, mask, multiReductionOp.getKind()); |
280 | newMultiDimRedOp = |
281 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiDimRedOp, mask: newVectorMask); |
282 | |
283 | // 7. If there are no parallel shapes, the result is a scalar. |
284 | // TODO: support 0-d vectors when available. |
285 | if (parallelShapes.empty()) { |
286 | rewriter.replaceOp(op: rootOp, newValues: newMultiDimRedOp->getResult(idx: 0)); |
287 | return success(); |
288 | } |
289 | |
290 | // 8. Creates shape cast for the output n-D -> 2-D. |
291 | VectorType outputCastedType = VectorType::get( |
292 | parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), |
293 | parallelScalableDims); |
294 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
295 | rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); |
296 | return success(); |
297 | } |
298 | |
299 | private: |
300 | const bool useInnerDimsForReduction; |
301 | }; |
302 | |
303 | /// Unrolls vector.multi_reduction with outermost reductions |
304 | /// and combines results |
305 | struct TwoDimMultiReductionToElementWise |
306 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
307 | using OpRewritePattern::OpRewritePattern; |
308 | |
309 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
310 | PatternRewriter &rewriter) const override { |
311 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
312 | // Rank-2 ["parallel", "reduce"] or bail. |
313 | if (srcRank != 2) |
314 | return failure(); |
315 | |
316 | if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) |
317 | return failure(); |
318 | |
319 | auto loc = multiReductionOp.getLoc(); |
320 | ArrayRef<int64_t> srcShape = |
321 | multiReductionOp.getSourceVectorType().getShape(); |
322 | |
323 | Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); |
324 | if (!elementType.isIntOrIndexOrFloat()) |
325 | return failure(); |
326 | |
327 | OpBuilder::InsertionGuard guard(rewriter); |
328 | auto maskableOp = |
329 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
330 | Operation *rootOp; |
331 | Value mask = nullptr; |
332 | if (maskableOp.isMasked()) { |
333 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
334 | rootOp = maskableOp.getMaskingOp(); |
335 | mask = maskableOp.getMaskingOp().getMask(); |
336 | } else { |
337 | rootOp = multiReductionOp; |
338 | } |
339 | |
340 | Value result = multiReductionOp.getAcc(); |
341 | for (int64_t i = 0; i < srcShape[0]; i++) { |
342 | auto operand = rewriter.create<vector::ExtractOp>( |
343 | loc, multiReductionOp.getSource(), i); |
344 | Value extractMask = nullptr; |
345 | if (mask) { |
346 | extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i); |
347 | } |
348 | result = |
349 | makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand, |
350 | result, /*fastmath=*/nullptr, extractMask); |
351 | } |
352 | |
353 | rewriter.replaceOp(op: rootOp, newValues: result); |
354 | return success(); |
355 | } |
356 | }; |
357 | |
358 | /// Converts 2d vector.multi_reduction with inner most reduction dimension into |
359 | /// a sequence of vector.reduction ops. |
360 | struct TwoDimMultiReductionToReduction |
361 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
362 | using OpRewritePattern::OpRewritePattern; |
363 | |
364 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
365 | PatternRewriter &rewriter) const override { |
366 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
367 | if (srcRank != 2) |
368 | return failure(); |
369 | |
370 | if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) |
371 | return failure(); |
372 | |
373 | // Vector mask setup. |
374 | OpBuilder::InsertionGuard guard(rewriter); |
375 | auto maskableOp = |
376 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
377 | Operation *rootOp; |
378 | if (maskableOp.isMasked()) { |
379 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
380 | rootOp = maskableOp.getMaskingOp(); |
381 | } else { |
382 | rootOp = multiReductionOp; |
383 | } |
384 | |
385 | auto loc = multiReductionOp.getLoc(); |
386 | Value result = rewriter.create<arith::ConstantOp>( |
387 | loc, multiReductionOp.getDestType(), |
388 | rewriter.getZeroAttr(multiReductionOp.getDestType())); |
389 | int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; |
390 | |
391 | for (int i = 0; i < outerDim; ++i) { |
392 | auto v = rewriter.create<vector::ExtractOp>( |
393 | loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); |
394 | auto acc = rewriter.create<vector::ExtractOp>( |
395 | loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i}); |
396 | Operation *reductionOp = rewriter.create<vector::ReductionOp>( |
397 | loc, multiReductionOp.getKind(), v, acc); |
398 | |
399 | // If masked, slice the mask and mask the new reduction operation. |
400 | if (maskableOp.isMasked()) { |
401 | Value mask = rewriter.create<vector::ExtractOp>( |
402 | loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i}); |
403 | reductionOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: reductionOp, mask); |
404 | } |
405 | |
406 | result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0), |
407 | result, i); |
408 | } |
409 | |
410 | rewriter.replaceOp(op: rootOp, newValues: result); |
411 | return success(); |
412 | } |
413 | }; |
414 | |
415 | /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d |
416 | /// form with both a single parallel and reduction dimension. |
417 | /// This is achieved with a simple vector.shape_cast that inserts a leading 1. |
418 | /// The case with a single parallel dimension is a noop and folds away |
419 | /// separately. |
420 | struct OneDimMultiReductionToTwoDim |
421 | : public OpRewritePattern<vector::MultiDimReductionOp> { |
422 | using OpRewritePattern::OpRewritePattern; |
423 | |
424 | LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, |
425 | PatternRewriter &rewriter) const override { |
426 | auto srcRank = multiReductionOp.getSourceVectorType().getRank(); |
427 | // Rank-1 or bail. |
428 | if (srcRank != 1) |
429 | return failure(); |
430 | |
431 | // Vector mask setup. |
432 | OpBuilder::InsertionGuard guard(rewriter); |
433 | auto maskableOp = |
434 | cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); |
435 | Operation *rootOp; |
436 | Value mask; |
437 | if (maskableOp.isMasked()) { |
438 | rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
439 | rootOp = maskableOp.getMaskingOp(); |
440 | mask = maskableOp.getMaskingOp().getMask(); |
441 | } else { |
442 | rootOp = multiReductionOp; |
443 | } |
444 | |
445 | auto loc = multiReductionOp.getLoc(); |
446 | auto srcVectorType = multiReductionOp.getSourceVectorType(); |
447 | auto srcShape = srcVectorType.getShape(); |
448 | auto castedType = VectorType::get( |
449 | ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(), |
450 | ArrayRef<bool>{false, srcVectorType.getScalableDims().back()}); |
451 | |
452 | auto accType = |
453 | VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); |
454 | assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) && |
455 | "multi_reduction with a single dimension expects a scalar result"); |
456 | |
457 | // If the unique dim is reduced and we insert a parallel in front, we need a |
458 | // {false, true} mask. |
459 | SmallVector<bool, 2> reductionMask{false, true}; |
460 | |
461 | /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) |
462 | Value cast = rewriter.create<vector::ShapeCastOp>( |
463 | loc, castedType, multiReductionOp.getSource()); |
464 | Value castAcc = rewriter.create<vector::BroadcastOp>( |
465 | loc, accType, multiReductionOp.getAcc()); |
466 | Value castMask; |
467 | if (maskableOp.isMasked()) { |
468 | auto maskType = llvm::cast<VectorType>(mask.getType()); |
469 | auto castMaskType = VectorType::get( |
470 | ArrayRef<int64_t>{1, maskType.getShape().back()}, |
471 | maskType.getElementType(), |
472 | ArrayRef<bool>{false, maskType.getScalableDims().back()}); |
473 | castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask); |
474 | } |
475 | |
476 | Operation *newOp = rewriter.create<vector::MultiDimReductionOp>( |
477 | loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); |
478 | newOp = vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: castMask); |
479 | |
480 | rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0), |
481 | ArrayRef<int64_t>{0}); |
482 | return success(); |
483 | } |
484 | }; |
485 | |
486 | struct LowerVectorMultiReductionPass |
487 | : public vector::impl::LowerVectorMultiReductionBase< |
488 | LowerVectorMultiReductionPass> { |
489 | LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) { |
490 | this->loweringStrategy = option; |
491 | } |
492 | |
493 | void runOnOperation() override { |
494 | Operation *op = getOperation(); |
495 | MLIRContext *context = op->getContext(); |
496 | |
497 | RewritePatternSet loweringPatterns(context); |
498 | populateVectorMultiReductionLoweringPatterns(loweringPatterns, |
499 | this->loweringStrategy); |
500 | |
501 | if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) |
502 | signalPassFailure(); |
503 | } |
504 | |
505 | void getDependentDialects(DialectRegistry ®istry) const override { |
506 | registry.insert<vector::VectorDialect>(); |
507 | } |
508 | }; |
509 | |
510 | } // namespace |
511 | |
512 | void mlir::vector::populateVectorMultiReductionLoweringPatterns( |
513 | RewritePatternSet &patterns, VectorMultiReductionLowering options, |
514 | PatternBenefit benefit) { |
515 | patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( |
516 | patterns.getContext(), options, benefit); |
517 | patterns.add<OneDimMultiReductionToTwoDim>(arg: patterns.getContext(), args&: benefit); |
518 | if (options == VectorMultiReductionLowering ::InnerReduction) |
519 | patterns.add<TwoDimMultiReductionToReduction>(arg: patterns.getContext(), |
520 | args&: benefit); |
521 | else |
522 | patterns.add<TwoDimMultiReductionToElementWise>(arg: patterns.getContext(), |
523 | args&: benefit); |
524 | } |
525 | |
526 | std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass( |
527 | vector::VectorMultiReductionLowering option) { |
528 | return std::make_unique<LowerVectorMultiReductionPass>(option); |
529 | } |
530 |
Definitions
- InnerOuterDimReductionConversion
- InnerOuterDimReductionConversion
- matchAndRewrite
- ReduceMultiDimReductionRank
- ReduceMultiDimReductionRank
- matchAndRewrite
- TwoDimMultiReductionToElementWise
- matchAndRewrite
- TwoDimMultiReductionToReduction
- matchAndRewrite
- OneDimMultiReductionToTwoDim
- matchAndRewrite
- LowerVectorMultiReductionPass
- LowerVectorMultiReductionPass
- runOnOperation
- getDependentDialects
- populateVectorMultiReductionLoweringPatterns
Improve your Profiling and Debugging skills
Find out more