1 | //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements scf.parallel to scf.for + async.execute conversion pass. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Async/Passes.h" |
14 | |
15 | #include "PassDetail.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Async/IR/Async.h" |
18 | #include "mlir/Dialect/Async/Transforms.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/SCF/IR/SCF.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
23 | #include "mlir/IR/Matchers.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/Support/LLVM.h" |
26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | #include "mlir/Transforms/RegionUtils.h" |
28 | #include <utility> |
29 | |
30 | namespace mlir { |
31 | #define GEN_PASS_DEF_ASYNCPARALLELFOR |
32 | #include "mlir/Dialect/Async/Passes.h.inc" |
33 | } // namespace mlir |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::async; |
37 | |
38 | #define DEBUG_TYPE "async-parallel-for" |
39 | |
40 | namespace { |
41 | |
42 | // Rewrite scf.parallel operation into multiple concurrent async.execute |
43 | // operations over non overlapping subranges of the original loop. |
44 | // |
45 | // Example: |
46 | // |
47 | // scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { |
48 | // "do_some_compute"(%i, %j): () -> () |
49 | // } |
50 | // |
51 | // Converted to: |
52 | // |
53 | // // Parallel compute function that executes the parallel body region for |
54 | // // a subset of the parallel iteration space defined by the one-dimensional |
55 | // // compute block index. |
56 | // func parallel_compute_function(%block_index : index, %block_size : index, |
57 | // <parallel operation properties>, ...) { |
58 | // // Compute multi-dimensional loop bounds for %block_index. |
59 | // %block_lbi, %block_lbj = ... |
60 | // %block_ubi, %block_ubj = ... |
61 | // |
62 | // // Clone parallel operation body into the scf.for loop nest. |
63 | // scf.for %i = %blockLbi to %blockUbi { |
64 | // scf.for %j = block_lbj to %block_ubj { |
65 | // "do_some_compute"(%i, %j): () -> () |
66 | // } |
67 | // } |
68 | // } |
69 | // |
70 | // And a dispatch function depending on the `asyncDispatch` option. |
71 | // |
72 | // When async dispatch is on: (pseudocode) |
73 | // |
74 | // %block_size = ... compute parallel compute block size |
75 | // %block_count = ... compute the number of compute blocks |
76 | // |
77 | // func @async_dispatch(%block_start : index, %block_end : index, ...) { |
78 | // // Keep splitting block range until we reached a range of size 1. |
79 | // while (%block_end - %block_start > 1) { |
80 | // %mid_index = block_start + (block_end - block_start) / 2; |
81 | // async.execute { call @async_dispatch(%mid_index, %block_end); } |
82 | // %block_end = %mid_index |
83 | // } |
84 | // |
85 | // // Call parallel compute function for a single block. |
86 | // call @parallel_compute_fn(%block_start, %block_size, ...); |
87 | // } |
88 | // |
89 | // // Launch async dispatch for [0, block_count) range. |
90 | // call @async_dispatch(%c0, %block_count); |
91 | // |
92 | // When async dispatch is off: |
93 | // |
94 | // %block_size = ... compute parallel compute block size |
95 | // %block_count = ... compute the number of compute blocks |
96 | // |
97 | // scf.for %block_index = %c0 to %block_count { |
98 | // call @parallel_compute_fn(%block_index, %block_size, ...) |
99 | // } |
100 | // |
101 | struct AsyncParallelForPass |
102 | : public impl::AsyncParallelForBase<AsyncParallelForPass> { |
103 | AsyncParallelForPass() = default; |
104 | |
105 | AsyncParallelForPass(bool asyncDispatch, int32_t numWorkerThreads, |
106 | int32_t minTaskSize) { |
107 | this->asyncDispatch = asyncDispatch; |
108 | this->numWorkerThreads = numWorkerThreads; |
109 | this->minTaskSize = minTaskSize; |
110 | } |
111 | |
112 | void runOnOperation() override; |
113 | }; |
114 | |
115 | struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { |
116 | public: |
117 | AsyncParallelForRewrite( |
118 | MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, |
119 | AsyncMinTaskSizeComputationFunction computeMinTaskSize) |
120 | : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), |
121 | numWorkerThreads(numWorkerThreads), |
122 | computeMinTaskSize(std::move(computeMinTaskSize)) {} |
123 | |
124 | LogicalResult matchAndRewrite(scf::ParallelOp op, |
125 | PatternRewriter &rewriter) const override; |
126 | |
127 | private: |
128 | bool asyncDispatch; |
129 | int32_t numWorkerThreads; |
130 | AsyncMinTaskSizeComputationFunction computeMinTaskSize; |
131 | }; |
132 | |
133 | struct ParallelComputeFunctionType { |
134 | FunctionType type; |
135 | SmallVector<Value> captures; |
136 | }; |
137 | |
138 | // Helper struct to parse parallel compute function argument list. |
139 | struct ParallelComputeFunctionArgs { |
140 | BlockArgument blockIndex(); |
141 | BlockArgument blockSize(); |
142 | ArrayRef<BlockArgument> tripCounts(); |
143 | ArrayRef<BlockArgument> lowerBounds(); |
144 | ArrayRef<BlockArgument> upperBounds(); |
145 | ArrayRef<BlockArgument> steps(); |
146 | ArrayRef<BlockArgument> captures(); |
147 | |
148 | unsigned numLoops; |
149 | ArrayRef<BlockArgument> args; |
150 | }; |
151 | |
152 | struct ParallelComputeFunctionBounds { |
153 | SmallVector<IntegerAttr> tripCounts; |
154 | SmallVector<IntegerAttr> lowerBounds; |
155 | SmallVector<IntegerAttr> upperBounds; |
156 | SmallVector<IntegerAttr> steps; |
157 | }; |
158 | |
159 | struct ParallelComputeFunction { |
160 | unsigned numLoops; |
161 | func::FuncOp func; |
162 | llvm::SmallVector<Value> captures; |
163 | }; |
164 | |
165 | } // namespace |
166 | |
167 | BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; } |
168 | BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; } |
169 | |
170 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() { |
171 | return args.drop_front(N: 2).take_front(N: numLoops); |
172 | } |
173 | |
174 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() { |
175 | return args.drop_front(N: 2 + 1 * numLoops).take_front(N: numLoops); |
176 | } |
177 | |
178 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::upperBounds() { |
179 | return args.drop_front(N: 2 + 2 * numLoops).take_front(N: numLoops); |
180 | } |
181 | |
182 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() { |
183 | return args.drop_front(N: 2 + 3 * numLoops).take_front(N: numLoops); |
184 | } |
185 | |
186 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() { |
187 | return args.drop_front(N: 2 + 4 * numLoops); |
188 | } |
189 | |
190 | template <typename ValueRange> |
191 | static SmallVector<IntegerAttr> integerConstants(ValueRange values) { |
192 | SmallVector<IntegerAttr> attrs(values.size()); |
193 | for (unsigned i = 0; i < values.size(); ++i) |
194 | matchPattern(values[i], m_Constant(&attrs[i])); |
195 | return attrs; |
196 | } |
197 | |
198 | // Converts one-dimensional iteration index in the [0, tripCount) interval |
199 | // into multidimensional iteration coordinate. |
200 | static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index, |
201 | ArrayRef<Value> tripCounts) { |
202 | SmallVector<Value> coords(tripCounts.size()); |
203 | assert(!tripCounts.empty() && "tripCounts must be not empty" ); |
204 | |
205 | for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) { |
206 | coords[i] = b.create<arith::RemSIOp>(index, tripCounts[i]); |
207 | index = b.create<arith::DivSIOp>(index, tripCounts[i]); |
208 | } |
209 | |
210 | return coords; |
211 | } |
212 | |
213 | // Returns a function type and implicit captures for a parallel compute |
214 | // function. We'll need a list of implicit captures to setup block and value |
215 | // mapping when we'll clone the body of the parallel operation. |
216 | static ParallelComputeFunctionType |
217 | getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) { |
218 | // Values implicitly captured by the parallel operation. |
219 | llvm::SetVector<Value> captures; |
220 | getUsedValuesDefinedAbove(op.getRegion(), op.getRegion(), captures); |
221 | |
222 | SmallVector<Type> inputs; |
223 | inputs.reserve(N: 2 + 4 * op.getNumLoops() + captures.size()); |
224 | |
225 | Type indexTy = rewriter.getIndexType(); |
226 | |
227 | // One-dimensional iteration space defined by the block index and size. |
228 | inputs.push_back(Elt: indexTy); // blockIndex |
229 | inputs.push_back(Elt: indexTy); // blockSize |
230 | |
231 | // Multi-dimensional parallel iteration space defined by the loop trip counts. |
232 | for (unsigned i = 0; i < op.getNumLoops(); ++i) |
233 | inputs.push_back(Elt: indexTy); // loop tripCount |
234 | |
235 | // Parallel operation lower bound, upper bound and step. Lower bound, upper |
236 | // bound and step passed as contiguous arguments: |
237 | // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...) |
238 | for (unsigned i = 0; i < op.getNumLoops(); ++i) { |
239 | inputs.push_back(Elt: indexTy); // lower bound |
240 | inputs.push_back(Elt: indexTy); // upper bound |
241 | inputs.push_back(Elt: indexTy); // step |
242 | } |
243 | |
244 | // Types of the implicit captures. |
245 | for (Value capture : captures) |
246 | inputs.push_back(Elt: capture.getType()); |
247 | |
248 | // Convert captures to vector for later convenience. |
249 | SmallVector<Value> capturesVector(captures.begin(), captures.end()); |
250 | return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector}; |
251 | } |
252 | |
253 | // Create a parallel compute fuction from the parallel operation. |
254 | static ParallelComputeFunction createParallelComputeFunction( |
255 | scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, |
256 | unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) { |
257 | OpBuilder::InsertionGuard guard(rewriter); |
258 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
259 | |
260 | ModuleOp module = op->getParentOfType<ModuleOp>(); |
261 | |
262 | ParallelComputeFunctionType computeFuncType = |
263 | getParallelComputeFunctionType(op, rewriter); |
264 | |
265 | FunctionType type = computeFuncType.type; |
266 | func::FuncOp func = func::FuncOp::create( |
267 | op.getLoc(), |
268 | numBlockAlignedInnerLoops > 0 ? "parallel_compute_fn_with_aligned_loops" |
269 | : "parallel_compute_fn" , |
270 | type); |
271 | func.setPrivate(); |
272 | |
273 | // Insert function into the module symbol table and assign it unique name. |
274 | SymbolTable symbolTable(module); |
275 | symbolTable.insert(symbol: func); |
276 | rewriter.getListener()->notifyOperationInserted(op: func, /*previous=*/{}); |
277 | |
278 | // Create function entry block. |
279 | Block *block = |
280 | b.createBlock(&func.getBody(), func.begin(), type.getInputs(), |
281 | SmallVector<Location>(type.getNumInputs(), op.getLoc())); |
282 | b.setInsertionPointToEnd(block); |
283 | |
284 | ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()}; |
285 | |
286 | // Block iteration position defined by the block index and size. |
287 | BlockArgument blockIndex = args.blockIndex(); |
288 | BlockArgument blockSize = args.blockSize(); |
289 | |
290 | // Constants used below. |
291 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
292 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
293 | |
294 | // Materialize known constants as constant operation in the function body. |
295 | auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) { |
296 | return llvm::to_vector( |
297 | Range: llvm::map_range(C: llvm::zip(t&: args, u&: attrs), F: [&](auto tuple) -> Value { |
298 | if (IntegerAttr attr = std::get<1>(tuple)) |
299 | return b.create<arith::ConstantOp>(attr); |
300 | return std::get<0>(tuple); |
301 | })); |
302 | }; |
303 | |
304 | // Multi-dimensional parallel iteration space defined by the loop trip counts. |
305 | auto tripCounts = values(args.tripCounts(), bounds.tripCounts); |
306 | |
307 | // Parallel operation lower bound and step. |
308 | auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds); |
309 | auto steps = values(args.steps(), bounds.steps); |
310 | |
311 | // Remaining arguments are implicit captures of the parallel operation. |
312 | ArrayRef<BlockArgument> captures = args.captures(); |
313 | |
314 | // Compute a product of trip counts to get the size of the flattened |
315 | // one-dimensional iteration space. |
316 | Value tripCount = tripCounts[0]; |
317 | for (unsigned i = 1; i < tripCounts.size(); ++i) |
318 | tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]); |
319 | |
320 | // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: |
321 | // blockFirstIndex = blockIndex * blockSize |
322 | Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize); |
323 | |
324 | // The last one-dimensional index in the block defined by the `blockIndex`: |
325 | // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1 |
326 | Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize); |
327 | Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount); |
328 | Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1); |
329 | |
330 | // Convert one-dimensional indices to multi-dimensional coordinates. |
331 | auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); |
332 | auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts); |
333 | |
334 | // Compute loops upper bounds derived from the block last coordinates: |
335 | // blockEndCoord[i] = blockLastCoord[i] + 1 |
336 | // |
337 | // Block first and last coordinates can be the same along the outer compute |
338 | // dimension when inner compute dimension contains multiple blocks. |
339 | SmallVector<Value> blockEndCoord(op.getNumLoops()); |
340 | for (size_t i = 0; i < blockLastCoord.size(); ++i) |
341 | blockEndCoord[i] = b.create<arith::AddIOp>(blockLastCoord[i], c1); |
342 | |
343 | // Construct a loop nest out of scf.for operations that will iterate over |
344 | // all coordinates in [blockFirstCoord, blockLastCoord] range. |
345 | using LoopBodyBuilder = |
346 | std::function<void(OpBuilder &, Location, Value, ValueRange)>; |
347 | using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>; |
348 | |
349 | // Parallel region induction variables computed from the multi-dimensional |
350 | // iteration coordinate using parallel operation bounds and step: |
351 | // |
352 | // computeBlockInductionVars[loopIdx] = |
353 | // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx] |
354 | SmallVector<Value> computeBlockInductionVars(op.getNumLoops()); |
355 | |
356 | // We need to know if we are in the first or last iteration of the |
357 | // multi-dimensional loop for each loop in the nest, so we can decide what |
358 | // loop bounds should we use for the nested loops: bounds defined by compute |
359 | // block interval, or bounds defined by the parallel operation. |
360 | // |
361 | // Example: 2d parallel operation |
362 | // i j |
363 | // loop sizes: [50, 50] |
364 | // first coord: [25, 25] |
365 | // last coord: [30, 30] |
366 | // |
367 | // If `i` is equal to 25 then iteration over `j` should start at 25, when `i` |
368 | // is between 25 and 30 it should start at 0. The upper bound for `j` should |
369 | // be 50, except when `i` is equal to 30, then it should also be 30. |
370 | // |
371 | // Value at ith position specifies if all loops in [0, i) range of the loop |
372 | // nest are in the first/last iteration. |
373 | SmallVector<Value> isBlockFirstCoord(op.getNumLoops()); |
374 | SmallVector<Value> isBlockLastCoord(op.getNumLoops()); |
375 | |
376 | // Builds inner loop nest inside async.execute operation that does all the |
377 | // work concurrently. |
378 | LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { |
379 | return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv, |
380 | ValueRange args) { |
381 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
382 | |
383 | // Compute induction variable for `loopIdx`. |
384 | computeBlockInductionVars[loopIdx] = b.create<arith::AddIOp>( |
385 | lowerBounds[loopIdx], b.create<arith::MulIOp>(iv, steps[loopIdx])); |
386 | |
387 | // Check if we are inside first or last iteration of the loop. |
388 | isBlockFirstCoord[loopIdx] = b.create<arith::CmpIOp>( |
389 | arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); |
390 | isBlockLastCoord[loopIdx] = b.create<arith::CmpIOp>( |
391 | arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); |
392 | |
393 | // Check if the previous loop is in its first or last iteration. |
394 | if (loopIdx > 0) { |
395 | isBlockFirstCoord[loopIdx] = b.create<arith::AndIOp>( |
396 | isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); |
397 | isBlockLastCoord[loopIdx] = b.create<arith::AndIOp>( |
398 | isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); |
399 | } |
400 | |
401 | // Keep building loop nest. |
402 | if (loopIdx < op.getNumLoops() - 1) { |
403 | if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) { |
404 | // For block aligned loops we always iterate starting from 0 up to |
405 | // the loop trip counts. |
406 | b.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(), |
407 | workLoopBuilder(loopIdx + 1)); |
408 | |
409 | } else { |
410 | // Select nested loop lower/upper bounds depending on our position in |
411 | // the multi-dimensional iteration space. |
412 | auto lb = b.create<arith::SelectOp>(isBlockFirstCoord[loopIdx], |
413 | blockFirstCoord[loopIdx + 1], c0); |
414 | |
415 | auto ub = b.create<arith::SelectOp>(isBlockLastCoord[loopIdx], |
416 | blockEndCoord[loopIdx + 1], |
417 | tripCounts[loopIdx + 1]); |
418 | |
419 | b.create<scf::ForOp>(lb, ub, c1, ValueRange(), |
420 | workLoopBuilder(loopIdx + 1)); |
421 | } |
422 | |
423 | b.create<scf::YieldOp>(loc); |
424 | return; |
425 | } |
426 | |
427 | // Copy the body of the parallel op into the inner-most loop. |
428 | IRMapping mapping; |
429 | mapping.map(op.getInductionVars(), computeBlockInductionVars); |
430 | mapping.map(computeFuncType.captures, captures); |
431 | |
432 | for (auto &bodyOp : op.getRegion().front().without_terminator()) |
433 | b.clone(bodyOp, mapping); |
434 | b.create<scf::YieldOp>(loc); |
435 | }; |
436 | }; |
437 | |
438 | b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), |
439 | workLoopBuilder(0)); |
440 | b.create<func::ReturnOp>(ValueRange()); |
441 | |
442 | return {op.getNumLoops(), func, std::move(computeFuncType.captures)}; |
443 | } |
444 | |
445 | // Creates recursive async dispatch function for the given parallel compute |
446 | // function. Dispatch function keeps splitting block range into halves until it |
447 | // reaches a single block, and then excecutes it inline. |
448 | // |
449 | // Function pseudocode (mix of C++ and MLIR): |
450 | // |
451 | // func @async_dispatch(%block_start : index, %block_end : index, ...) { |
452 | // |
453 | // // Keep splitting block range until we reached a range of size 1. |
454 | // while (%block_end - %block_start > 1) { |
455 | // %mid_index = block_start + (block_end - block_start) / 2; |
456 | // async.execute { call @async_dispatch(%mid_index, %block_end); } |
457 | // %block_end = %mid_index |
458 | // } |
459 | // |
460 | // // Call parallel compute function for a single block. |
461 | // call @parallel_compute_fn(%block_start, %block_size, ...); |
462 | // } |
463 | // |
464 | static func::FuncOp |
465 | createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, |
466 | PatternRewriter &rewriter) { |
467 | OpBuilder::InsertionGuard guard(rewriter); |
468 | Location loc = computeFunc.func.getLoc(); |
469 | ImplicitLocOpBuilder b(loc, rewriter); |
470 | |
471 | ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>(); |
472 | |
473 | ArrayRef<Type> computeFuncInputTypes = |
474 | computeFunc.func.getFunctionType().getInputs(); |
475 | |
476 | // Compared to the parallel compute function async dispatch function takes |
477 | // additional !async.group argument. Also instead of a single `blockIndex` it |
478 | // takes `blockStart` and `blockEnd` arguments to define the range of |
479 | // dispatched blocks. |
480 | SmallVector<Type> inputTypes; |
481 | inputTypes.push_back(async::GroupType::get(rewriter.getContext())); |
482 | inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument |
483 | inputTypes.append(in_start: computeFuncInputTypes.begin(), in_end: computeFuncInputTypes.end()); |
484 | |
485 | FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange()); |
486 | func::FuncOp func = func::FuncOp::create(loc, "async_dispatch_fn" , type); |
487 | func.setPrivate(); |
488 | |
489 | // Insert function into the module symbol table and assign it unique name. |
490 | SymbolTable symbolTable(module); |
491 | symbolTable.insert(symbol: func); |
492 | rewriter.getListener()->notifyOperationInserted(op: func, /*previous=*/{}); |
493 | |
494 | // Create function entry block. |
495 | Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(), |
496 | SmallVector<Location>(type.getNumInputs(), loc)); |
497 | b.setInsertionPointToEnd(block); |
498 | |
499 | Type indexTy = b.getIndexType(); |
500 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
501 | Value c2 = b.create<arith::ConstantIndexOp>(args: 2); |
502 | |
503 | // Get the async group that will track async dispatch completion. |
504 | Value group = block->getArgument(i: 0); |
505 | |
506 | // Get the block iteration range: [blockStart, blockEnd) |
507 | Value blockStart = block->getArgument(i: 1); |
508 | Value blockEnd = block->getArgument(i: 2); |
509 | |
510 | // Create a work splitting while loop for the [blockStart, blockEnd) range. |
511 | SmallVector<Type> types = {indexTy, indexTy}; |
512 | SmallVector<Value> operands = {blockStart, blockEnd}; |
513 | SmallVector<Location> locations = {loc, loc}; |
514 | |
515 | // Create a recursive dispatch loop. |
516 | scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands); |
517 | Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations); |
518 | Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations); |
519 | |
520 | // Setup dispatch loop condition block: decide if we need to go into the |
521 | // `after` block and launch one more async dispatch. |
522 | { |
523 | b.setInsertionPointToEnd(before); |
524 | Value start = before->getArgument(i: 0); |
525 | Value end = before->getArgument(i: 1); |
526 | Value distance = b.create<arith::SubIOp>(end, start); |
527 | Value dispatch = |
528 | b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1); |
529 | b.create<scf::ConditionOp>(dispatch, before->getArguments()); |
530 | } |
531 | |
532 | // Setup the async dispatch loop body: recursively call dispatch function |
533 | // for the seconds half of the original range and go to the next iteration. |
534 | { |
535 | b.setInsertionPointToEnd(after); |
536 | Value start = after->getArgument(i: 0); |
537 | Value end = after->getArgument(i: 1); |
538 | Value distance = b.create<arith::SubIOp>(end, start); |
539 | Value halfDistance = b.create<arith::DivSIOp>(distance, c2); |
540 | Value midIndex = b.create<arith::AddIOp>(start, halfDistance); |
541 | |
542 | // Call parallel compute function inside the async.execute region. |
543 | auto executeBodyBuilder = [&](OpBuilder &executeBuilder, |
544 | Location executeLoc, ValueRange executeArgs) { |
545 | // Update the original `blockStart` and `blockEnd` with new range. |
546 | SmallVector<Value> operands{block->getArguments().begin(), |
547 | block->getArguments().end()}; |
548 | operands[1] = midIndex; |
549 | operands[2] = end; |
550 | |
551 | executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(), |
552 | func.getResultTypes(), operands); |
553 | executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); |
554 | }; |
555 | |
556 | // Create async.execute operation to dispatch half of the block range. |
557 | auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), |
558 | executeBodyBuilder); |
559 | b.create<AddToGroupOp>(indexTy, execute.getToken(), group); |
560 | b.create<scf::YieldOp>(ValueRange({start, midIndex})); |
561 | } |
562 | |
563 | // After dispatching async operations to process the tail of the block range |
564 | // call the parallel compute function for the first block of the range. |
565 | b.setInsertionPointAfter(whileOp); |
566 | |
567 | // Drop async dispatch specific arguments: async group, block start and end. |
568 | auto forwardedInputs = block->getArguments().drop_front(N: 3); |
569 | SmallVector<Value> computeFuncOperands = {blockStart}; |
570 | computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end()); |
571 | |
572 | b.create<func::CallOp>(computeFunc.func.getSymName(), |
573 | computeFunc.func.getResultTypes(), |
574 | computeFuncOperands); |
575 | b.create<func::ReturnOp>(ValueRange()); |
576 | |
577 | return func; |
578 | } |
579 | |
580 | // Launch async dispatch of the parallel compute function. |
581 | static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, |
582 | ParallelComputeFunction ¶llelComputeFunction, |
583 | scf::ParallelOp op, Value blockSize, |
584 | Value blockCount, |
585 | const SmallVector<Value> &tripCounts) { |
586 | MLIRContext *ctx = op->getContext(); |
587 | |
588 | // Add one more level of indirection to dispatch parallel compute functions |
589 | // using async operations and recursive work splitting. |
590 | func::FuncOp asyncDispatchFunction = |
591 | createAsyncDispatchFunction(parallelComputeFunction, rewriter); |
592 | |
593 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
594 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
595 | |
596 | // Appends operands shared by async dispatch and parallel compute functions to |
597 | // the given operands vector. |
598 | auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) { |
599 | operands.append(RHS: tripCounts); |
600 | operands.append(op.getLowerBound().begin(), op.getLowerBound().end()); |
601 | operands.append(op.getUpperBound().begin(), op.getUpperBound().end()); |
602 | operands.append(op.getStep().begin(), op.getStep().end()); |
603 | operands.append(parallelComputeFunction.captures); |
604 | }; |
605 | |
606 | // Check if the block size is one, in this case we can skip the async dispatch |
607 | // completely. If this will be known statically, then canonicalization will |
608 | // erase async group operations. |
609 | Value isSingleBlock = |
610 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1); |
611 | |
612 | auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { |
613 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
614 | |
615 | // Call parallel compute function for the single block. |
616 | SmallVector<Value> operands = {c0, blockSize}; |
617 | appendBlockComputeOperands(operands); |
618 | |
619 | b.create<func::CallOp>(parallelComputeFunction.func.getSymName(), |
620 | parallelComputeFunction.func.getResultTypes(), |
621 | operands); |
622 | b.create<scf::YieldOp>(); |
623 | }; |
624 | |
625 | auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { |
626 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
627 | |
628 | // Create an async.group to wait on all async tokens from the concurrent |
629 | // execution of multiple parallel compute function. First block will be |
630 | // executed synchronously in the caller thread. |
631 | Value groupSize = b.create<arith::SubIOp>(blockCount, c1); |
632 | Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); |
633 | |
634 | // Launch async dispatch function for [0, blockCount) range. |
635 | SmallVector<Value> operands = {group, c0, blockCount, blockSize}; |
636 | appendBlockComputeOperands(operands); |
637 | |
638 | b.create<func::CallOp>(asyncDispatchFunction.getSymName(), |
639 | asyncDispatchFunction.getResultTypes(), operands); |
640 | |
641 | // Wait for the completion of all parallel compute operations. |
642 | b.create<AwaitAllOp>(group); |
643 | |
644 | b.create<scf::YieldOp>(); |
645 | }; |
646 | |
647 | // Dispatch either single block compute function, or launch async dispatch. |
648 | b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch); |
649 | } |
650 | |
651 | // Dispatch parallel compute functions by submitting all async compute tasks |
652 | // from a simple for loop in the caller thread. |
653 | static void |
654 | doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, |
655 | ParallelComputeFunction ¶llelComputeFunction, |
656 | scf::ParallelOp op, Value blockSize, Value blockCount, |
657 | const SmallVector<Value> &tripCounts) { |
658 | MLIRContext *ctx = op->getContext(); |
659 | |
660 | func::FuncOp compute = parallelComputeFunction.func; |
661 | |
662 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
663 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
664 | |
665 | // Create an async.group to wait on all async tokens from the concurrent |
666 | // execution of multiple parallel compute function. First block will be |
667 | // executed synchronously in the caller thread. |
668 | Value groupSize = b.create<arith::SubIOp>(blockCount, c1); |
669 | Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); |
670 | |
671 | // Call parallel compute function for all blocks. |
672 | using LoopBodyBuilder = |
673 | std::function<void(OpBuilder &, Location, Value, ValueRange)>; |
674 | |
675 | // Returns parallel compute function operands to process the given block. |
676 | auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> { |
677 | SmallVector<Value> computeFuncOperands = {blockIndex, blockSize}; |
678 | computeFuncOperands.append(RHS: tripCounts); |
679 | computeFuncOperands.append(op.getLowerBound().begin(), |
680 | op.getLowerBound().end()); |
681 | computeFuncOperands.append(op.getUpperBound().begin(), |
682 | op.getUpperBound().end()); |
683 | computeFuncOperands.append(op.getStep().begin(), op.getStep().end()); |
684 | computeFuncOperands.append(parallelComputeFunction.captures); |
685 | return computeFuncOperands; |
686 | }; |
687 | |
688 | // Induction variable is the index of the block: [0, blockCount). |
689 | LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc, |
690 | Value iv, ValueRange args) { |
691 | ImplicitLocOpBuilder b(loc, loopBuilder); |
692 | |
693 | // Call parallel compute function inside the async.execute region. |
694 | auto executeBodyBuilder = [&](OpBuilder &executeBuilder, |
695 | Location executeLoc, ValueRange executeArgs) { |
696 | executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(), |
697 | compute.getResultTypes(), |
698 | computeFuncOperands(iv)); |
699 | executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); |
700 | }; |
701 | |
702 | // Create async.execute operation to launch parallel computate function. |
703 | auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), |
704 | executeBodyBuilder); |
705 | b.create<AddToGroupOp>(rewriter.getIndexType(), execute.getToken(), group); |
706 | b.create<scf::YieldOp>(); |
707 | }; |
708 | |
709 | // Iterate over all compute blocks and launch parallel compute operations. |
710 | b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder); |
711 | |
712 | // Call parallel compute function for the first block in the caller thread. |
713 | b.create<func::CallOp>(compute.getSymName(), compute.getResultTypes(), |
714 | computeFuncOperands(c0)); |
715 | |
716 | // Wait for the completion of all async compute operations. |
717 | b.create<AwaitAllOp>(group); |
718 | } |
719 | |
720 | LogicalResult |
721 | AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, |
722 | PatternRewriter &rewriter) const { |
723 | // We do not currently support rewrite for parallel op with reductions. |
724 | if (op.getNumReductions() != 0) |
725 | return failure(); |
726 | |
727 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
728 | |
729 | // Computing minTaskSize emits IR and can be implemented as executing a cost |
730 | // model on the body of the scf.parallel. Thus it needs to be computed before |
731 | // the body of the scf.parallel has been manipulated. |
732 | Value minTaskSize = computeMinTaskSize(b, op); |
733 | |
734 | // Make sure that all constants will be inside the parallel operation body to |
735 | // reduce the number of parallel compute function arguments. |
736 | cloneConstantsIntoTheRegion(op.getRegion(), rewriter); |
737 | |
738 | // Compute trip count for each loop induction variable: |
739 | // tripCount = ceil_div(upperBound - lowerBound, step); |
740 | SmallVector<Value> tripCounts(op.getNumLoops()); |
741 | for (size_t i = 0; i < op.getNumLoops(); ++i) { |
742 | auto lb = op.getLowerBound()[i]; |
743 | auto ub = op.getUpperBound()[i]; |
744 | auto step = op.getStep()[i]; |
745 | auto range = b.createOrFold<arith::SubIOp>(ub, lb); |
746 | tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step); |
747 | } |
748 | |
749 | // Compute a product of trip counts to get the 1-dimensional iteration space |
750 | // for the scf.parallel operation. |
751 | Value tripCount = tripCounts[0]; |
752 | for (size_t i = 1; i < tripCounts.size(); ++i) |
753 | tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]); |
754 | |
755 | // Short circuit no-op parallel loops (zero iterations) that can arise from |
756 | // the memrefs with dynamic dimension(s) equal to zero. |
757 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
758 | Value isZeroIterations = |
759 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0); |
760 | |
761 | // Do absolutely nothing if the trip count is zero. |
762 | auto noOp = [&](OpBuilder &nestedBuilder, Location loc) { |
763 | nestedBuilder.create<scf::YieldOp>(loc); |
764 | }; |
765 | |
766 | // Compute the parallel block size and dispatch concurrent tasks computing |
767 | // results for each block. |
768 | auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { |
769 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
770 | |
771 | // Collect statically known constants defining the loop nest in the parallel |
772 | // compute function. LLVM can't always push constants across the non-trivial |
773 | // async dispatch call graph, by providing these values explicitly we can |
774 | // choose to build more efficient loop nest, and rely on a better constant |
775 | // folding, loop unrolling and vectorization. |
776 | ParallelComputeFunctionBounds staticBounds = { |
777 | integerConstants(tripCounts), |
778 | integerConstants(op.getLowerBound()), |
779 | integerConstants(op.getUpperBound()), |
780 | integerConstants(op.getStep()), |
781 | }; |
782 | |
783 | // Find how many inner iteration dimensions are statically known, and their |
784 | // product is smaller than the `512`. We align the parallel compute block |
785 | // size by the product of statically known dimensions, so that we can |
786 | // guarantee that the inner loops executes from 0 to the loop trip counts |
787 | // and we can elide dynamic loop boundaries, and give LLVM an opportunity to |
788 | // unroll the loops. The constant `512` is arbitrary, it should depend on |
789 | // how many iterations LLVM will typically decide to unroll. |
790 | static constexpr int64_t maxUnrollableIterations = 512; |
791 | |
792 | // The number of inner loops with statically known number of iterations less |
793 | // than the `maxUnrollableIterations` value. |
794 | int numUnrollableLoops = 0; |
795 | |
796 | auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; }; |
797 | |
798 | SmallVector<int64_t> numIterations(op.getNumLoops()); |
799 | numIterations.back() = getInt(staticBounds.tripCounts.back()); |
800 | |
801 | for (int i = op.getNumLoops() - 2; i >= 0; --i) { |
802 | int64_t tripCount = getInt(staticBounds.tripCounts[i]); |
803 | int64_t innerIterations = numIterations[i + 1]; |
804 | numIterations[i] = tripCount * innerIterations; |
805 | |
806 | // Update the number of inner loops that we can potentially unroll. |
807 | if (innerIterations > 0 && innerIterations <= maxUnrollableIterations) |
808 | numUnrollableLoops++; |
809 | } |
810 | |
811 | Value numWorkerThreadsVal; |
812 | if (numWorkerThreads >= 0) |
813 | numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(args: numWorkerThreads); |
814 | else |
815 | numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>(); |
816 | |
817 | // With large number of threads the value of creating many compute blocks |
818 | // is reduced because the problem typically becomes memory bound. For this |
819 | // reason we scale the number of workers using an equivalent to the |
820 | // following logic: |
821 | // float overshardingFactor = numWorkerThreads <= 4 ? 8.0 |
822 | // : numWorkerThreads <= 8 ? 4.0 |
823 | // : numWorkerThreads <= 16 ? 2.0 |
824 | // : numWorkerThreads <= 32 ? 1.0 |
825 | // : numWorkerThreads <= 64 ? 0.8 |
826 | // : 0.6; |
827 | |
828 | // Pairs of non-inclusive lower end of the bracket and factor that the |
829 | // number of workers needs to be scaled with if it falls in that bucket. |
830 | const SmallVector<std::pair<int, float>> overshardingBrackets = { |
831 | {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}}; |
832 | const float initialOvershardingFactor = 8.0f; |
833 | |
834 | Value scalingFactor = b.create<arith::ConstantFloatOp>( |
835 | args: llvm::APFloat(initialOvershardingFactor), args: b.getF32Type()); |
836 | for (const std::pair<int, float> &p : overshardingBrackets) { |
837 | Value bracketBegin = b.create<arith::ConstantIndexOp>(args: p.first); |
838 | Value inBracket = b.create<arith::CmpIOp>( |
839 | arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); |
840 | Value bracketScalingFactor = b.create<arith::ConstantFloatOp>( |
841 | args: llvm::APFloat(p.second), args: b.getF32Type()); |
842 | scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor, |
843 | scalingFactor); |
844 | } |
845 | Value numWorkersIndex = |
846 | b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal); |
847 | Value numWorkersFloat = |
848 | b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex); |
849 | Value scaledNumWorkers = |
850 | b.create<arith::MulFOp>(scalingFactor, numWorkersFloat); |
851 | Value scaledNumInt = |
852 | b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers); |
853 | Value scaledWorkers = |
854 | b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt); |
855 | |
856 | Value maxComputeBlocks = b.create<arith::MaxSIOp>( |
857 | b.create<arith::ConstantIndexOp>(1), scaledWorkers); |
858 | |
859 | // Compute parallel block size from the parallel problem size: |
860 | // blockSize = min(tripCount, |
861 | // max(ceil_div(tripCount, maxComputeBlocks), |
862 | // minTaskSize)) |
863 | Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); |
864 | Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize); |
865 | Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); |
866 | |
867 | // Dispatch parallel compute function using async recursive work splitting, |
868 | // or by submitting compute task sequentially from a caller thread. |
869 | auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch; |
870 | |
871 | // Create a parallel compute function that takes a block id and computes |
872 | // the parallel operation body for a subset of iteration space. |
873 | |
874 | // Compute the number of parallel compute blocks. |
875 | Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); |
876 | |
877 | // Dispatch parallel compute function without hints to unroll inner loops. |
878 | auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) { |
879 | ParallelComputeFunction compute = |
880 | createParallelComputeFunction(op, staticBounds, 0, rewriter); |
881 | |
882 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
883 | doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts); |
884 | b.create<scf::YieldOp>(); |
885 | }; |
886 | |
887 | // Dispatch parallel compute function with hints for unrolling inner loops. |
888 | auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) { |
889 | ParallelComputeFunction compute = createParallelComputeFunction( |
890 | op, staticBounds, numUnrollableLoops, rewriter); |
891 | |
892 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
893 | // Align the block size to be a multiple of the statically known |
894 | // number of iterations in the inner loops. |
895 | Value numIters = b.create<arith::ConstantIndexOp>( |
896 | numIterations[op.getNumLoops() - numUnrollableLoops]); |
897 | Value alignedBlockSize = b.create<arith::MulIOp>( |
898 | b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters); |
899 | doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount, |
900 | tripCounts); |
901 | b.create<scf::YieldOp>(); |
902 | }; |
903 | |
904 | // Dispatch to block aligned compute function only if the computed block |
905 | // size is larger than the number of iterations in the unrollable inner |
906 | // loops, because otherwise it can reduce the available parallelism. |
907 | if (numUnrollableLoops > 0) { |
908 | Value numIters = b.create<arith::ConstantIndexOp>( |
909 | numIterations[op.getNumLoops() - numUnrollableLoops]); |
910 | Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>( |
911 | arith::CmpIPredicate::sge, blockSize, numIters); |
912 | |
913 | b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned, |
914 | dispatchDefault); |
915 | b.create<scf::YieldOp>(); |
916 | } else { |
917 | dispatchDefault(b, loc); |
918 | } |
919 | }; |
920 | |
921 | // Replace the `scf.parallel` operation with the parallel compute function. |
922 | b.create<scf::IfOp>(isZeroIterations, noOp, dispatch); |
923 | |
924 | // Parallel operation was replaced with a block iteration loop. |
925 | rewriter.eraseOp(op: op); |
926 | |
927 | return success(); |
928 | } |
929 | |
930 | void AsyncParallelForPass::runOnOperation() { |
931 | MLIRContext *ctx = &getContext(); |
932 | |
933 | RewritePatternSet patterns(ctx); |
934 | populateAsyncParallelForPatterns( |
935 | patterns, asyncDispatch, numWorkerThreads, |
936 | [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { |
937 | return builder.create<arith::ConstantIndexOp>(minTaskSize); |
938 | }); |
939 | if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) |
940 | signalPassFailure(); |
941 | } |
942 | |
943 | std::unique_ptr<Pass> mlir::createAsyncParallelForPass() { |
944 | return std::make_unique<AsyncParallelForPass>(); |
945 | } |
946 | |
947 | std::unique_ptr<Pass> mlir::createAsyncParallelForPass(bool asyncDispatch, |
948 | int32_t numWorkerThreads, |
949 | int32_t minTaskSize) { |
950 | return std::make_unique<AsyncParallelForPass>(args&: asyncDispatch, args&: numWorkerThreads, |
951 | args&: minTaskSize); |
952 | } |
953 | |
954 | void mlir::async::populateAsyncParallelForPatterns( |
955 | RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, |
956 | const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) { |
957 | MLIRContext *ctx = patterns.getContext(); |
958 | patterns.add<AsyncParallelForRewrite>(arg&: ctx, args&: asyncDispatch, args&: numWorkerThreads, |
959 | args: computeMinTaskSize); |
960 | } |
961 | |