1//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements linalg transformation to break a reduction dimension
10// between a parallel and a reduction dimension.
11//
12//===----------------------------------------------------------------------===//
13
14#include <optional>
15#include <utility>
16
17#include "mlir/Analysis/SliceAnalysis.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
22#include "mlir/Dialect/Linalg/Utils/Utils.h"
23#include "mlir/Dialect/Tensor/IR/Tensor.h"
24#include "mlir/Dialect/Tensor/Utils/Utils.h"
25#include "mlir/IR/PatternMatch.h"
26
27using namespace mlir;
28using namespace mlir::linalg;
29
30FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
31 RewriterBase &b, LinalgOp op,
32 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
33 OpBuilder::InsertionGuard guard(b);
34 b.setInsertionPoint(op);
35
36 SplitReductionOptions control = controlSplitReductionFn(op);
37 int64_t ratio = control.ratio;
38 unsigned insertSplitIndex = control.index;
39 unsigned insertSplitDimension = control.index;
40 if (ratio <= 1)
41 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42
43 SmallVector<unsigned> dims;
44 op.getReductionDims(dims);
45
46 if (dims.size() != 1)
47 return b.notifyMatchFailure(op, "needs a single reduction dimension");
48 unsigned reductionDim = dims[0];
49 if (control.innerParallel) {
50 insertSplitDimension = reductionDim + 1;
51 }
52 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53 int64_t reductionDimSize = loopRanges[reductionDim];
54 if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55 return b.notifyMatchFailure(
56 op, "Reduction dimension not divisible by split ratio");
57 if (op.getNumDpsInits() != 1)
58 return b.notifyMatchFailure(op, "More than one output in split reduction");
59 if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60 return b.notifyMatchFailure(op, "Insert dimension position too large "
61 "compared to intermediate tensor size");
62
63 SmallVector<Operation *, 4> combinerOps;
64 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
65 combinerOps.size() != 1)
66 return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67
68 Operation *reductionOp = combinerOps[0];
69 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
70 if (!identity.has_value())
71 return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
72
73 Location loc = op->getLoc();
74 SmallVector<Value> newInputs;
75 SmallVector<AffineMap> newMaps;
76 // Calculate the new shapes and indexing maps of the input operands.
77 for (OpOperand *operand : op.getDpsInputOperands()) {
78 AffineMap map = op.getMatchingIndexingMap(operand);
79 SmallVector<int64_t> newShape;
80 SmallVector<AffineExpr> exprs;
81 SmallVector<ReassociationIndices> reassociation;
82 unsigned index = 0;
83 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
84 unsigned dim = map.getDimPosition(idx);
85 if (reductionDim == dim) {
86 if (control.innerParallel) {
87 newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
88 newShape.push_back(ratio); // parallel (insert)
89 exprs.push_back(
90 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92 } else {
93 newShape.push_back(ratio); // parallel (insert)
94 newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
95 exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
96 exprs.push_back(
97 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98 }
99 reassociation.push_back({index++, index++});
100 continue;
101 }
102 newShape.push_back(op.getShape(operand)[idx]);
103 exprs.push_back(
104 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105 reassociation.push_back({index++});
106 }
107 newMaps.push_back(
108 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
109 // If the shape is unchanged the input doesn't change.
110 if (newShape == op.getShape(operand)) {
111 newInputs.push_back(operand->get());
112 continue;
113 }
114 Type newType = RankedTensorType::get(
115 newShape,
116 cast<RankedTensorType>(operand->get().getType()).getElementType());
117 Value newInput = b.create<tensor::ExpandShapeOp>(
118 loc, newType, operand->get(), reassociation);
119 newInputs.push_back(newInput);
120 }
121
122 // Calculate the new output map and shape, we insert the new dimension based
123 // on the index returned by `controlSplitReductionFn`.
124 SmallVector<int64_t> newOutputShape;
125 AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
126 ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
127 SmallVector<AffineExpr> outputExpr;
128 for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
129 if (insertSplitIndex == idx) {
130 newOutputShape.push_back(ratio);
131 outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
132 }
133 if (idx < oldShape.size()) {
134 newOutputShape.push_back(oldShape[idx]);
135 unsigned dim = oldOutputMap.getDimPosition(idx);
136 outputExpr.push_back(
137 b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
138 }
139 }
140 Value emptyOrAllocTensor;
141 if (useAlloc) {
142 emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
143 loc,
144 RankedTensorType::get(newOutputShape,
145 op.getRegionOutputArgs()[0].getType()),
146 ValueRange{});
147 } else {
148 emptyOrAllocTensor = b.create<tensor::EmptyOp>(
149 loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
150 }
151 Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
152 Value identityTensor =
153 b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
154 .getResult(0);
155
156 newMaps.push_back(Elt: AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
157 op.getContext()));
158 SmallVector<utils::IteratorType> newIteratorTypes;
159 for (auto [index, iteratorType] :
160 llvm::enumerate(op.getIteratorTypesArray())) {
161 if (insertSplitDimension == index)
162 newIteratorTypes.push_back(utils::IteratorType::parallel);
163 newIteratorTypes.push_back(iteratorType);
164 }
165 if (insertSplitDimension == op.getIteratorTypesArray().size()) {
166 newIteratorTypes.push_back(utils::IteratorType::parallel);
167 }
168 // Create the new op matching the original op with an extra parallel
169 // dimension.
170 GenericOp genericOp = b.create<GenericOp>(
171 loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
172 ValueRange({identityTensor}), newMaps, newIteratorTypes);
173 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
174 genericOp.getRegion().begin());
175
176 // Then create a new reduction that only reduce the newly added dimension
177 // from the previous op.
178 unsigned intermRank = newOutputShape.size();
179 AffineMap inputMap = b.getMultiDimIdentityMap(rank: intermRank);
180 SmallVector<utils::IteratorType> reductionIteratorTypes;
181 SmallVector<AffineExpr> exprs;
182 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: intermRank)) {
183 if (insertSplitIndex == i) {
184 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
185 } else {
186 exprs.push_back(Elt: b.getAffineDimExpr(position: i));
187 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
188 }
189 }
190 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
191 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
192
193 auto reduction = b.create<GenericOp>(
194 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
195 op.getDpsInits(), reductionMaps, reductionIteratorTypes,
196 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
197 Operation *clonedReductionOp = b.clone(*reductionOp);
198 clonedReductionOp->setOperand(0, inputs[0]);
199 clonedReductionOp->setOperand(1, inputs[1]);
200 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
201 });
202 b.replaceOp(op, reduction.getResults());
203
204 return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
205 identityTensor.getDefiningOp<FillOp>(),
206 cast<LinalgOp>(genericOp.getOperation()),
207 reduction};
208}
209
210/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
211/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
212/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
213/// done as a transform to enable better vectorization.
214static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
215 unsigned reductionDimPos,
216 int64_t reductionRatio) {
217 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
218 auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
219 AffineMap map = op.getMatchingIndexingMap(&opOperand);
220 AffineMap idMap =
221 AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext());
222 AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1);
223 AffineMap composeMap = shiftedIdMap.replace(
224 reductionDim, reductionDim * reductionRatio + reductionDimP1,
225 shiftedIdMap.getNumDims(), /*numSymbols=*/0);
226 return map.compose(map: composeMap);
227}
228
229static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
230 unsigned reductionDimPos, int64_t size) {
231 auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
232 AffineMap map = op.getMatchingIndexingMap(&opOperand);
233 AffineMap idMap =
234 AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext());
235 AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1);
236 return map.compose(map: shiftedIdMap).insertResult(expr: reductionDim, pos: reductionDimPos);
237}
238
239/// Core rewrite implementation.
240FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
241 RewriterBase &b, LinalgOp op,
242 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
243 OpBuilder::InsertionGuard guard(b);
244 b.setInsertionPoint(op);
245
246 // Matcher part, enforce preconditions.
247 SplitReductionOptions control = controlSplitReductionFn(op);
248 if (control.innerParallel)
249 return b.notifyMatchFailure(op, "innerParallel not supported");
250
251 int64_t splitFactor = control.ratio;
252 unsigned insertSplitDimension = control.index;
253 if (splitFactor <= 1)
254 return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
255
256 SmallVector<unsigned> dims;
257 op.getReductionDims(dims);
258 if (dims.empty())
259 return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
260
261 unsigned reductionDimPos = dims[0];
262 SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
263 int64_t reductionDimSize = loopRanges[reductionDimPos];
264 if (reductionDimSize == ShapedType::kDynamic ||
265 reductionDimSize % splitFactor != 0 ||
266 insertSplitDimension >= loopRanges.size())
267 return b.notifyMatchFailure(
268 op, "first reduction dimension not divisible by split factor");
269
270 SmallVector<Operation *> combinerOps;
271 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
272 return b.notifyMatchFailure(op, "cannot match a reduction pattern");
273
274 SmallVector<TypedAttr> neutralElements;
275 for (Operation *reductionOp : combinerOps) {
276 std::optional<TypedAttr> neutralElement =
277 arith::getNeutralElement(reductionOp);
278 if (!neutralElement.has_value())
279 return b.notifyMatchFailure(op, "cannot find neutral element.");
280 neutralElements.push_back(*neutralElement);
281 }
282 if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
283 return b.notifyMatchFailure(op, "unknown reduction neutral");
284
285 // TODO: relax this when multi-reduction support is available.
286 if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
287 return b.notifyMatchFailure(op, "expect one reduction per output");
288
289 // Rewrite part.
290 // Step 1. Build the intermediate outputs filled with the proper
291 // neutralElements. Such outputs are of the same shape with an extra dimension
292 // inserted at `insertSplitDimension`.
293 //
294 // Consider a minimal example where `k` is reduced:
295 // O(i, j) += I(i, j, k)
296 // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
297 // The compute is rewritten as:
298 // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
299 // b. O(i, j) += O_i(kk, i, j)
300 // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
301 Location loc = op->getLoc();
302 MLIRContext *context = op.getContext();
303 // For now assume outputs are 1-1 with reduction neutralElements.
304 // TODO: generalize when multi-reduction support is available.
305 SmallVector<Value> newOutputs;
306 newOutputs.reserve(N: op.getNumDpsInits());
307 SmallVector<Operation *> emptyOrAllocTensorOps;
308 SmallVector<linalg::FillOp> fillOps;
309 fillOps.reserve(op.getNumDpsInits());
310 for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
311 Value rankedTensor = std::get<0>(it).get();
312 auto t = cast<RankedTensorType>(rankedTensor.getType());
313 RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
314 reductionDimSize / splitFactor, insertSplitDimension);
315 SmallVector<Value> dims =
316 tensor::createDynamicDimValues(b, loc, rankedTensor);
317 Value emptyOrAllocTensor;
318 if (useAlloc) {
319 emptyOrAllocTensor =
320 b.create<bufferization::AllocTensorOp>(loc, newT, dims);
321 } else {
322 emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
323 t.getElementType(), dims);
324 }
325 Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
326 fillOps.push_back(
327 b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
328 newOutputs.push_back(fillOps.back().getResult(0));
329 emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
330 }
331
332 // Step 2. Reindex / expand indexing maps.
333 // Reindex existing input indexings: k -> k * splitFactor + k'.
334 SmallVector<AffineMap> newMaps;
335 newMaps.reserve(N: op->getNumOperands() + 1);
336 for (OpOperand *o : op.getDpsInputOperands())
337 newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
338 // Provision a new indexing for the shape-only tensor.
339 auto nDims = op.getNumLoops() + 1;
340 auto redDim = getAffineDimExpr(position: reductionDimPos, context);
341 auto redDimP1 = getAffineDimExpr(position: reductionDimPos + 1, context);
342 newMaps.push_back(Elt: AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
343 // Expand existing output indexings.
344 // TODO: a subset of these may not reduce along reducePos and should be
345 // reindexed: k -> k * splitFactor + k', when multi-reduction support is
346 // available.
347 for (OpOperand &o : op.getDpsInitsMutable())
348 newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
349 reductionDimSize / splitFactor));
350
351 // Step 3. Handle operands.
352 // Compute the new input tensors.
353 SmallVector<Value> newInputs = op.getDpsInputs();
354 // Add a single shape-only tensor to carry the dimensions without resorting to
355 // more complex inversions.
356 newInputs.push_back(b.create<tensor::EmptyOp>(
357 loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
358 b.getIntegerType(1)));
359 // Output tensors are already good to go.
360
361 // Step 4. Create the new op matching the original op with an extra parallel
362 // dimension.
363 auto iteratorTypes = op.getIteratorTypesArray();
364 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
365 utils::IteratorType::parallel);
366 GenericOp genericOp =
367 b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
368 newOutputs, newMaps, iteratorTypes);
369 b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
370 genericOp.getRegion().begin());
371 genericOp.getRegion().front().insertArgument(reductionDimPos,
372 b.getIntegerType(1), loc);
373
374 // Step 5. Create new reduction ops that only reduce the newly added
375 // dimensions from the previous op.
376 // For now assume outputs are 1-1 with reduction ops.
377 // TODO: a subset of these may not reduce in the first place and do not
378 // require a new op, when multi-reduction support is available.
379 // TODO: all results can be handled in a single GenericOp, when
380 // multi-reduction support is available.
381 SmallVector<LinalgOp> results;
382 for (auto it :
383 llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
384 Value reindexedOutput = std::get<0>(it);
385 Value originalOutput = std::get<1>(it);
386 auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
387 Operation *combinerOp = std::get<2>(it);
388
389 AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
390 SmallVector<AffineMap> indexingMaps = {
391 map, map.dropResult(insertSplitDimension)};
392 SmallVector<utils::IteratorType> reductionIteratorTypes(
393 originalOutputType.getRank() + 1, utils::IteratorType::parallel);
394 reductionIteratorTypes[insertSplitDimension] =
395 utils::IteratorType::reduction;
396
397 // clang-format off
398 auto reductionOp = b.create<GenericOp>(
399 loc,
400 originalOutputType,
401 reindexedOutput,
402 originalOutput,
403 indexingMaps,
404 reductionIteratorTypes,
405 [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
406 Operation *clonedReductionOp = b.clone(*combinerOp);
407 clonedReductionOp->setOperand(0, bbArgs[0]);
408 clonedReductionOp->setOperand(1, bbArgs[1]);
409 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
410 });
411 // clang-format on
412
413 results.push_back(reductionOp);
414 }
415
416 // TODO: extend when multi-reduction support is available.
417 assert(fillOps.size() == results.size() && results.size() == 1);
418 b.replaceOp(op, results.front()->getResults());
419 return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
420 cast<LinalgOp>(genericOp.getOperation()),
421 results.front()};
422}
423
424namespace {
425
426struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
427 /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
428 LinalgSplitReduction(MLIRContext *context,
429 ControlSplitReductionFn controlSplitReductionFn,
430 bool useAlloc = false, PatternBenefit benefit = 1)
431 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
432 controlSplitReductionFn(std::move(controlSplitReductionFn)),
433 useAlloc(useAlloc) {}
434
435 LogicalResult matchAndRewrite(LinalgOp op,
436 PatternRewriter &rewriter) const override {
437 return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
438 }
439
440private:
441 ControlSplitReductionFn controlSplitReductionFn;
442 bool useAlloc;
443};
444
445} // namespace
446
447void linalg::populateSplitReductionPattern(
448 RewritePatternSet &patterns,
449 const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
450 patterns.add<LinalgSplitReduction>(arg: patterns.getContext(),
451 args: controlSplitReductionFn, args&: useAlloc);
452}
453

source code of mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp