| 1 | //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements linalg transformation to break a reduction dimension |
| 10 | // between a parallel and a reduction dimension. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include <optional> |
| 15 | #include <utility> |
| 16 | |
| 17 | #include "mlir/Analysis/SliceAnalysis.h" |
| 18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 21 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 23 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 24 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 25 | #include "mlir/IR/PatternMatch.h" |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::linalg; |
| 29 | |
| 30 | FailureOr<SplitReductionResult> mlir::linalg::splitReduction( |
| 31 | RewriterBase &b, LinalgOp op, |
| 32 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
| 33 | OpBuilder::InsertionGuard guard(b); |
| 34 | b.setInsertionPoint(op); |
| 35 | |
| 36 | SplitReductionOptions control = controlSplitReductionFn(op); |
| 37 | int64_t ratio = control.ratio; |
| 38 | unsigned insertSplitIndex = control.index; |
| 39 | unsigned insertSplitDimension = control.index; |
| 40 | if (ratio <= 1) |
| 41 | return b.notifyMatchFailure(op, "split ratio needs to be greater than 1" ); |
| 42 | |
| 43 | SmallVector<unsigned> dims; |
| 44 | op.getReductionDims(dims); |
| 45 | |
| 46 | if (dims.size() != 1) |
| 47 | return b.notifyMatchFailure(op, "needs a single reduction dimension" ); |
| 48 | unsigned reductionDim = dims[0]; |
| 49 | if (control.innerParallel) { |
| 50 | insertSplitDimension = reductionDim + 1; |
| 51 | } |
| 52 | SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); |
| 53 | int64_t reductionDimSize = loopRanges[reductionDim]; |
| 54 | if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0) |
| 55 | return b.notifyMatchFailure( |
| 56 | op, "Reduction dimension not divisible by split ratio" ); |
| 57 | if (op.getNumDpsInits() != 1) |
| 58 | return b.notifyMatchFailure(op, "More than one output in split reduction" ); |
| 59 | if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size()) |
| 60 | return b.notifyMatchFailure(op, "Insert dimension position too large " |
| 61 | "compared to intermediate tensor size" ); |
| 62 | |
| 63 | SmallVector<Operation *, 4> combinerOps; |
| 64 | if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || |
| 65 | combinerOps.size() != 1) |
| 66 | return b.notifyMatchFailure(op, "Cannot match the reduction pattern" ); |
| 67 | |
| 68 | Operation *reductionOp = combinerOps[0]; |
| 69 | std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); |
| 70 | if (!identity.has_value()) |
| 71 | return b.notifyMatchFailure(op, "Unknown identity value for the reduction" ); |
| 72 | |
| 73 | Location loc = op->getLoc(); |
| 74 | SmallVector<Value> newInputs; |
| 75 | SmallVector<AffineMap> newMaps; |
| 76 | // Calculate the new shapes and indexing maps of the input operands. |
| 77 | for (OpOperand *operand : op.getDpsInputOperands()) { |
| 78 | AffineMap map = op.getMatchingIndexingMap(operand); |
| 79 | SmallVector<int64_t> newShape; |
| 80 | SmallVector<AffineExpr> exprs; |
| 81 | SmallVector<ReassociationIndices> reassociation; |
| 82 | unsigned index = 0; |
| 83 | for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { |
| 84 | unsigned dim = map.getDimPosition(idx); |
| 85 | if (reductionDim == dim) { |
| 86 | if (control.innerParallel) { |
| 87 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce |
| 88 | newShape.push_back(ratio); // parallel (insert) |
| 89 | exprs.push_back( |
| 90 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
| 91 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); |
| 92 | } else { |
| 93 | newShape.push_back(ratio); // parallel (insert) |
| 94 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce |
| 95 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); |
| 96 | exprs.push_back( |
| 97 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
| 98 | } |
| 99 | reassociation.push_back({index++, index++}); |
| 100 | continue; |
| 101 | } |
| 102 | newShape.push_back(op.getShape(operand)[idx]); |
| 103 | exprs.push_back( |
| 104 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
| 105 | reassociation.push_back({index++}); |
| 106 | } |
| 107 | newMaps.push_back( |
| 108 | AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); |
| 109 | // If the shape is unchanged the input doesn't change. |
| 110 | if (newShape == op.getShape(operand)) { |
| 111 | newInputs.push_back(operand->get()); |
| 112 | continue; |
| 113 | } |
| 114 | Type newType = RankedTensorType::get( |
| 115 | newShape, |
| 116 | cast<RankedTensorType>(operand->get().getType()).getElementType()); |
| 117 | |
| 118 | Value newInput = b.create<tensor::ExpandShapeOp>( |
| 119 | loc, newType, operand->get(), reassociation); |
| 120 | newInputs.push_back(newInput); |
| 121 | } |
| 122 | |
| 123 | // Calculate the new output map and shape, we insert the new dimension based |
| 124 | // on the index returned by `controlSplitReductionFn`. |
| 125 | SmallVector<int64_t> newOutputShape; |
| 126 | AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0)); |
| 127 | ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0)); |
| 128 | SmallVector<AffineExpr> outputExpr; |
| 129 | for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) { |
| 130 | if (insertSplitIndex == idx) { |
| 131 | newOutputShape.push_back(ratio); |
| 132 | outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); |
| 133 | } |
| 134 | if (idx < oldShape.size()) { |
| 135 | newOutputShape.push_back(oldShape[idx]); |
| 136 | unsigned dim = oldOutputMap.getDimPosition(idx); |
| 137 | outputExpr.push_back( |
| 138 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
| 139 | } |
| 140 | } |
| 141 | Value emptyOrAllocTensor; |
| 142 | if (useAlloc) { |
| 143 | emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>( |
| 144 | loc, |
| 145 | RankedTensorType::get(newOutputShape, |
| 146 | op.getRegionOutputArgs()[0].getType()), |
| 147 | ValueRange{}); |
| 148 | } else { |
| 149 | emptyOrAllocTensor = b.create<tensor::EmptyOp>( |
| 150 | loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); |
| 151 | } |
| 152 | Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
| 153 | Value identityTensor = |
| 154 | b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor) |
| 155 | .getResult(0); |
| 156 | |
| 157 | newMaps.push_back(Elt: AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, |
| 158 | op.getContext())); |
| 159 | SmallVector<utils::IteratorType> newIteratorTypes; |
| 160 | for (auto [index, iteratorType] : |
| 161 | llvm::enumerate(op.getIteratorTypesArray())) { |
| 162 | if (insertSplitDimension == index) |
| 163 | newIteratorTypes.push_back(utils::IteratorType::parallel); |
| 164 | newIteratorTypes.push_back(iteratorType); |
| 165 | } |
| 166 | if (insertSplitDimension == op.getIteratorTypesArray().size()) { |
| 167 | newIteratorTypes.push_back(utils::IteratorType::parallel); |
| 168 | } |
| 169 | // Create the new op matching the original op with an extra parallel |
| 170 | // dimension. |
| 171 | GenericOp genericOp = b.create<GenericOp>( |
| 172 | loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, |
| 173 | ValueRange({identityTensor}), newMaps, newIteratorTypes); |
| 174 | b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), |
| 175 | genericOp.getRegion().begin()); |
| 176 | |
| 177 | // Then create a new reduction that only reduce the newly added dimension |
| 178 | // from the previous op. |
| 179 | unsigned intermRank = newOutputShape.size(); |
| 180 | AffineMap inputMap = b.getMultiDimIdentityMap(rank: intermRank); |
| 181 | SmallVector<utils::IteratorType> reductionIteratorTypes; |
| 182 | SmallVector<AffineExpr> exprs; |
| 183 | for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: intermRank)) { |
| 184 | if (insertSplitIndex == i) { |
| 185 | reductionIteratorTypes.push_back(utils::IteratorType::reduction); |
| 186 | } else { |
| 187 | exprs.push_back(Elt: b.getAffineDimExpr(position: i)); |
| 188 | reductionIteratorTypes.push_back(utils::IteratorType::parallel); |
| 189 | } |
| 190 | } |
| 191 | AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); |
| 192 | SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; |
| 193 | |
| 194 | auto reduction = b.create<GenericOp>( |
| 195 | loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), |
| 196 | op.getDpsInits(), reductionMaps, reductionIteratorTypes, |
| 197 | [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { |
| 198 | Operation *clonedReductionOp = b.clone(op&: *reductionOp); |
| 199 | clonedReductionOp->setOperand(idx: 0, value: inputs[0]); |
| 200 | clonedReductionOp->setOperand(idx: 1, value: inputs[1]); |
| 201 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
| 202 | }); |
| 203 | b.replaceOp(op, reduction.getResults()); |
| 204 | |
| 205 | return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(), |
| 206 | identityTensor.getDefiningOp<FillOp>(), |
| 207 | cast<LinalgOp>(genericOp.getOperation()), |
| 208 | reduction}; |
| 209 | } |
| 210 | |
| 211 | /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) |
| 212 | /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into |
| 213 | /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better |
| 214 | /// done as a transform to enable better vectorization. |
| 215 | static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, |
| 216 | unsigned reductionDimPos, |
| 217 | int64_t reductionRatio) { |
| 218 | auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); |
| 219 | auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); |
| 220 | AffineMap map = op.getMatchingIndexingMap(&opOperand); |
| 221 | AffineMap idMap = |
| 222 | AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext()); |
| 223 | AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1); |
| 224 | AffineMap composeMap = shiftedIdMap.replace( |
| 225 | reductionDim, reductionDim * reductionRatio + reductionDimP1, |
| 226 | shiftedIdMap.getNumDims(), /*numSymbols=*/0); |
| 227 | return map.compose(map: composeMap); |
| 228 | } |
| 229 | |
| 230 | static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, |
| 231 | unsigned reductionDimPos, int64_t size) { |
| 232 | auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); |
| 233 | AffineMap map = op.getMatchingIndexingMap(&opOperand); |
| 234 | AffineMap idMap = |
| 235 | AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext()); |
| 236 | AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1); |
| 237 | return map.compose(map: shiftedIdMap).insertResult(expr: reductionDim, pos: reductionDimPos); |
| 238 | } |
| 239 | |
| 240 | /// Core rewrite implementation. |
| 241 | FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( |
| 242 | RewriterBase &b, LinalgOp op, |
| 243 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
| 244 | OpBuilder::InsertionGuard guard(b); |
| 245 | b.setInsertionPoint(op); |
| 246 | |
| 247 | // Matcher part, enforce preconditions. |
| 248 | SplitReductionOptions control = controlSplitReductionFn(op); |
| 249 | if (control.innerParallel) |
| 250 | return b.notifyMatchFailure(op, "innerParallel not supported" ); |
| 251 | |
| 252 | int64_t splitFactor = control.ratio; |
| 253 | unsigned insertSplitDimension = control.index; |
| 254 | if (splitFactor <= 1) |
| 255 | return b.notifyMatchFailure(op, "split factor needs to be greater than 1" ); |
| 256 | |
| 257 | SmallVector<unsigned> dims; |
| 258 | op.getReductionDims(dims); |
| 259 | if (dims.empty()) |
| 260 | return b.notifyMatchFailure(op, "needs at least 1 reduction dimension" ); |
| 261 | |
| 262 | unsigned reductionDimPos = dims[0]; |
| 263 | SmallVector<int64_t> loopRanges = op.getStaticLoopRanges(); |
| 264 | int64_t reductionDimSize = loopRanges[reductionDimPos]; |
| 265 | if (reductionDimSize == ShapedType::kDynamic || |
| 266 | reductionDimSize % splitFactor != 0 || |
| 267 | insertSplitDimension >= loopRanges.size()) |
| 268 | return b.notifyMatchFailure( |
| 269 | op, "first reduction dimension not divisible by split factor" ); |
| 270 | |
| 271 | SmallVector<Operation *> combinerOps; |
| 272 | if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) |
| 273 | return b.notifyMatchFailure(op, "cannot match a reduction pattern" ); |
| 274 | |
| 275 | SmallVector<TypedAttr> neutralElements; |
| 276 | for (Operation *reductionOp : combinerOps) { |
| 277 | std::optional<TypedAttr> neutralElement = |
| 278 | arith::getNeutralElement(reductionOp); |
| 279 | if (!neutralElement.has_value()) |
| 280 | return b.notifyMatchFailure(op, "cannot find neutral element." ); |
| 281 | neutralElements.push_back(*neutralElement); |
| 282 | } |
| 283 | if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) |
| 284 | return b.notifyMatchFailure(op, "unknown reduction neutral" ); |
| 285 | |
| 286 | // TODO: relax this when multi-reduction support is available. |
| 287 | if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size())) |
| 288 | return b.notifyMatchFailure(op, "expect one reduction per output" ); |
| 289 | |
| 290 | // Rewrite part. |
| 291 | // Step 1. Build the intermediate outputs filled with the proper |
| 292 | // neutralElements. Such outputs are of the same shape with an extra dimension |
| 293 | // inserted at `insertSplitDimension`. |
| 294 | // |
| 295 | // Consider a minimal example where `k` is reduced: |
| 296 | // O(i, j) += I(i, j, k) |
| 297 | // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0. |
| 298 | // The compute is rewritten as: |
| 299 | // a. O_i(kk, i, j) += I(i, j, 16 * k + kk) |
| 300 | // b. O(i, j) += O_i(kk, i, j) |
| 301 | // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. |
| 302 | Location loc = op->getLoc(); |
| 303 | MLIRContext *context = op.getContext(); |
| 304 | // For now assume outputs are 1-1 with reduction neutralElements. |
| 305 | // TODO: generalize when multi-reduction support is available. |
| 306 | SmallVector<Value> newOutputs; |
| 307 | newOutputs.reserve(N: op.getNumDpsInits()); |
| 308 | SmallVector<Operation *> emptyOrAllocTensorOps; |
| 309 | SmallVector<linalg::FillOp> fillOps; |
| 310 | fillOps.reserve(op.getNumDpsInits()); |
| 311 | for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) { |
| 312 | Value rankedTensor = std::get<0>(it).get(); |
| 313 | auto t = cast<RankedTensorType>(rankedTensor.getType()); |
| 314 | RankedTensorType newT = RankedTensorType::Builder(t).insertDim( |
| 315 | reductionDimSize / splitFactor, insertSplitDimension); |
| 316 | SmallVector<Value> dims = |
| 317 | tensor::createDynamicDimValues(b, loc, rankedTensor); |
| 318 | Value emptyOrAllocTensor; |
| 319 | if (useAlloc) { |
| 320 | emptyOrAllocTensor = |
| 321 | b.create<bufferization::AllocTensorOp>(loc, newT, dims); |
| 322 | } else { |
| 323 | emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(), |
| 324 | t.getElementType(), dims); |
| 325 | } |
| 326 | Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it)); |
| 327 | fillOps.push_back( |
| 328 | b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)); |
| 329 | newOutputs.push_back(fillOps.back().getResult(0)); |
| 330 | emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp()); |
| 331 | } |
| 332 | |
| 333 | // Step 2. Reindex / expand indexing maps. |
| 334 | // Reindex existing input indexings: k -> k * splitFactor + k'. |
| 335 | SmallVector<AffineMap> newMaps; |
| 336 | newMaps.reserve(N: op->getNumOperands() + 1); |
| 337 | for (OpOperand *o : op.getDpsInputOperands()) |
| 338 | newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); |
| 339 | // Provision a new indexing for the shape-only tensor. |
| 340 | auto nDims = op.getNumLoops() + 1; |
| 341 | auto redDim = getAffineDimExpr(position: reductionDimPos, context); |
| 342 | auto redDimP1 = getAffineDimExpr(position: reductionDimPos + 1, context); |
| 343 | newMaps.push_back(Elt: AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); |
| 344 | // Expand existing output indexings. |
| 345 | // TODO: a subset of these may not reduce along reducePos and should be |
| 346 | // reindexed: k -> k * splitFactor + k', when multi-reduction support is |
| 347 | // available. |
| 348 | for (OpOperand &o : op.getDpsInitsMutable()) |
| 349 | newMaps.push_back(insertParallelDim(op, o, reductionDimPos, |
| 350 | reductionDimSize / splitFactor)); |
| 351 | |
| 352 | // Step 3. Handle operands. |
| 353 | // Compute the new input tensors. |
| 354 | SmallVector<Value> newInputs = op.getDpsInputs(); |
| 355 | // Add a single shape-only tensor to carry the dimensions without resorting to |
| 356 | // more complex inversions. |
| 357 | newInputs.push_back(b.create<tensor::EmptyOp>( |
| 358 | loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, |
| 359 | b.getIntegerType(1))); |
| 360 | // Output tensors are already good to go. |
| 361 | |
| 362 | // Step 4. Create the new op matching the original op with an extra parallel |
| 363 | // dimension. |
| 364 | auto iteratorTypes = op.getIteratorTypesArray(); |
| 365 | iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, |
| 366 | utils::IteratorType::parallel); |
| 367 | GenericOp genericOp = |
| 368 | b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs, |
| 369 | newOutputs, newMaps, iteratorTypes); |
| 370 | b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), |
| 371 | genericOp.getRegion().begin()); |
| 372 | genericOp.getRegion().front().insertArgument(reductionDimPos, |
| 373 | b.getIntegerType(1), loc); |
| 374 | |
| 375 | // Step 5. Create new reduction ops that only reduce the newly added |
| 376 | // dimensions from the previous op. |
| 377 | // For now assume outputs are 1-1 with reduction ops. |
| 378 | // TODO: a subset of these may not reduce in the first place and do not |
| 379 | // require a new op, when multi-reduction support is available. |
| 380 | // TODO: all results can be handled in a single GenericOp, when |
| 381 | // multi-reduction support is available. |
| 382 | SmallVector<LinalgOp> results; |
| 383 | for (auto it : |
| 384 | llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) { |
| 385 | Value reindexedOutput = std::get<0>(it); |
| 386 | Value originalOutput = std::get<1>(it); |
| 387 | auto originalOutputType = cast<RankedTensorType>(originalOutput.getType()); |
| 388 | Operation *combinerOp = std::get<2>(it); |
| 389 | |
| 390 | AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); |
| 391 | SmallVector<AffineMap> indexingMaps = { |
| 392 | map, map.dropResult(insertSplitDimension)}; |
| 393 | SmallVector<utils::IteratorType> reductionIteratorTypes( |
| 394 | originalOutputType.getRank() + 1, utils::IteratorType::parallel); |
| 395 | reductionIteratorTypes[insertSplitDimension] = |
| 396 | utils::IteratorType::reduction; |
| 397 | |
| 398 | // clang-format off |
| 399 | auto reductionOp = b.create<GenericOp>( |
| 400 | loc, |
| 401 | originalOutputType, |
| 402 | reindexedOutput, |
| 403 | originalOutput, |
| 404 | indexingMaps, |
| 405 | reductionIteratorTypes, |
| 406 | [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { |
| 407 | Operation *clonedReductionOp = b.clone(*combinerOp); |
| 408 | clonedReductionOp->setOperand(0, bbArgs[0]); |
| 409 | clonedReductionOp->setOperand(1, bbArgs[1]); |
| 410 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
| 411 | }); |
| 412 | // clang-format on |
| 413 | |
| 414 | results.push_back(reductionOp); |
| 415 | } |
| 416 | |
| 417 | // TODO: extend when multi-reduction support is available. |
| 418 | assert(fillOps.size() == results.size() && results.size() == 1); |
| 419 | b.replaceOp(op, results.front()->getResults()); |
| 420 | return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(), |
| 421 | cast<LinalgOp>(genericOp.getOperation()), |
| 422 | results.front()}; |
| 423 | } |
| 424 | |
| 425 | namespace { |
| 426 | |
| 427 | struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { |
| 428 | /// Construct a generic pattern applied to all LinalgOp that verify `filter`. |
| 429 | LinalgSplitReduction(MLIRContext *context, |
| 430 | ControlSplitReductionFn controlSplitReductionFn, |
| 431 | bool useAlloc = false, PatternBenefit benefit = 1) |
| 432 | : OpInterfaceRewritePattern<LinalgOp>(context, benefit), |
| 433 | controlSplitReductionFn(std::move(controlSplitReductionFn)), |
| 434 | useAlloc(useAlloc) {} |
| 435 | |
| 436 | LogicalResult matchAndRewrite(LinalgOp op, |
| 437 | PatternRewriter &rewriter) const override { |
| 438 | return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc); |
| 439 | } |
| 440 | |
| 441 | private: |
| 442 | ControlSplitReductionFn controlSplitReductionFn; |
| 443 | bool useAlloc; |
| 444 | }; |
| 445 | |
| 446 | } // namespace |
| 447 | |
| 448 | void linalg::populateSplitReductionPattern( |
| 449 | RewritePatternSet &patterns, |
| 450 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
| 451 | patterns.add<LinalgSplitReduction>(arg: patterns.getContext(), |
| 452 | args: controlSplitReductionFn, args&: useAlloc); |
| 453 | } |
| 454 | |