1//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2//
3/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.multi_reduction' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17#include "mlir/Dialect/Vector/Transforms/Passes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22namespace mlir {
23namespace vector {
24#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26} // namespace vector
27} // namespace mlir
28
29#define DEBUG_TYPE "vector-multi-reduction"
30
31using namespace mlir;
32
33namespace {
34/// This file implements the following transformations as composable atomic
35/// patterns.
36
37/// Converts vector.multi_reduction into inner-most/outer-most reduction form
38/// by using vector.transpose
39class InnerOuterDimReductionConversion
40 : public OpRewritePattern<vector::MultiDimReductionOp> {
41public:
42 using OpRewritePattern::OpRewritePattern;
43
44 explicit InnerOuterDimReductionConversion(
45 MLIRContext *context, vector::VectorMultiReductionLowering options,
46 PatternBenefit benefit = 1)
47 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
50
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52 PatternRewriter &rewriter) const override {
53 // Vector mask setup.
54 OpBuilder::InsertionGuard guard(rewriter);
55 auto maskableOp =
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57 Operation *rootOp;
58 if (maskableOp.isMasked()) {
59 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60 rootOp = maskableOp.getMaskingOp();
61 } else {
62 rootOp = multiReductionOp;
63 }
64
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68
69 // Separate reduction and parallel dims
70 ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72 reductionDims.end());
73 int64_t reductionSize = reductionDims.size();
74 SmallVector<int64_t, 4> parallelDims;
75 for (int64_t i = 0; i < srcRank; ++i)
76 if (!reductionDimsSet.contains(i))
77 parallelDims.push_back(i);
78
79 // Add transpose only if inner-most/outer-most dimensions are not parallel
80 // and there are parallel dims.
81 if (parallelDims.empty())
82 return failure();
83 if (useInnerDimsForReduction &&
84 (parallelDims ==
85 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86 return failure();
87
88 if (!useInnerDimsForReduction &&
89 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90 reductionDims.size(),
91 parallelDims.size() + reductionDims.size()))))
92 return failure();
93
94 SmallVector<int64_t, 4> indices;
95 if (useInnerDimsForReduction) {
96 indices.append(parallelDims.begin(), parallelDims.end());
97 indices.append(reductionDims.begin(), reductionDims.end());
98 } else {
99 indices.append(reductionDims.begin(), reductionDims.end());
100 indices.append(parallelDims.begin(), parallelDims.end());
101 }
102
103 // If masked, transpose the original mask.
104 Value transposedMask;
105 if (maskableOp.isMasked()) {
106 transposedMask = rewriter.create<vector::TransposeOp>(
107 loc, maskableOp.getMaskingOp().getMask(), indices);
108 }
109
110 // Transpose reduction source.
111 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
112 SmallVector<bool> reductionMask(srcRank, false);
113 for (int i = 0; i < reductionSize; ++i) {
114 if (useInnerDimsForReduction)
115 reductionMask[srcRank - i - 1] = true;
116 else
117 reductionMask[i] = true;
118 }
119
120 Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
121 multiReductionOp.getLoc(), transposeOp.getResult(),
122 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123 newMultiRedOp =
124 mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiRedOp, mask: transposedMask);
125
126 rewriter.replaceOp(op: rootOp, newValues: newMultiRedOp->getResult(idx: 0));
127 return success();
128 }
129
130private:
131 const bool useInnerDimsForReduction;
132};
133
134/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135/// dimensions are either inner most or outer most.
136class ReduceMultiDimReductionRank
137 : public OpRewritePattern<vector::MultiDimReductionOp> {
138public:
139 using OpRewritePattern::OpRewritePattern;
140
141 explicit ReduceMultiDimReductionRank(
142 MLIRContext *context, vector::VectorMultiReductionLowering options,
143 PatternBenefit benefit = 1)
144 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
145 useInnerDimsForReduction(
146 options == vector::VectorMultiReductionLowering::InnerReduction) {}
147
148 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149 PatternRewriter &rewriter) const override {
150 // Vector mask setup.
151 OpBuilder::InsertionGuard guard(rewriter);
152 auto maskableOp =
153 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154 Operation *rootOp;
155 if (maskableOp.isMasked()) {
156 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157 rootOp = maskableOp.getMaskingOp();
158 } else {
159 rootOp = multiReductionOp;
160 }
161
162 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164 auto srcScalableDims =
165 multiReductionOp.getSourceVectorType().getScalableDims();
166 auto loc = multiReductionOp.getLoc();
167
168 // If rank less than 2, nothing to do.
169 if (srcRank < 2)
170 return failure();
171
172 // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173 // `vscale * vscale` that's currently not modelled.
174 if (llvm::count(srcScalableDims, true) > 1)
175 return failure();
176
177 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180 return failure();
181
182 // 1. Separate reduction and parallel dims.
183 SmallVector<int64_t, 4> parallelDims, parallelShapes;
184 SmallVector<bool, 4> parallelScalableDims;
185 SmallVector<int64_t, 4> reductionDims, reductionShapes;
186 bool isReductionDimScalable = false;
187 for (const auto &it : llvm::enumerate(reductionMask)) {
188 int64_t i = it.index();
189 bool isReduction = it.value();
190 if (isReduction) {
191 reductionDims.push_back(i);
192 reductionShapes.push_back(srcShape[i]);
193 isReductionDimScalable |= srcScalableDims[i];
194 } else {
195 parallelDims.push_back(i);
196 parallelShapes.push_back(srcShape[i]);
197 parallelScalableDims.push_back(srcScalableDims[i]);
198 }
199 }
200
201 // 2. Compute flattened parallel and reduction sizes.
202 int flattenedParallelDim = 0;
203 int flattenedReductionDim = 0;
204 if (!parallelShapes.empty()) {
205 flattenedParallelDim = 1;
206 for (auto d : parallelShapes)
207 flattenedParallelDim *= d;
208 }
209 if (!reductionShapes.empty()) {
210 flattenedReductionDim = 1;
211 for (auto d : reductionShapes)
212 flattenedReductionDim *= d;
213 }
214 // We must at least have some parallel or some reduction.
215 assert((flattenedParallelDim || flattenedReductionDim) &&
216 "expected at least one parallel or reduction dim");
217
218 // 3. Fail if reduction/parallel dims are not contiguous.
219 // Check parallelDims are exactly [0 .. size).
220 int64_t counter = 0;
221 if (useInnerDimsForReduction &&
222 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223 return failure();
224 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225 counter = reductionDims.size();
226 if (!useInnerDimsForReduction &&
227 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228 return failure();
229
230 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231 // a single parallel (resp. reduction) dim.
232 SmallVector<bool, 2> mask;
233 SmallVector<bool, 2> scalableDims;
234 SmallVector<int64_t, 2> vectorShape;
235 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236 if (flattenedParallelDim) {
237 mask.push_back(false);
238 vectorShape.push_back(Elt: flattenedParallelDim);
239 scalableDims.push_back(isParallelDimScalable);
240 }
241 if (flattenedReductionDim) {
242 mask.push_back(true);
243 vectorShape.push_back(Elt: flattenedReductionDim);
244 scalableDims.push_back(isReductionDimScalable);
245 }
246 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247 std::swap(mask.front(), mask.back());
248 std::swap(vectorShape.front(), vectorShape.back());
249 std::swap(scalableDims.front(), scalableDims.back());
250 }
251
252 Value newVectorMask;
253 if (maskableOp.isMasked()) {
254 Value vectorMask = maskableOp.getMaskingOp().getMask();
255 auto maskCastedType = VectorType::get(
256 vectorShape,
257 llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258 newVectorMask =
259 rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
260 }
261
262 auto castedType = VectorType::get(
263 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264 scalableDims);
265 Value cast = rewriter.create<vector::ShapeCastOp>(
266 loc, castedType, multiReductionOp.getSource());
267
268 Value acc = multiReductionOp.getAcc();
269 if (flattenedParallelDim) {
270 auto accType = VectorType::get(
271 {flattenedParallelDim},
272 multiReductionOp.getSourceVectorType().getElementType(),
273 /*scalableDims=*/{isParallelDimScalable});
274 acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
275 }
276 // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277 // most dim as reduction.
278 Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
279 loc, cast, acc, mask, multiReductionOp.getKind());
280 newMultiDimRedOp =
281 mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiDimRedOp, mask: newVectorMask);
282
283 // 7. If there are no parallel shapes, the result is a scalar.
284 // TODO: support 0-d vectors when available.
285 if (parallelShapes.empty()) {
286 rewriter.replaceOp(op: rootOp, newValues: newMultiDimRedOp->getResult(idx: 0));
287 return success();
288 }
289
290 // 8. Creates shape cast for the output n-D -> 2-D.
291 VectorType outputCastedType = VectorType::get(
292 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293 parallelScalableDims);
294 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295 rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296 return success();
297 }
298
299private:
300 const bool useInnerDimsForReduction;
301};
302
303/// Unrolls vector.multi_reduction with outermost reductions
304/// and combines results
305struct TwoDimMultiReductionToElementWise
306 : public OpRewritePattern<vector::MultiDimReductionOp> {
307 using OpRewritePattern::OpRewritePattern;
308
309 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310 PatternRewriter &rewriter) const override {
311 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
312 // Rank-2 ["parallel", "reduce"] or bail.
313 if (srcRank != 2)
314 return failure();
315
316 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
317 return failure();
318
319 auto loc = multiReductionOp.getLoc();
320 ArrayRef<int64_t> srcShape =
321 multiReductionOp.getSourceVectorType().getShape();
322
323 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
324 if (!elementType.isIntOrIndexOrFloat())
325 return failure();
326
327 OpBuilder::InsertionGuard guard(rewriter);
328 auto maskableOp =
329 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
330 Operation *rootOp;
331 Value mask = nullptr;
332 if (maskableOp.isMasked()) {
333 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
334 rootOp = maskableOp.getMaskingOp();
335 mask = maskableOp.getMaskingOp().getMask();
336 } else {
337 rootOp = multiReductionOp;
338 }
339
340 Value result = multiReductionOp.getAcc();
341 for (int64_t i = 0; i < srcShape[0]; i++) {
342 auto operand = rewriter.create<vector::ExtractOp>(
343 loc, multiReductionOp.getSource(), i);
344 Value extractMask = nullptr;
345 if (mask) {
346 extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
347 }
348 result =
349 makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand,
350 result, /*fastmath=*/nullptr, extractMask);
351 }
352
353 rewriter.replaceOp(op: rootOp, newValues: result);
354 return success();
355 }
356};
357
358/// Converts 2d vector.multi_reduction with inner most reduction dimension into
359/// a sequence of vector.reduction ops.
360struct TwoDimMultiReductionToReduction
361 : public OpRewritePattern<vector::MultiDimReductionOp> {
362 using OpRewritePattern::OpRewritePattern;
363
364 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
365 PatternRewriter &rewriter) const override {
366 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
367 if (srcRank != 2)
368 return failure();
369
370 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
371 return failure();
372
373 // Vector mask setup.
374 OpBuilder::InsertionGuard guard(rewriter);
375 auto maskableOp =
376 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
377 Operation *rootOp;
378 if (maskableOp.isMasked()) {
379 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
380 rootOp = maskableOp.getMaskingOp();
381 } else {
382 rootOp = multiReductionOp;
383 }
384
385 auto loc = multiReductionOp.getLoc();
386 Value result = rewriter.create<arith::ConstantOp>(
387 loc, multiReductionOp.getDestType(),
388 rewriter.getZeroAttr(multiReductionOp.getDestType()));
389 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
390
391 for (int i = 0; i < outerDim; ++i) {
392 auto v = rewriter.create<vector::ExtractOp>(
393 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
394 auto acc = rewriter.create<vector::ExtractOp>(
395 loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
396 Operation *reductionOp = rewriter.create<vector::ReductionOp>(
397 loc, multiReductionOp.getKind(), v, acc);
398
399 // If masked, slice the mask and mask the new reduction operation.
400 if (maskableOp.isMasked()) {
401 Value mask = rewriter.create<vector::ExtractOp>(
402 loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
403 reductionOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: reductionOp, mask);
404 }
405
406 result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
407 result, i);
408 }
409
410 rewriter.replaceOp(op: rootOp, newValues: result);
411 return success();
412 }
413};
414
415/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
416/// form with both a single parallel and reduction dimension.
417/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
418/// The case with a single parallel dimension is a noop and folds away
419/// separately.
420struct OneDimMultiReductionToTwoDim
421 : public OpRewritePattern<vector::MultiDimReductionOp> {
422 using OpRewritePattern::OpRewritePattern;
423
424 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
425 PatternRewriter &rewriter) const override {
426 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
427 // Rank-1 or bail.
428 if (srcRank != 1)
429 return failure();
430
431 // Vector mask setup.
432 OpBuilder::InsertionGuard guard(rewriter);
433 auto maskableOp =
434 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
435 Operation *rootOp;
436 Value mask;
437 if (maskableOp.isMasked()) {
438 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
439 rootOp = maskableOp.getMaskingOp();
440 mask = maskableOp.getMaskingOp().getMask();
441 } else {
442 rootOp = multiReductionOp;
443 }
444
445 auto loc = multiReductionOp.getLoc();
446 auto srcVectorType = multiReductionOp.getSourceVectorType();
447 auto srcShape = srcVectorType.getShape();
448 auto castedType = VectorType::get(
449 ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
450 ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
451
452 auto accType =
453 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
454 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
455 "multi_reduction with a single dimension expects a scalar result");
456
457 // If the unique dim is reduced and we insert a parallel in front, we need a
458 // {false, true} mask.
459 SmallVector<bool, 2> reductionMask{false, true};
460
461 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
462 Value cast = rewriter.create<vector::ShapeCastOp>(
463 loc, castedType, multiReductionOp.getSource());
464 Value castAcc = rewriter.create<vector::BroadcastOp>(
465 loc, accType, multiReductionOp.getAcc());
466 Value castMask;
467 if (maskableOp.isMasked()) {
468 auto maskType = llvm::cast<VectorType>(mask.getType());
469 auto castMaskType = VectorType::get(
470 ArrayRef<int64_t>{1, maskType.getShape().back()},
471 maskType.getElementType(),
472 ArrayRef<bool>{false, maskType.getScalableDims().back()});
473 castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
474 }
475
476 Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
477 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
478 newOp = vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: castMask);
479
480 rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
481 ArrayRef<int64_t>{0});
482 return success();
483 }
484};
485
486struct LowerVectorMultiReductionPass
487 : public vector::impl::LowerVectorMultiReductionBase<
488 LowerVectorMultiReductionPass> {
489 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
490 this->loweringStrategy = option;
491 }
492
493 void runOnOperation() override {
494 Operation *op = getOperation();
495 MLIRContext *context = op->getContext();
496
497 RewritePatternSet loweringPatterns(context);
498 populateVectorMultiReductionLoweringPatterns(loweringPatterns,
499 this->loweringStrategy);
500
501 if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
502 signalPassFailure();
503 }
504
505 void getDependentDialects(DialectRegistry &registry) const override {
506 registry.insert<vector::VectorDialect>();
507 }
508};
509
510} // namespace
511
512void mlir::vector::populateVectorMultiReductionLoweringPatterns(
513 RewritePatternSet &patterns, VectorMultiReductionLowering options,
514 PatternBenefit benefit) {
515 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
516 patterns.getContext(), options, benefit);
517 patterns.add<OneDimMultiReductionToTwoDim>(arg: patterns.getContext(), args&: benefit);
518 if (options == VectorMultiReductionLowering ::InnerReduction)
519 patterns.add<TwoDimMultiReductionToReduction>(arg: patterns.getContext(),
520 args&: benefit);
521 else
522 patterns.add<TwoDimMultiReductionToElementWise>(arg: patterns.getContext(),
523 args&: benefit);
524}
525
526std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
527 vector::VectorMultiReductionLowering option) {
528 return std::make_unique<LowerVectorMultiReductionPass>(option);
529}
530

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp