1//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2//
3/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.multi_reduction' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17#include "mlir/Dialect/Vector/Transforms/Passes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/TypeUtilities.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22namespace mlir {
23namespace vector {
24#define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26} // namespace vector
27} // namespace mlir
28
29#define DEBUG_TYPE "vector-multi-reduction"
30
31using namespace mlir;
32
33namespace {
34/// This file implements the following transformations as composable atomic
35/// patterns.
36
37/// Converts vector.multi_reduction into inner-most/outer-most reduction form
38/// by using vector.transpose
39class InnerOuterDimReductionConversion
40 : public OpRewritePattern<vector::MultiDimReductionOp> {
41public:
42 using OpRewritePattern::OpRewritePattern;
43
44 explicit InnerOuterDimReductionConversion(
45 MLIRContext *context, vector::VectorMultiReductionLowering options,
46 PatternBenefit benefit = 1)
47 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
48 useInnerDimsForReduction(
49 options == vector::VectorMultiReductionLowering::InnerReduction) {}
50
51 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52 PatternRewriter &rewriter) const override {
53 // Vector mask setup.
54 OpBuilder::InsertionGuard guard(rewriter);
55 auto maskableOp =
56 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57 Operation *rootOp;
58 if (maskableOp.isMasked()) {
59 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60 rootOp = maskableOp.getMaskingOp();
61 } else {
62 rootOp = multiReductionOp;
63 }
64
65 auto src = multiReductionOp.getSource();
66 auto loc = multiReductionOp.getLoc();
67 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68
69 // Separate reduction and parallel dims
70 auto reductionDimsRange =
71 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
72 auto reductionDims = llvm::to_vector<4>(llvm::map_range(
73 reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
74 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
75 reductionDims.end());
76 int64_t reductionSize = reductionDims.size();
77 SmallVector<int64_t, 4> parallelDims;
78 for (int64_t i = 0; i < srcRank; ++i)
79 if (!reductionDimsSet.contains(i))
80 parallelDims.push_back(i);
81
82 // Add transpose only if inner-most/outer-most dimensions are not parallel
83 // and there are parallel dims.
84 if (parallelDims.empty())
85 return failure();
86 if (useInnerDimsForReduction &&
87 (parallelDims ==
88 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
89 return failure();
90
91 if (!useInnerDimsForReduction &&
92 (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
93 reductionDims.size(),
94 parallelDims.size() + reductionDims.size()))))
95 return failure();
96
97 SmallVector<int64_t, 4> indices;
98 if (useInnerDimsForReduction) {
99 indices.append(parallelDims.begin(), parallelDims.end());
100 indices.append(reductionDims.begin(), reductionDims.end());
101 } else {
102 indices.append(reductionDims.begin(), reductionDims.end());
103 indices.append(parallelDims.begin(), parallelDims.end());
104 }
105
106 // If masked, transpose the original mask.
107 Value transposedMask;
108 if (maskableOp.isMasked()) {
109 transposedMask = rewriter.create<vector::TransposeOp>(
110 loc, maskableOp.getMaskingOp().getMask(), indices);
111 }
112
113 // Transpose reduction source.
114 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
115 SmallVector<bool> reductionMask(srcRank, false);
116 for (int i = 0; i < reductionSize; ++i) {
117 if (useInnerDimsForReduction)
118 reductionMask[srcRank - i - 1] = true;
119 else
120 reductionMask[i] = true;
121 }
122
123 Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
124 multiReductionOp.getLoc(), transposeOp.getResult(),
125 multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
126 newMultiRedOp =
127 mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiRedOp, mask: transposedMask);
128
129 rewriter.replaceOp(op: rootOp, newValues: newMultiRedOp->getResult(idx: 0));
130 return success();
131 }
132
133private:
134 const bool useInnerDimsForReduction;
135};
136
137/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
138/// dimensions are either inner most or outer most.
139class ReduceMultiDimReductionRank
140 : public OpRewritePattern<vector::MultiDimReductionOp> {
141public:
142 using OpRewritePattern::OpRewritePattern;
143
144 explicit ReduceMultiDimReductionRank(
145 MLIRContext *context, vector::VectorMultiReductionLowering options,
146 PatternBenefit benefit = 1)
147 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
148 useInnerDimsForReduction(
149 options == vector::VectorMultiReductionLowering::InnerReduction) {}
150
151 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
152 PatternRewriter &rewriter) const override {
153 // Vector mask setup.
154 OpBuilder::InsertionGuard guard(rewriter);
155 auto maskableOp =
156 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
157 Operation *rootOp;
158 if (maskableOp.isMasked()) {
159 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
160 rootOp = maskableOp.getMaskingOp();
161 } else {
162 rootOp = multiReductionOp;
163 }
164
165 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
166 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
167 auto srcScalableDims =
168 multiReductionOp.getSourceVectorType().getScalableDims();
169 auto loc = multiReductionOp.getLoc();
170
171 // If rank less than 2, nothing to do.
172 if (srcRank < 2)
173 return failure();
174
175 // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
176 // `vscale * vscale` that's currently not modelled.
177 if (llvm::count(srcScalableDims, true) > 1)
178 return failure();
179
180 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
181 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
182 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
183 return failure();
184
185 // 1. Separate reduction and parallel dims.
186 SmallVector<int64_t, 4> parallelDims, parallelShapes;
187 SmallVector<bool, 4> parallelScalableDims;
188 SmallVector<int64_t, 4> reductionDims, reductionShapes;
189 bool isReductionDimScalable = false;
190 for (const auto &it : llvm::enumerate(reductionMask)) {
191 int64_t i = it.index();
192 bool isReduction = it.value();
193 if (isReduction) {
194 reductionDims.push_back(i);
195 reductionShapes.push_back(srcShape[i]);
196 isReductionDimScalable |= srcScalableDims[i];
197 } else {
198 parallelDims.push_back(i);
199 parallelShapes.push_back(srcShape[i]);
200 parallelScalableDims.push_back(srcScalableDims[i]);
201 }
202 }
203
204 // 2. Compute flattened parallel and reduction sizes.
205 int flattenedParallelDim = 0;
206 int flattenedReductionDim = 0;
207 if (!parallelShapes.empty()) {
208 flattenedParallelDim = 1;
209 for (auto d : parallelShapes)
210 flattenedParallelDim *= d;
211 }
212 if (!reductionShapes.empty()) {
213 flattenedReductionDim = 1;
214 for (auto d : reductionShapes)
215 flattenedReductionDim *= d;
216 }
217 // We must at least have some parallel or some reduction.
218 assert((flattenedParallelDim || flattenedReductionDim) &&
219 "expected at least one parallel or reduction dim");
220
221 // 3. Fail if reduction/parallel dims are not contiguous.
222 // Check parallelDims are exactly [0 .. size).
223 int64_t counter = 0;
224 if (useInnerDimsForReduction &&
225 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
226 return failure();
227 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
228 counter = reductionDims.size();
229 if (!useInnerDimsForReduction &&
230 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
231 return failure();
232
233 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
234 // a single parallel (resp. reduction) dim.
235 SmallVector<bool, 2> mask;
236 SmallVector<bool, 2> scalableDims;
237 SmallVector<int64_t, 2> vectorShape;
238 bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
239 if (flattenedParallelDim) {
240 mask.push_back(false);
241 vectorShape.push_back(Elt: flattenedParallelDim);
242 scalableDims.push_back(isParallelDimScalable);
243 }
244 if (flattenedReductionDim) {
245 mask.push_back(true);
246 vectorShape.push_back(Elt: flattenedReductionDim);
247 scalableDims.push_back(isReductionDimScalable);
248 }
249 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
250 std::swap(mask.front(), mask.back());
251 std::swap(vectorShape.front(), vectorShape.back());
252 std::swap(scalableDims.front(), scalableDims.back());
253 }
254
255 Value newVectorMask;
256 if (maskableOp.isMasked()) {
257 Value vectorMask = maskableOp.getMaskingOp().getMask();
258 auto maskCastedType = VectorType::get(
259 vectorShape,
260 llvm::cast<VectorType>(vectorMask.getType()).getElementType());
261 newVectorMask =
262 rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
263 }
264
265 auto castedType = VectorType::get(
266 vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
267 scalableDims);
268 Value cast = rewriter.create<vector::ShapeCastOp>(
269 loc, castedType, multiReductionOp.getSource());
270
271 Value acc = multiReductionOp.getAcc();
272 if (flattenedParallelDim) {
273 auto accType = VectorType::get(
274 {flattenedParallelDim},
275 multiReductionOp.getSourceVectorType().getElementType(),
276 /*scalableDims=*/{isParallelDimScalable});
277 acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
278 }
279 // 6. Creates the flattened form of vector.multi_reduction with inner/outer
280 // most dim as reduction.
281 Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
282 loc, cast, acc, mask, multiReductionOp.getKind());
283 newMultiDimRedOp =
284 mlir::vector::maskOperation(builder&: rewriter, maskableOp: newMultiDimRedOp, mask: newVectorMask);
285
286 // 7. If there are no parallel shapes, the result is a scalar.
287 // TODO: support 0-d vectors when available.
288 if (parallelShapes.empty()) {
289 rewriter.replaceOp(op: rootOp, newValues: newMultiDimRedOp->getResult(idx: 0));
290 return success();
291 }
292
293 // 8. Creates shape cast for the output n-D -> 2-D.
294 VectorType outputCastedType = VectorType::get(
295 parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
296 parallelScalableDims);
297 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
298 rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
299 return success();
300 }
301
302private:
303 const bool useInnerDimsForReduction;
304};
305
306/// Unrolls vector.multi_reduction with outermost reductions
307/// and combines results
308struct TwoDimMultiReductionToElementWise
309 : public OpRewritePattern<vector::MultiDimReductionOp> {
310 using OpRewritePattern::OpRewritePattern;
311
312 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
313 PatternRewriter &rewriter) const override {
314 auto maskableOp =
315 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
316 if (maskableOp.isMasked())
317 // TODO: Support masking.
318 return failure();
319
320 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
321 // Rank-2 ["parallel", "reduce"] or bail.
322 if (srcRank != 2)
323 return failure();
324
325 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
326 return failure();
327
328 auto loc = multiReductionOp.getLoc();
329 ArrayRef<int64_t> srcShape =
330 multiReductionOp.getSourceVectorType().getShape();
331
332 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
333 if (!elementType.isIntOrIndexOrFloat())
334 return failure();
335
336 Value result = multiReductionOp.getAcc();
337 for (int64_t i = 0; i < srcShape[0]; i++) {
338 auto operand = rewriter.create<vector::ExtractOp>(
339 loc, multiReductionOp.getSource(), i);
340 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
341 operand, result);
342 }
343
344 rewriter.replaceOp(multiReductionOp, result);
345 return success();
346 }
347};
348
349/// Converts 2d vector.multi_reduction with inner most reduction dimension into
350/// a sequence of vector.reduction ops.
351struct TwoDimMultiReductionToReduction
352 : public OpRewritePattern<vector::MultiDimReductionOp> {
353 using OpRewritePattern::OpRewritePattern;
354
355 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
356 PatternRewriter &rewriter) const override {
357 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
358 if (srcRank != 2)
359 return failure();
360
361 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
362 return failure();
363
364 // Vector mask setup.
365 OpBuilder::InsertionGuard guard(rewriter);
366 auto maskableOp =
367 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
368 Operation *rootOp;
369 if (maskableOp.isMasked()) {
370 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
371 rootOp = maskableOp.getMaskingOp();
372 } else {
373 rootOp = multiReductionOp;
374 }
375
376 auto loc = multiReductionOp.getLoc();
377 Value result = rewriter.create<arith::ConstantOp>(
378 loc, multiReductionOp.getDestType(),
379 rewriter.getZeroAttr(multiReductionOp.getDestType()));
380 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
381
382 for (int i = 0; i < outerDim; ++i) {
383 auto v = rewriter.create<vector::ExtractOp>(
384 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
385 auto acc = rewriter.create<vector::ExtractOp>(
386 loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
387 Operation *reductionOp = rewriter.create<vector::ReductionOp>(
388 loc, multiReductionOp.getKind(), v, acc);
389
390 // If masked, slice the mask and mask the new reduction operation.
391 if (maskableOp.isMasked()) {
392 Value mask = rewriter.create<vector::ExtractOp>(
393 loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
394 reductionOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: reductionOp, mask);
395 }
396
397 result = rewriter.create<vector::InsertElementOp>(
398 loc, reductionOp->getResult(0), result,
399 rewriter.create<arith::ConstantIndexOp>(loc, i));
400 }
401
402 rewriter.replaceOp(op: rootOp, newValues: result);
403 return success();
404 }
405};
406
407/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
408/// form with both a single parallel and reduction dimension.
409/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
410/// The case with a single parallel dimension is a noop and folds away
411/// separately.
412struct OneDimMultiReductionToTwoDim
413 : public OpRewritePattern<vector::MultiDimReductionOp> {
414 using OpRewritePattern::OpRewritePattern;
415
416 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
417 PatternRewriter &rewriter) const override {
418 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
419 // Rank-1 or bail.
420 if (srcRank != 1)
421 return failure();
422
423 // Vector mask setup.
424 OpBuilder::InsertionGuard guard(rewriter);
425 auto maskableOp =
426 cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
427 Operation *rootOp;
428 Value mask;
429 if (maskableOp.isMasked()) {
430 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
431 rootOp = maskableOp.getMaskingOp();
432 mask = maskableOp.getMaskingOp().getMask();
433 } else {
434 rootOp = multiReductionOp;
435 }
436
437 auto loc = multiReductionOp.getLoc();
438 auto srcVectorType = multiReductionOp.getSourceVectorType();
439 auto srcShape = srcVectorType.getShape();
440 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
441 srcVectorType.getElementType());
442 auto accType =
443 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
444 assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
445 "multi_reduction with a single dimension expects a scalar result");
446
447 // If the unique dim is reduced and we insert a parallel in front, we need a
448 // {false, true} mask.
449 SmallVector<bool, 2> reductionMask{false, true};
450
451 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
452 Value cast = rewriter.create<vector::ShapeCastOp>(
453 loc, castedType, multiReductionOp.getSource());
454 Value castAcc = rewriter.create<vector::BroadcastOp>(
455 loc, accType, multiReductionOp.getAcc());
456 Value castMask;
457 if (maskableOp.isMasked()) {
458 auto maskType = llvm::cast<ShapedType>(mask.getType());
459 auto castMaskType =
460 VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
461 maskType.getElementType());
462 castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
463 }
464
465 Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
466 loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
467 newOp = vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: castMask);
468
469 rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
470 ArrayRef<int64_t>{0});
471 return success();
472 }
473};
474
475struct LowerVectorMultiReductionPass
476 : public vector::impl::LowerVectorMultiReductionBase<
477 LowerVectorMultiReductionPass> {
478 LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
479 this->loweringStrategy = option;
480 }
481
482 void runOnOperation() override {
483 Operation *op = getOperation();
484 MLIRContext *context = op->getContext();
485
486 RewritePatternSet loweringPatterns(context);
487 populateVectorMultiReductionLoweringPatterns(loweringPatterns,
488 this->loweringStrategy);
489
490 if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
491 signalPassFailure();
492 }
493
494 void getDependentDialects(DialectRegistry &registry) const override {
495 registry.insert<vector::VectorDialect>();
496 }
497};
498
499} // namespace
500
501void mlir::vector::populateVectorMultiReductionLoweringPatterns(
502 RewritePatternSet &patterns, VectorMultiReductionLowering options,
503 PatternBenefit benefit) {
504 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
505 patterns.getContext(), options, benefit);
506 patterns.add<OneDimMultiReductionToTwoDim>(arg: patterns.getContext(), args&: benefit);
507 if (options == VectorMultiReductionLowering ::InnerReduction)
508 patterns.add<TwoDimMultiReductionToReduction>(arg: patterns.getContext(),
509 args&: benefit);
510 else
511 patterns.add<TwoDimMultiReductionToElementWise>(arg: patterns.getContext(),
512 args&: benefit);
513}
514
515std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
516 vector::VectorMultiReductionLowering option) {
517 return std::make_unique<LowerVectorMultiReductionPass>(option);
518}
519

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp