1 | //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements linalg transformation to break a reduction dimension |
10 | // between a parallel and a reduction dimension. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include <optional> |
15 | #include <utility> |
16 | |
17 | #include "mlir/Analysis/SliceAnalysis.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
21 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
23 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
24 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | |
27 | using namespace mlir; |
28 | using namespace mlir::linalg; |
29 | |
30 | FailureOr<SplitReductionResult> mlir::linalg::splitReduction( |
31 | RewriterBase &b, LinalgOp op, |
32 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
33 | OpBuilder::InsertionGuard guard(b); |
34 | b.setInsertionPoint(op); |
35 | |
36 | SplitReductionOptions control = controlSplitReductionFn(op); |
37 | int64_t ratio = control.ratio; |
38 | unsigned insertSplitIndex = control.index; |
39 | unsigned insertSplitDimension = control.index; |
40 | if (ratio <= 1) |
41 | return b.notifyMatchFailure(op, "split ratio needs to be greater than 1" ); |
42 | |
43 | SmallVector<unsigned> dims; |
44 | op.getReductionDims(dims); |
45 | |
46 | if (dims.size() != 1) |
47 | return b.notifyMatchFailure(op, "needs a single reduction dimension" ); |
48 | unsigned reductionDim = dims[0]; |
49 | if (control.innerParallel) { |
50 | insertSplitDimension = reductionDim + 1; |
51 | } |
52 | SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); |
53 | int64_t reductionDimSize = loopRanges[reductionDim]; |
54 | if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0) |
55 | return b.notifyMatchFailure( |
56 | op, "Reduction dimension not divisible by split ratio" ); |
57 | if (op.getNumDpsInits() != 1) |
58 | return b.notifyMatchFailure(op, "More than one output in split reduction" ); |
59 | if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size()) |
60 | return b.notifyMatchFailure(op, "Insert dimension position too large " |
61 | "compared to intermediate tensor size" ); |
62 | |
63 | SmallVector<Operation *, 4> combinerOps; |
64 | if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || |
65 | combinerOps.size() != 1) |
66 | return b.notifyMatchFailure(op, "Cannot match the reduction pattern" ); |
67 | |
68 | Operation *reductionOp = combinerOps[0]; |
69 | std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); |
70 | if (!identity.has_value()) |
71 | return b.notifyMatchFailure(op, "Unknown identity value for the reduction" ); |
72 | |
73 | Location loc = op->getLoc(); |
74 | SmallVector<Value> newInputs; |
75 | SmallVector<AffineMap> newMaps; |
76 | // Calculate the new shapes and indexing maps of the input operands. |
77 | for (OpOperand *operand : op.getDpsInputOperands()) { |
78 | AffineMap map = op.getMatchingIndexingMap(operand); |
79 | SmallVector<int64_t> newShape; |
80 | SmallVector<AffineExpr> exprs; |
81 | SmallVector<ReassociationIndices> reassociation; |
82 | unsigned index = 0; |
83 | for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { |
84 | unsigned dim = map.getDimPosition(idx); |
85 | if (reductionDim == dim) { |
86 | if (control.innerParallel) { |
87 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce |
88 | newShape.push_back(ratio); // parallel (insert) |
89 | exprs.push_back( |
90 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
91 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); |
92 | } else { |
93 | newShape.push_back(ratio); // parallel (insert) |
94 | newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce |
95 | exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); |
96 | exprs.push_back( |
97 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
98 | } |
99 | reassociation.push_back({index++, index++}); |
100 | continue; |
101 | } |
102 | newShape.push_back(op.getShape(operand)[idx]); |
103 | exprs.push_back( |
104 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
105 | reassociation.push_back({index++}); |
106 | } |
107 | newMaps.push_back( |
108 | AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); |
109 | // If the shape is unchanged the input doesn't change. |
110 | if (newShape == op.getShape(operand)) { |
111 | newInputs.push_back(operand->get()); |
112 | continue; |
113 | } |
114 | Type newType = RankedTensorType::get( |
115 | newShape, |
116 | cast<RankedTensorType>(operand->get().getType()).getElementType()); |
117 | Value newInput = b.create<tensor::ExpandShapeOp>( |
118 | loc, newType, operand->get(), reassociation); |
119 | newInputs.push_back(newInput); |
120 | } |
121 | |
122 | // Calculate the new output map and shape, we insert the new dimension based |
123 | // on the index returned by `controlSplitReductionFn`. |
124 | SmallVector<int64_t> newOutputShape; |
125 | AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0)); |
126 | ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0)); |
127 | SmallVector<AffineExpr> outputExpr; |
128 | for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) { |
129 | if (insertSplitIndex == idx) { |
130 | newOutputShape.push_back(ratio); |
131 | outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); |
132 | } |
133 | if (idx < oldShape.size()) { |
134 | newOutputShape.push_back(oldShape[idx]); |
135 | unsigned dim = oldOutputMap.getDimPosition(idx); |
136 | outputExpr.push_back( |
137 | b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); |
138 | } |
139 | } |
140 | Value emptyOrAllocTensor; |
141 | if (useAlloc) { |
142 | emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>( |
143 | loc, |
144 | RankedTensorType::get(newOutputShape, |
145 | op.getRegionOutputArgs()[0].getType()), |
146 | ValueRange{}); |
147 | } else { |
148 | emptyOrAllocTensor = b.create<tensor::EmptyOp>( |
149 | loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); |
150 | } |
151 | Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
152 | Value identityTensor = |
153 | b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor) |
154 | .getResult(0); |
155 | |
156 | newMaps.push_back(Elt: AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, |
157 | op.getContext())); |
158 | SmallVector<utils::IteratorType> newIteratorTypes; |
159 | for (auto [index, iteratorType] : |
160 | llvm::enumerate(op.getIteratorTypesArray())) { |
161 | if (insertSplitDimension == index) |
162 | newIteratorTypes.push_back(utils::IteratorType::parallel); |
163 | newIteratorTypes.push_back(iteratorType); |
164 | } |
165 | if (insertSplitDimension == op.getIteratorTypesArray().size()) { |
166 | newIteratorTypes.push_back(utils::IteratorType::parallel); |
167 | } |
168 | // Create the new op matching the original op with an extra parallel |
169 | // dimension. |
170 | GenericOp genericOp = b.create<GenericOp>( |
171 | loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, |
172 | ValueRange({identityTensor}), newMaps, newIteratorTypes); |
173 | b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), |
174 | genericOp.getRegion().begin()); |
175 | |
176 | // Then create a new reduction that only reduce the newly added dimension |
177 | // from the previous op. |
178 | unsigned intermRank = newOutputShape.size(); |
179 | AffineMap inputMap = b.getMultiDimIdentityMap(rank: intermRank); |
180 | SmallVector<utils::IteratorType> reductionIteratorTypes; |
181 | SmallVector<AffineExpr> exprs; |
182 | for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: intermRank)) { |
183 | if (insertSplitIndex == i) { |
184 | reductionIteratorTypes.push_back(utils::IteratorType::reduction); |
185 | } else { |
186 | exprs.push_back(Elt: b.getAffineDimExpr(position: i)); |
187 | reductionIteratorTypes.push_back(utils::IteratorType::parallel); |
188 | } |
189 | } |
190 | AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); |
191 | SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; |
192 | |
193 | auto reduction = b.create<GenericOp>( |
194 | loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), |
195 | op.getDpsInits(), reductionMaps, reductionIteratorTypes, |
196 | [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { |
197 | Operation *clonedReductionOp = b.clone(*reductionOp); |
198 | clonedReductionOp->setOperand(0, inputs[0]); |
199 | clonedReductionOp->setOperand(1, inputs[1]); |
200 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
201 | }); |
202 | b.replaceOp(op, reduction.getResults()); |
203 | |
204 | return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(), |
205 | identityTensor.getDefiningOp<FillOp>(), |
206 | cast<LinalgOp>(genericOp.getOperation()), |
207 | reduction}; |
208 | } |
209 | |
210 | /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) |
211 | /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into |
212 | /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better |
213 | /// done as a transform to enable better vectorization. |
214 | static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, |
215 | unsigned reductionDimPos, |
216 | int64_t reductionRatio) { |
217 | auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); |
218 | auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); |
219 | AffineMap map = op.getMatchingIndexingMap(&opOperand); |
220 | AffineMap idMap = |
221 | AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext()); |
222 | AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1); |
223 | AffineMap composeMap = shiftedIdMap.replace( |
224 | reductionDim, reductionDim * reductionRatio + reductionDimP1, |
225 | shiftedIdMap.getNumDims(), /*numSymbols=*/0); |
226 | return map.compose(map: composeMap); |
227 | } |
228 | |
229 | static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, |
230 | unsigned reductionDimPos, int64_t size) { |
231 | auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); |
232 | AffineMap map = op.getMatchingIndexingMap(&opOperand); |
233 | AffineMap idMap = |
234 | AffineMap::getMultiDimIdentityMap(numDims: map.getNumDims(), context: op.getContext()); |
235 | AffineMap shiftedIdMap = idMap.shiftDims(shift: 1, /*offset=*/reductionDimPos + 1); |
236 | return map.compose(map: shiftedIdMap).insertResult(expr: reductionDim, pos: reductionDimPos); |
237 | } |
238 | |
239 | /// Core rewrite implementation. |
240 | FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( |
241 | RewriterBase &b, LinalgOp op, |
242 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
243 | OpBuilder::InsertionGuard guard(b); |
244 | b.setInsertionPoint(op); |
245 | |
246 | // Matcher part, enforce preconditions. |
247 | SplitReductionOptions control = controlSplitReductionFn(op); |
248 | if (control.innerParallel) |
249 | return b.notifyMatchFailure(op, "innerParallel not supported" ); |
250 | |
251 | int64_t splitFactor = control.ratio; |
252 | unsigned insertSplitDimension = control.index; |
253 | if (splitFactor <= 1) |
254 | return b.notifyMatchFailure(op, "split factor needs to be greater than 1" ); |
255 | |
256 | SmallVector<unsigned> dims; |
257 | op.getReductionDims(dims); |
258 | if (dims.empty()) |
259 | return b.notifyMatchFailure(op, "needs at least 1 reduction dimension" ); |
260 | |
261 | unsigned reductionDimPos = dims[0]; |
262 | SmallVector<int64_t> loopRanges = op.getStaticLoopRanges(); |
263 | int64_t reductionDimSize = loopRanges[reductionDimPos]; |
264 | if (reductionDimSize == ShapedType::kDynamic || |
265 | reductionDimSize % splitFactor != 0 || |
266 | insertSplitDimension >= loopRanges.size()) |
267 | return b.notifyMatchFailure( |
268 | op, "first reduction dimension not divisible by split factor" ); |
269 | |
270 | SmallVector<Operation *> combinerOps; |
271 | if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) |
272 | return b.notifyMatchFailure(op, "cannot match a reduction pattern" ); |
273 | |
274 | SmallVector<TypedAttr> neutralElements; |
275 | for (Operation *reductionOp : combinerOps) { |
276 | std::optional<TypedAttr> neutralElement = |
277 | arith::getNeutralElement(reductionOp); |
278 | if (!neutralElement.has_value()) |
279 | return b.notifyMatchFailure(op, "cannot find neutral element." ); |
280 | neutralElements.push_back(*neutralElement); |
281 | } |
282 | if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) |
283 | return b.notifyMatchFailure(op, "unknown reduction neutral" ); |
284 | |
285 | // TODO: relax this when multi-reduction support is available. |
286 | if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size())) |
287 | return b.notifyMatchFailure(op, "expect one reduction per output" ); |
288 | |
289 | // Rewrite part. |
290 | // Step 1. Build the intermediate outputs filled with the proper |
291 | // neutralElements. Such outputs are of the same shape with an extra dimension |
292 | // inserted at `insertSplitDimension`. |
293 | // |
294 | // Consider a minimal example where `k` is reduced: |
295 | // O(i, j) += I(i, j, k) |
296 | // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0. |
297 | // The compute is rewritten as: |
298 | // a. O_i(kk, i, j) += I(i, j, 16 * k + kk) |
299 | // b. O(i, j) += O_i(kk, i, j) |
300 | // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. |
301 | Location loc = op->getLoc(); |
302 | MLIRContext *context = op.getContext(); |
303 | // For now assume outputs are 1-1 with reduction neutralElements. |
304 | // TODO: generalize when multi-reduction support is available. |
305 | SmallVector<Value> newOutputs; |
306 | newOutputs.reserve(N: op.getNumDpsInits()); |
307 | SmallVector<Operation *> emptyOrAllocTensorOps; |
308 | SmallVector<linalg::FillOp> fillOps; |
309 | fillOps.reserve(op.getNumDpsInits()); |
310 | for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) { |
311 | Value rankedTensor = std::get<0>(it).get(); |
312 | auto t = cast<RankedTensorType>(rankedTensor.getType()); |
313 | RankedTensorType newT = RankedTensorType::Builder(t).insertDim( |
314 | reductionDimSize / splitFactor, insertSplitDimension); |
315 | SmallVector<Value> dims = |
316 | tensor::createDynamicDimValues(b, loc, rankedTensor); |
317 | Value emptyOrAllocTensor; |
318 | if (useAlloc) { |
319 | emptyOrAllocTensor = |
320 | b.create<bufferization::AllocTensorOp>(loc, newT, dims); |
321 | } else { |
322 | emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(), |
323 | t.getElementType(), dims); |
324 | } |
325 | Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it)); |
326 | fillOps.push_back( |
327 | b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)); |
328 | newOutputs.push_back(fillOps.back().getResult(0)); |
329 | emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp()); |
330 | } |
331 | |
332 | // Step 2. Reindex / expand indexing maps. |
333 | // Reindex existing input indexings: k -> k * splitFactor + k'. |
334 | SmallVector<AffineMap> newMaps; |
335 | newMaps.reserve(N: op->getNumOperands() + 1); |
336 | for (OpOperand *o : op.getDpsInputOperands()) |
337 | newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); |
338 | // Provision a new indexing for the shape-only tensor. |
339 | auto nDims = op.getNumLoops() + 1; |
340 | auto redDim = getAffineDimExpr(position: reductionDimPos, context); |
341 | auto redDimP1 = getAffineDimExpr(position: reductionDimPos + 1, context); |
342 | newMaps.push_back(Elt: AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); |
343 | // Expand existing output indexings. |
344 | // TODO: a subset of these may not reduce along reducePos and should be |
345 | // reindexed: k -> k * splitFactor + k', when multi-reduction support is |
346 | // available. |
347 | for (OpOperand &o : op.getDpsInitsMutable()) |
348 | newMaps.push_back(insertParallelDim(op, o, reductionDimPos, |
349 | reductionDimSize / splitFactor)); |
350 | |
351 | // Step 3. Handle operands. |
352 | // Compute the new input tensors. |
353 | SmallVector<Value> newInputs = op.getDpsInputs(); |
354 | // Add a single shape-only tensor to carry the dimensions without resorting to |
355 | // more complex inversions. |
356 | newInputs.push_back(b.create<tensor::EmptyOp>( |
357 | loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, |
358 | b.getIntegerType(1))); |
359 | // Output tensors are already good to go. |
360 | |
361 | // Step 4. Create the new op matching the original op with an extra parallel |
362 | // dimension. |
363 | auto iteratorTypes = op.getIteratorTypesArray(); |
364 | iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, |
365 | utils::IteratorType::parallel); |
366 | GenericOp genericOp = |
367 | b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs, |
368 | newOutputs, newMaps, iteratorTypes); |
369 | b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), |
370 | genericOp.getRegion().begin()); |
371 | genericOp.getRegion().front().insertArgument(reductionDimPos, |
372 | b.getIntegerType(1), loc); |
373 | |
374 | // Step 5. Create new reduction ops that only reduce the newly added |
375 | // dimensions from the previous op. |
376 | // For now assume outputs are 1-1 with reduction ops. |
377 | // TODO: a subset of these may not reduce in the first place and do not |
378 | // require a new op, when multi-reduction support is available. |
379 | // TODO: all results can be handled in a single GenericOp, when |
380 | // multi-reduction support is available. |
381 | SmallVector<LinalgOp> results; |
382 | for (auto it : |
383 | llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) { |
384 | Value reindexedOutput = std::get<0>(it); |
385 | Value originalOutput = std::get<1>(it); |
386 | auto originalOutputType = cast<RankedTensorType>(originalOutput.getType()); |
387 | Operation *combinerOp = std::get<2>(it); |
388 | |
389 | AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); |
390 | SmallVector<AffineMap> indexingMaps = { |
391 | map, map.dropResult(insertSplitDimension)}; |
392 | SmallVector<utils::IteratorType> reductionIteratorTypes( |
393 | originalOutputType.getRank() + 1, utils::IteratorType::parallel); |
394 | reductionIteratorTypes[insertSplitDimension] = |
395 | utils::IteratorType::reduction; |
396 | |
397 | // clang-format off |
398 | auto reductionOp = b.create<GenericOp>( |
399 | loc, |
400 | originalOutputType, |
401 | reindexedOutput, |
402 | originalOutput, |
403 | indexingMaps, |
404 | reductionIteratorTypes, |
405 | [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { |
406 | Operation *clonedReductionOp = b.clone(*combinerOp); |
407 | clonedReductionOp->setOperand(0, bbArgs[0]); |
408 | clonedReductionOp->setOperand(1, bbArgs[1]); |
409 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
410 | }); |
411 | // clang-format on |
412 | |
413 | results.push_back(reductionOp); |
414 | } |
415 | |
416 | // TODO: extend when multi-reduction support is available. |
417 | assert(fillOps.size() == results.size() && results.size() == 1); |
418 | b.replaceOp(op, results.front()->getResults()); |
419 | return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(), |
420 | cast<LinalgOp>(genericOp.getOperation()), |
421 | results.front()}; |
422 | } |
423 | |
424 | namespace { |
425 | |
426 | struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { |
427 | /// Construct a generic pattern applied to all LinalgOp that verify `filter`. |
428 | LinalgSplitReduction(MLIRContext *context, |
429 | ControlSplitReductionFn controlSplitReductionFn, |
430 | bool useAlloc = false, PatternBenefit benefit = 1) |
431 | : OpInterfaceRewritePattern<LinalgOp>(context, benefit), |
432 | controlSplitReductionFn(std::move(controlSplitReductionFn)), |
433 | useAlloc(useAlloc) {} |
434 | |
435 | LogicalResult matchAndRewrite(LinalgOp op, |
436 | PatternRewriter &rewriter) const override { |
437 | return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc); |
438 | } |
439 | |
440 | private: |
441 | ControlSplitReductionFn controlSplitReductionFn; |
442 | bool useAlloc; |
443 | }; |
444 | |
445 | } // namespace |
446 | |
447 | void linalg::populateSplitReductionPattern( |
448 | RewritePatternSet &patterns, |
449 | const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { |
450 | patterns.add<LinalgSplitReduction>(arg: patterns.getContext(), |
451 | args: controlSplitReductionFn, args&: useAlloc); |
452 | } |
453 | |