1 | //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements linalg transformation to break a reduction dimension |
10 | // between a parallel and a reduction dimension. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include <optional> |
15 | #include <utility> |
16 | |
17 | #include "mlir/Analysis/SliceAnalysis.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
21 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
23 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
24 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::linalg; |
29 | |
30 | FailureOr<SplitReductionResult> mlir::linalg::splitReduction( |
31 | RewriterBase &b, LinalgOp op, |
32 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
33 | OpBuilder::InsertionGuard guard(b); |
34 | b.setInsertionPoint(op); |
35 | |
36 | SplitReductionOptions control = controlSplitReductionFn(op); |
37 | int64_t ratio = control.ratio; |
38 | unsigned insertSplitIndex = control.index; |
39 | unsigned insertSplitDimension = control.index; |
40 | if (ratio <= 1) |
41 | return b.notifyMatchFailure(op, "split ratio needs to be greater than 1" ); |
42 | |
43 | SmallVector<unsigned> dims; |
44 | op.getReductionDims(dims); |
45 | |
46 | if (dims.size() != 1) |
47 | return b.notifyMatchFailure(op, "needs a single reduction dimension" ); |
48 | unsigned reductionDim = dims[0]; |
49 | if (control.innerParallel) { |
50 | insertSplitDimension = reductionDim + 1; |
51 | } |
52 | SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); |
53 | int64_t reductionDimSize = loopRanges[reductionDim]; |
54 | if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0) |
55 | return b.notifyMatchFailure( |
56 | op, "Reduction dimension not divisible by split ratio" ); |
57 | if (op.getNumDpsInits() != 1) |
58 | return b.notifyMatchFailure(op, "More than one output in split reduction" ); |
59 | if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size()) |
60 | return b.notifyMatchFailure(op, "Insert dimension position too large " |
61 | "compared to intermediate tensor size" ); |
62 | |
63 | SmallVector<Operation *, 4> combinerOps; |
64 | if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || |
65 | combinerOps.size() != 1) |
66 | return b.notifyMatchFailure(op, "Cannot match the reduction pattern" ); |
67 | |
68 | Operation *reductionOp = combinerOps[0]; |
69 | std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); |
70 | if (!identity.has_value()) |
71 | return b.notifyMatchFailure(op, "Unknown identity value for the reduction" ); |
72 | |
73 | Location loc = op->getLoc(); |
74 | SmallVector<Value> newInputs; |
75 | SmallVector<AffineMap> newMaps; |
76 | // Calculate the new shapes and indexing maps of the input operands. |
77 | for (OpOperand *operand : op.getDpsInputOperands()) { |
78 | AffineMap map = op.getMatchingIndexingMap(operand); |
79 | SmallVector<int64_t> newShape; |
80 | SmallVector<AffineExpr> exprs; |
81 | SmallVector<ReassociationIndices> reassociation; |
82 | unsigned index = 0; |
83 | for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { |
84 | unsigned dim = map.getDimPosition(idx); |
85 | if (reductionDim == dim) { |
86 | if (control.innerParallel) { |
87 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce |
88 | newShape.push_back(ratio); // parallel (insert) |
89 | exprs.push_back( |
90 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
91 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); |
92 | } else { |
93 | newShape.push_back(ratio); // parallel (insert) |
94 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce |
95 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); |
96 | exprs.push_back( |
97 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
98 | } |
99 | reassociation.push_back({index++, index++}); |
100 | continue; |
101 | } |
102 | newShape.push_back(op.getShape(operand)[idx]); |
103 | exprs.push_back( |
104 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
105 | reassociation.push_back({index++}); |
106 | } |
107 | newMaps.push_back( |
108 | AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); |
109 | // If the shape is unchanged the input doesn't change. |
110 | if (newShape == op.getShape(operand)) { |
111 | newInputs.push_back(operand->get()); |
112 | continue; |
113 | } |
114 | Type newType = RankedTensorType::get( |
115 | newShape, |
116 | cast<RankedTensorType>(operand->get().getType()).getElementType()); |
117 | |
118 | Value newInput = b.create<tensor::ExpandShapeOp>( |
119 | loc, newType, operand->get(), reassociation); |
120 | newInputs.push_back(newInput); |
121 | } |
122 | |
123 | // Calculate the new output map and shape, we insert the new dimension based |
124 | // on the index returned by `controlSplitReductionFn`. |
125 | SmallVector<int64_t> newOutputShape; |
126 | AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0)); |
127 | ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0)); |
128 | SmallVector<AffineExpr> outputExpr; |
129 | for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) { |
130 | if (insertSplitIndex == idx) { |
131 | newOutputShape.push_back(ratio); |
132 | outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); |
133 | } |
134 | if (idx < oldShape.size()) { |
135 | newOutputShape.push_back(oldShape[idx]); |
136 | unsigned dim = oldOutputMap.getDimPosition(idx); |
137 | outputExpr.push_back( |
138 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
139 | } |
140 | } |
141 | Value emptyOrAllocTensor; |
142 | if (useAlloc) { |
143 | emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>( |
144 | loc, |
145 | RankedTensorType::get(newOutputShape, |
146 | op.getRegionOutputArgs()[0].getType()), |
147 | ValueRange{}); |
148 | } else { |
149 | emptyOrAllocTensor = b.create<tensor::EmptyOp>( |
150 | loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); |
151 | } |
152 | Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
153 | Value identityTensor = |
154 | b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor) |
155 | .getResult(0); |
156 | |
157 | newMaps.push_back(Elt: AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, |
158 | op.getContext())); |
159 | SmallVector<utils::IteratorType> newIteratorTypes; |
160 | for (auto [index, iteratorType] : |
161 | llvm::enumerate(op.getIteratorTypesArray())) { |
162 | if (insertSplitDimension == index) |
163 | newIteratorTypes.push_back(utils::IteratorType::parallel); |
164 | newIteratorTypes.push_back(iteratorType); |
165 | } |
166 | if (insertSplitDimension == op.getIteratorTypesArray().size()) { |
167 | newIteratorTypes.push_back(utils::IteratorType::parallel); |
168 | } |
169 | // Create the new op matching the original op with an extra parallel |
170 | // dimension. |
171 | GenericOp genericOp = b.create<GenericOp>( |
172 | loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, |
173 | ValueRange({identityTensor}), newMaps, newIteratorTypes); |
174 | b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), |
175 | genericOp.getRegion().begin()); |
176 | |
177 | // Then create a new reduction that only reduce the newly added dimension |
178 | // from the previous op. |
179 | unsigned intermRank = newOutputShape.size(); |
180 | AffineMap inputMap = b.getMultiDimIdentityMap(rank: intermRank); |
181 | SmallVector<utils::IteratorType> reductionIteratorTypes; |
182 | SmallVector<AffineExpr> exprs; |
183 | for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: intermRank)) { |
184 | if (insertSplitIndex == i) { |
185 | reductionIteratorTypes.push_back(utils::IteratorType::reduction); |
186 | } else { |
187 | exprs.push_back(Elt: b.getAffineDimExpr(position: i)); |
188 | reductionIteratorTypes.push_back(utils::IteratorType::parallel); |
189 | } |
190 | } |
191 | AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); |
192 | SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; |
193 | |
194 | auto reduction = b.create<GenericOp>( |
195 | loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), |
196 | op.getDpsInits(), reductionMaps, reductionIteratorTypes, |
197 | [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { |
198 | Operation *clonedReductionOp = b.clone(op&: *reductionOp); |
199 | clonedReductionOp->setOperand(idx: 0, value: inputs[0]); |
200 | clonedReductionOp->setOperand(idx: 1, value: inputs[1]); |
201 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
202 | }); |
203 | b.replaceOp(op, reduction.getResults()); |
204 | |
205 | return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(), |
206 | identityTensor.getDefiningOp<FillOp>(), |
207 | cast<LinalgOp>(genericOp.getOperation()), |
208 | reduction}; |
209 | } |
210 | |
211 | /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) |
212 | /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into |
213 | /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better |
214 | /// done as a transform to enable better vectorization. |
215 | static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, |
216 | unsigned reductionDimPos, |
217 | int64_t reductionRatio) { |
218 | auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); |
219 | auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); |
220 | AffineMap map = op.getMatchingIndexingMap(&opOperand); |
221 | AffineMap idMap = |
222 | AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext()); |
223 | AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1); |
224 | AffineMap composeMap = shiftedIdMap.replace( |
225 | reductionDim, reductionDim * reductionRatio + reductionDimP1, |
226 | shiftedIdMap.getNumDims(), /*numSymbols=*/0); |
227 | return map.compose(map: composeMap); |
228 | } |
229 | |
230 | static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, |
231 | unsigned reductionDimPos, int64_t size) { |
232 | auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); |
233 | AffineMap map = op.getMatchingIndexingMap(&opOperand); |
234 | AffineMap idMap = |
235 | AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext()); |
236 | AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1); |
237 | return map.compose(map: shiftedIdMap).insertResult(expr: reductionDim, pos: reductionDimPos); |
238 | } |
239 | |
240 | /// Core rewrite implementation. |
241 | FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( |
242 | RewriterBase &b, LinalgOp op, |
243 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
244 | OpBuilder::InsertionGuard guard(b); |
245 | b.setInsertionPoint(op); |
246 | |
247 | // Matcher part, enforce preconditions. |
248 | SplitReductionOptions control = controlSplitReductionFn(op); |
249 | if (control.innerParallel) |
250 | return b.notifyMatchFailure(op, "innerParallel not supported" ); |
251 | |
252 | int64_t splitFactor = control.ratio; |
253 | unsigned insertSplitDimension = control.index; |
254 | if (splitFactor <= 1) |
255 | return b.notifyMatchFailure(op, "split factor needs to be greater than 1" ); |
256 | |
257 | SmallVector<unsigned> dims; |
258 | op.getReductionDims(dims); |
259 | if (dims.empty()) |
260 | return b.notifyMatchFailure(op, "needs at least 1 reduction dimension" ); |
261 | |
262 | unsigned reductionDimPos = dims[0]; |
263 | SmallVector<int64_t> loopRanges = op.getStaticLoopRanges(); |
264 | int64_t reductionDimSize = loopRanges[reductionDimPos]; |
265 | if (reductionDimSize == ShapedType::kDynamic || |
266 | reductionDimSize % splitFactor != 0 || |
267 | insertSplitDimension >= loopRanges.size()) |
268 | return b.notifyMatchFailure( |
269 | op, "first reduction dimension not divisible by split factor" ); |
270 | |
271 | SmallVector<Operation *> combinerOps; |
272 | if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) |
273 | return b.notifyMatchFailure(op, "cannot match a reduction pattern" ); |
274 | |
275 | SmallVector<TypedAttr> neutralElements; |
276 | for (Operation *reductionOp : combinerOps) { |
277 | std::optional<TypedAttr> neutralElement = |
278 | arith::getNeutralElement(reductionOp); |
279 | if (!neutralElement.has_value()) |
280 | return b.notifyMatchFailure(op, "cannot find neutral element." ); |
281 | neutralElements.push_back(*neutralElement); |
282 | } |
283 | if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) |
284 | return b.notifyMatchFailure(op, "unknown reduction neutral" ); |
285 | |
286 | // TODO: relax this when multi-reduction support is available. |
287 | if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size())) |
288 | return b.notifyMatchFailure(op, "expect one reduction per output" ); |
289 | |
290 | // Rewrite part. |
291 | // Step 1. Build the intermediate outputs filled with the proper |
292 | // neutralElements. Such outputs are of the same shape with an extra dimension |
293 | // inserted at `insertSplitDimension`. |
294 | // |
295 | // Consider a minimal example where `k` is reduced: |
296 | // O(i, j) += I(i, j, k) |
297 | // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0. |
298 | // The compute is rewritten as: |
299 | // a. O_i(kk, i, j) += I(i, j, 16 * k + kk) |
300 | // b. O(i, j) += O_i(kk, i, j) |
301 | // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. |
302 | Location loc = op->getLoc(); |
303 | MLIRContext *context = op.getContext(); |
304 | // For now assume outputs are 1-1 with reduction neutralElements. |
305 | // TODO: generalize when multi-reduction support is available. |
306 | SmallVector<Value> newOutputs; |
307 | newOutputs.reserve(N: op.getNumDpsInits()); |
308 | SmallVector<Operation *> emptyOrAllocTensorOps; |
309 | SmallVector<linalg::FillOp> fillOps; |
310 | fillOps.reserve(op.getNumDpsInits()); |
311 | for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) { |
312 | Value rankedTensor = std::get<0>(it).get(); |
313 | auto t = cast<RankedTensorType>(rankedTensor.getType()); |
314 | RankedTensorType newT = RankedTensorType::Builder(t).insertDim( |
315 | reductionDimSize / splitFactor, insertSplitDimension); |
316 | SmallVector<Value> dims = |
317 | tensor::createDynamicDimValues(b, loc, rankedTensor); |
318 | Value emptyOrAllocTensor; |
319 | if (useAlloc) { |
320 | emptyOrAllocTensor = |
321 | b.create<bufferization::AllocTensorOp>(loc, newT, dims); |
322 | } else { |
323 | emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(), |
324 | t.getElementType(), dims); |
325 | } |
326 | Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it)); |
327 | fillOps.push_back( |
328 | b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)); |
329 | newOutputs.push_back(fillOps.back().getResult(0)); |
330 | emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp()); |
331 | } |
332 | |
333 | // Step 2. Reindex / expand indexing maps. |
334 | // Reindex existing input indexings: k -> k * splitFactor + k'. |
335 | SmallVector<AffineMap> newMaps; |
336 | newMaps.reserve(N: op->getNumOperands() + 1); |
337 | for (OpOperand *o : op.getDpsInputOperands()) |
338 | newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); |
339 | // Provision a new indexing for the shape-only tensor. |
340 | auto nDims = op.getNumLoops() + 1; |
341 | auto redDim = getAffineDimExpr(position: reductionDimPos, context); |
342 | auto redDimP1 = getAffineDimExpr(position: reductionDimPos + 1, context); |
343 | newMaps.push_back(Elt: AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); |
344 | // Expand existing output indexings. |
345 | // TODO: a subset of these may not reduce along reducePos and should be |
346 | // reindexed: k -> k * splitFactor + k', when multi-reduction support is |
347 | // available. |
348 | for (OpOperand &o : op.getDpsInitsMutable()) |
349 | newMaps.push_back(insertParallelDim(op, o, reductionDimPos, |
350 | reductionDimSize / splitFactor)); |
351 | |
352 | // Step 3. Handle operands. |
353 | // Compute the new input tensors. |
354 | SmallVector<Value> newInputs = op.getDpsInputs(); |
355 | // Add a single shape-only tensor to carry the dimensions without resorting to |
356 | // more complex inversions. |
357 | newInputs.push_back(b.create<tensor::EmptyOp>( |
358 | loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, |
359 | b.getIntegerType(1))); |
360 | // Output tensors are already good to go. |
361 | |
362 | // Step 4. Create the new op matching the original op with an extra parallel |
363 | // dimension. |
364 | auto iteratorTypes = op.getIteratorTypesArray(); |
365 | iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, |
366 | utils::IteratorType::parallel); |
367 | GenericOp genericOp = |
368 | b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs, |
369 | newOutputs, newMaps, iteratorTypes); |
370 | b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), |
371 | genericOp.getRegion().begin()); |
372 | genericOp.getRegion().front().insertArgument(reductionDimPos, |
373 | b.getIntegerType(1), loc); |
374 | |
375 | // Step 5. Create new reduction ops that only reduce the newly added |
376 | // dimensions from the previous op. |
377 | // For now assume outputs are 1-1 with reduction ops. |
378 | // TODO: a subset of these may not reduce in the first place and do not |
379 | // require a new op, when multi-reduction support is available. |
380 | // TODO: all results can be handled in a single GenericOp, when |
381 | // multi-reduction support is available. |
382 | SmallVector<LinalgOp> results; |
383 | for (auto it : |
384 | llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) { |
385 | Value reindexedOutput = std::get<0>(it); |
386 | Value originalOutput = std::get<1>(it); |
387 | auto originalOutputType = cast<RankedTensorType>(originalOutput.getType()); |
388 | Operation *combinerOp = std::get<2>(it); |
389 | |
390 | AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); |
391 | SmallVector<AffineMap> indexingMaps = { |
392 | map, map.dropResult(insertSplitDimension)}; |
393 | SmallVector<utils::IteratorType> reductionIteratorTypes( |
394 | originalOutputType.getRank() + 1, utils::IteratorType::parallel); |
395 | reductionIteratorTypes[insertSplitDimension] = |
396 | utils::IteratorType::reduction; |
397 | |
398 | // clang-format off |
399 | auto reductionOp = b.create<GenericOp>( |
400 | loc, |
401 | originalOutputType, |
402 | reindexedOutput, |
403 | originalOutput, |
404 | indexingMaps, |
405 | reductionIteratorTypes, |
406 | [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { |
407 | Operation *clonedReductionOp = b.clone(*combinerOp); |
408 | clonedReductionOp->setOperand(0, bbArgs[0]); |
409 | clonedReductionOp->setOperand(1, bbArgs[1]); |
410 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
411 | }); |
412 | // clang-format on |
413 | |
414 | results.push_back(reductionOp); |
415 | } |
416 | |
417 | // TODO: extend when multi-reduction support is available. |
418 | assert(fillOps.size() == results.size() && results.size() == 1); |
419 | b.replaceOp(op, results.front()->getResults()); |
420 | return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(), |
421 | cast<LinalgOp>(genericOp.getOperation()), |
422 | results.front()}; |
423 | } |
424 | |
425 | namespace { |
426 | |
427 | struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { |
428 | /// Construct a generic pattern applied to all LinalgOp that verify `filter`. |
429 | LinalgSplitReduction(MLIRContext *context, |
430 | ControlSplitReductionFn controlSplitReductionFn, |
431 | bool useAlloc = false, PatternBenefit benefit = 1) |
432 | : OpInterfaceRewritePattern<LinalgOp>(context, benefit), |
433 | controlSplitReductionFn(std::move(controlSplitReductionFn)), |
434 | useAlloc(useAlloc) {} |
435 | |
436 | LogicalResult matchAndRewrite(LinalgOp op, |
437 | PatternRewriter &rewriter) const override { |
438 | return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc); |
439 | } |
440 | |
441 | private: |
442 | ControlSplitReductionFn controlSplitReductionFn; |
443 | bool useAlloc; |
444 | }; |
445 | |
446 | } // namespace |
447 | |
448 | void linalg::populateSplitReductionPattern( |
449 | RewritePatternSet &patterns, |
450 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
451 | patterns.add<LinalgSplitReduction>(arg: patterns.getContext(), |
452 | args: controlSplitReductionFn, args&: useAlloc); |
453 | } |
454 | |