1 | //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements scf.parallel to scf.for + async.execute conversion pass. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Async/Passes.h" |
14 | |
15 | #include "PassDetail.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Async/IR/Async.h" |
18 | #include "mlir/Dialect/Async/Transforms.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | #include "mlir/Dialect/SCF/IR/SCF.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
23 | #include "mlir/IR/Matchers.h" |
24 | #include "mlir/IR/PatternMatch.h" |
25 | #include "mlir/Support/LLVM.h" |
26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | #include "mlir/Transforms/RegionUtils.h" |
28 | #include <utility> |
29 | |
30 | namespace mlir { |
31 | #define GEN_PASS_DEF_ASYNCPARALLELFORPASS |
32 | #include "mlir/Dialect/Async/Passes.h.inc" |
33 | } // namespace mlir |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::async; |
37 | |
38 | #define DEBUG_TYPE "async-parallel-for" |
39 | |
40 | namespace { |
41 | |
42 | // Rewrite scf.parallel operation into multiple concurrent async.execute |
43 | // operations over non overlapping subranges of the original loop. |
44 | // |
45 | // Example: |
46 | // |
47 | // scf.parallel (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { |
48 | // "do_some_compute"(%i, %j): () -> () |
49 | // } |
50 | // |
51 | // Converted to: |
52 | // |
53 | // // Parallel compute function that executes the parallel body region for |
54 | // // a subset of the parallel iteration space defined by the one-dimensional |
55 | // // compute block index. |
56 | // func parallel_compute_function(%block_index : index, %block_size : index, |
57 | // <parallel operation properties>, ...) { |
58 | // // Compute multi-dimensional loop bounds for %block_index. |
59 | // %block_lbi, %block_lbj = ... |
60 | // %block_ubi, %block_ubj = ... |
61 | // |
62 | // // Clone parallel operation body into the scf.for loop nest. |
63 | // scf.for %i = %blockLbi to %blockUbi { |
64 | // scf.for %j = block_lbj to %block_ubj { |
65 | // "do_some_compute"(%i, %j): () -> () |
66 | // } |
67 | // } |
68 | // } |
69 | // |
70 | // And a dispatch function depending on the `asyncDispatch` option. |
71 | // |
72 | // When async dispatch is on: (pseudocode) |
73 | // |
74 | // %block_size = ... compute parallel compute block size |
75 | // %block_count = ... compute the number of compute blocks |
76 | // |
77 | // func @async_dispatch(%block_start : index, %block_end : index, ...) { |
78 | // // Keep splitting block range until we reached a range of size 1. |
79 | // while (%block_end - %block_start > 1) { |
80 | // %mid_index = block_start + (block_end - block_start) / 2; |
81 | // async.execute { call @async_dispatch(%mid_index, %block_end); } |
82 | // %block_end = %mid_index |
83 | // } |
84 | // |
85 | // // Call parallel compute function for a single block. |
86 | // call @parallel_compute_fn(%block_start, %block_size, ...); |
87 | // } |
88 | // |
89 | // // Launch async dispatch for [0, block_count) range. |
90 | // call @async_dispatch(%c0, %block_count); |
91 | // |
92 | // When async dispatch is off: |
93 | // |
94 | // %block_size = ... compute parallel compute block size |
95 | // %block_count = ... compute the number of compute blocks |
96 | // |
97 | // scf.for %block_index = %c0 to %block_count { |
98 | // call @parallel_compute_fn(%block_index, %block_size, ...) |
99 | // } |
100 | // |
101 | struct AsyncParallelForPass |
102 | : public impl::AsyncParallelForPassBase<AsyncParallelForPass> { |
103 | using Base::Base; |
104 | |
105 | void runOnOperation() override; |
106 | }; |
107 | |
108 | struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> { |
109 | public: |
110 | AsyncParallelForRewrite( |
111 | MLIRContext *ctx, bool asyncDispatch, int32_t numWorkerThreads, |
112 | AsyncMinTaskSizeComputationFunction computeMinTaskSize) |
113 | : OpRewritePattern(ctx), asyncDispatch(asyncDispatch), |
114 | numWorkerThreads(numWorkerThreads), |
115 | computeMinTaskSize(std::move(computeMinTaskSize)) {} |
116 | |
117 | LogicalResult matchAndRewrite(scf::ParallelOp op, |
118 | PatternRewriter &rewriter) const override; |
119 | |
120 | private: |
121 | bool asyncDispatch; |
122 | int32_t numWorkerThreads; |
123 | AsyncMinTaskSizeComputationFunction computeMinTaskSize; |
124 | }; |
125 | |
126 | struct ParallelComputeFunctionType { |
127 | FunctionType type; |
128 | SmallVector<Value> captures; |
129 | }; |
130 | |
131 | // Helper struct to parse parallel compute function argument list. |
132 | struct ParallelComputeFunctionArgs { |
133 | BlockArgument blockIndex(); |
134 | BlockArgument blockSize(); |
135 | ArrayRef<BlockArgument> tripCounts(); |
136 | ArrayRef<BlockArgument> lowerBounds(); |
137 | ArrayRef<BlockArgument> steps(); |
138 | ArrayRef<BlockArgument> captures(); |
139 | |
140 | unsigned numLoops; |
141 | ArrayRef<BlockArgument> args; |
142 | }; |
143 | |
144 | struct ParallelComputeFunctionBounds { |
145 | SmallVector<IntegerAttr> tripCounts; |
146 | SmallVector<IntegerAttr> lowerBounds; |
147 | SmallVector<IntegerAttr> upperBounds; |
148 | SmallVector<IntegerAttr> steps; |
149 | }; |
150 | |
151 | struct ParallelComputeFunction { |
152 | unsigned numLoops; |
153 | func::FuncOp func; |
154 | llvm::SmallVector<Value> captures; |
155 | }; |
156 | |
157 | } // namespace |
158 | |
159 | BlockArgument ParallelComputeFunctionArgs::blockIndex() { return args[0]; } |
160 | BlockArgument ParallelComputeFunctionArgs::blockSize() { return args[1]; } |
161 | |
162 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::tripCounts() { |
163 | return args.drop_front(N: 2).take_front(N: numLoops); |
164 | } |
165 | |
166 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::lowerBounds() { |
167 | return args.drop_front(N: 2 + 1 * numLoops).take_front(N: numLoops); |
168 | } |
169 | |
170 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::steps() { |
171 | return args.drop_front(N: 2 + 3 * numLoops).take_front(N: numLoops); |
172 | } |
173 | |
174 | ArrayRef<BlockArgument> ParallelComputeFunctionArgs::captures() { |
175 | return args.drop_front(N: 2 + 4 * numLoops); |
176 | } |
177 | |
178 | template <typename ValueRange> |
179 | static SmallVector<IntegerAttr> integerConstants(ValueRange values) { |
180 | SmallVector<IntegerAttr> attrs(values.size()); |
181 | for (unsigned i = 0; i < values.size(); ++i) |
182 | matchPattern(values[i], m_Constant(&attrs[i])); |
183 | return attrs; |
184 | } |
185 | |
186 | // Converts one-dimensional iteration index in the [0, tripCount) interval |
187 | // into multidimensional iteration coordinate. |
188 | static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index, |
189 | ArrayRef<Value> tripCounts) { |
190 | SmallVector<Value> coords(tripCounts.size()); |
191 | assert(!tripCounts.empty() && "tripCounts must be not empty"); |
192 | |
193 | for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) { |
194 | coords[i] = b.create<arith::RemSIOp>(index, tripCounts[i]); |
195 | index = b.create<arith::DivSIOp>(index, tripCounts[i]); |
196 | } |
197 | |
198 | return coords; |
199 | } |
200 | |
201 | // Returns a function type and implicit captures for a parallel compute |
202 | // function. We'll need a list of implicit captures to setup block and value |
203 | // mapping when we'll clone the body of the parallel operation. |
204 | static ParallelComputeFunctionType |
205 | getParallelComputeFunctionType(scf::ParallelOp op, PatternRewriter &rewriter) { |
206 | // Values implicitly captured by the parallel operation. |
207 | llvm::SetVector<Value> captures; |
208 | getUsedValuesDefinedAbove(op.getRegion(), op.getRegion(), captures); |
209 | |
210 | SmallVector<Type> inputs; |
211 | inputs.reserve(N: 2 + 4 * op.getNumLoops() + captures.size()); |
212 | |
213 | Type indexTy = rewriter.getIndexType(); |
214 | |
215 | // One-dimensional iteration space defined by the block index and size. |
216 | inputs.push_back(Elt: indexTy); // blockIndex |
217 | inputs.push_back(Elt: indexTy); // blockSize |
218 | |
219 | // Multi-dimensional parallel iteration space defined by the loop trip counts. |
220 | for (unsigned i = 0; i < op.getNumLoops(); ++i) |
221 | inputs.push_back(Elt: indexTy); // loop tripCount |
222 | |
223 | // Parallel operation lower bound, upper bound and step. Lower bound, upper |
224 | // bound and step passed as contiguous arguments: |
225 | // call @compute(%lb0, %lb1, ..., %ub0, %ub1, ..., %step0, %step1, ...) |
226 | for (unsigned i = 0; i < op.getNumLoops(); ++i) { |
227 | inputs.push_back(Elt: indexTy); // lower bound |
228 | inputs.push_back(Elt: indexTy); // upper bound |
229 | inputs.push_back(Elt: indexTy); // step |
230 | } |
231 | |
232 | // Types of the implicit captures. |
233 | for (Value capture : captures) |
234 | inputs.push_back(Elt: capture.getType()); |
235 | |
236 | // Convert captures to vector for later convenience. |
237 | SmallVector<Value> capturesVector(captures.begin(), captures.end()); |
238 | return {rewriter.getFunctionType(inputs, TypeRange()), capturesVector}; |
239 | } |
240 | |
241 | // Create a parallel compute fuction from the parallel operation. |
242 | static ParallelComputeFunction createParallelComputeFunction( |
243 | scf::ParallelOp op, const ParallelComputeFunctionBounds &bounds, |
244 | unsigned numBlockAlignedInnerLoops, PatternRewriter &rewriter) { |
245 | OpBuilder::InsertionGuard guard(rewriter); |
246 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
247 | |
248 | ModuleOp module = op->getParentOfType<ModuleOp>(); |
249 | |
250 | ParallelComputeFunctionType computeFuncType = |
251 | getParallelComputeFunctionType(op, rewriter); |
252 | |
253 | FunctionType type = computeFuncType.type; |
254 | func::FuncOp func = func::FuncOp::create( |
255 | op.getLoc(), |
256 | numBlockAlignedInnerLoops > 0 ? "parallel_compute_fn_with_aligned_loops" |
257 | : "parallel_compute_fn", |
258 | type); |
259 | func.setPrivate(); |
260 | |
261 | // Insert function into the module symbol table and assign it unique name. |
262 | SymbolTable symbolTable(module); |
263 | symbolTable.insert(symbol: func); |
264 | rewriter.getListener()->notifyOperationInserted(op: func, /*previous=*/{}); |
265 | |
266 | // Create function entry block. |
267 | Block *block = |
268 | b.createBlock(&func.getBody(), func.begin(), type.getInputs(), |
269 | SmallVector<Location>(type.getNumInputs(), op.getLoc())); |
270 | b.setInsertionPointToEnd(block); |
271 | |
272 | ParallelComputeFunctionArgs args = {op.getNumLoops(), func.getArguments()}; |
273 | |
274 | // Block iteration position defined by the block index and size. |
275 | BlockArgument blockIndex = args.blockIndex(); |
276 | BlockArgument blockSize = args.blockSize(); |
277 | |
278 | // Constants used below. |
279 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
280 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
281 | |
282 | // Materialize known constants as constant operation in the function body. |
283 | auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) { |
284 | return llvm::to_vector( |
285 | Range: llvm::map_range(C: llvm::zip(t&: args, u&: attrs), F: [&](auto tuple) -> Value { |
286 | if (IntegerAttr attr = std::get<1>(tuple)) |
287 | return b.create<arith::ConstantOp>(attr); |
288 | return std::get<0>(tuple); |
289 | })); |
290 | }; |
291 | |
292 | // Multi-dimensional parallel iteration space defined by the loop trip counts. |
293 | auto tripCounts = values(args.tripCounts(), bounds.tripCounts); |
294 | |
295 | // Parallel operation lower bound and step. |
296 | auto lowerBounds = values(args.lowerBounds(), bounds.lowerBounds); |
297 | auto steps = values(args.steps(), bounds.steps); |
298 | |
299 | // Remaining arguments are implicit captures of the parallel operation. |
300 | ArrayRef<BlockArgument> captures = args.captures(); |
301 | |
302 | // Compute a product of trip counts to get the size of the flattened |
303 | // one-dimensional iteration space. |
304 | Value tripCount = tripCounts[0]; |
305 | for (unsigned i = 1; i < tripCounts.size(); ++i) |
306 | tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]); |
307 | |
308 | // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: |
309 | // blockFirstIndex = blockIndex * blockSize |
310 | Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize); |
311 | |
312 | // The last one-dimensional index in the block defined by the `blockIndex`: |
313 | // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1 |
314 | Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize); |
315 | Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount); |
316 | Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1); |
317 | |
318 | // Convert one-dimensional indices to multi-dimensional coordinates. |
319 | auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); |
320 | auto blockLastCoord = delinearize(b, blockLastIndex, tripCounts); |
321 | |
322 | // Compute loops upper bounds derived from the block last coordinates: |
323 | // blockEndCoord[i] = blockLastCoord[i] + 1 |
324 | // |
325 | // Block first and last coordinates can be the same along the outer compute |
326 | // dimension when inner compute dimension contains multiple blocks. |
327 | SmallVector<Value> blockEndCoord(op.getNumLoops()); |
328 | for (size_t i = 0; i < blockLastCoord.size(); ++i) |
329 | blockEndCoord[i] = b.create<arith::AddIOp>(blockLastCoord[i], c1); |
330 | |
331 | // Construct a loop nest out of scf.for operations that will iterate over |
332 | // all coordinates in [blockFirstCoord, blockLastCoord] range. |
333 | using LoopBodyBuilder = |
334 | std::function<void(OpBuilder &, Location, Value, ValueRange)>; |
335 | using LoopNestBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>; |
336 | |
337 | // Parallel region induction variables computed from the multi-dimensional |
338 | // iteration coordinate using parallel operation bounds and step: |
339 | // |
340 | // computeBlockInductionVars[loopIdx] = |
341 | // lowerBound[loopIdx] + blockCoord[loopIdx] * step[loopIdx] |
342 | SmallVector<Value> computeBlockInductionVars(op.getNumLoops()); |
343 | |
344 | // We need to know if we are in the first or last iteration of the |
345 | // multi-dimensional loop for each loop in the nest, so we can decide what |
346 | // loop bounds should we use for the nested loops: bounds defined by compute |
347 | // block interval, or bounds defined by the parallel operation. |
348 | // |
349 | // Example: 2d parallel operation |
350 | // i j |
351 | // loop sizes: [50, 50] |
352 | // first coord: [25, 25] |
353 | // last coord: [30, 30] |
354 | // |
355 | // If `i` is equal to 25 then iteration over `j` should start at 25, when `i` |
356 | // is between 25 and 30 it should start at 0. The upper bound for `j` should |
357 | // be 50, except when `i` is equal to 30, then it should also be 30. |
358 | // |
359 | // Value at ith position specifies if all loops in [0, i) range of the loop |
360 | // nest are in the first/last iteration. |
361 | SmallVector<Value> isBlockFirstCoord(op.getNumLoops()); |
362 | SmallVector<Value> isBlockLastCoord(op.getNumLoops()); |
363 | |
364 | // Builds inner loop nest inside async.execute operation that does all the |
365 | // work concurrently. |
366 | LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder { |
367 | return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv, |
368 | ValueRange args) { |
369 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
370 | |
371 | // Compute induction variable for `loopIdx`. |
372 | computeBlockInductionVars[loopIdx] = b.create<arith::AddIOp>( |
373 | lowerBounds[loopIdx], b.create<arith::MulIOp>(iv, steps[loopIdx])); |
374 | |
375 | // Check if we are inside first or last iteration of the loop. |
376 | isBlockFirstCoord[loopIdx] = b.create<arith::CmpIOp>( |
377 | arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); |
378 | isBlockLastCoord[loopIdx] = b.create<arith::CmpIOp>( |
379 | arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); |
380 | |
381 | // Check if the previous loop is in its first or last iteration. |
382 | if (loopIdx > 0) { |
383 | isBlockFirstCoord[loopIdx] = b.create<arith::AndIOp>( |
384 | isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); |
385 | isBlockLastCoord[loopIdx] = b.create<arith::AndIOp>( |
386 | isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); |
387 | } |
388 | |
389 | // Keep building loop nest. |
390 | if (loopIdx < op.getNumLoops() - 1) { |
391 | if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) { |
392 | // For block aligned loops we always iterate starting from 0 up to |
393 | // the loop trip counts. |
394 | b.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(), |
395 | workLoopBuilder(loopIdx + 1)); |
396 | |
397 | } else { |
398 | // Select nested loop lower/upper bounds depending on our position in |
399 | // the multi-dimensional iteration space. |
400 | auto lb = b.create<arith::SelectOp>(isBlockFirstCoord[loopIdx], |
401 | blockFirstCoord[loopIdx + 1], c0); |
402 | |
403 | auto ub = b.create<arith::SelectOp>(isBlockLastCoord[loopIdx], |
404 | blockEndCoord[loopIdx + 1], |
405 | tripCounts[loopIdx + 1]); |
406 | |
407 | b.create<scf::ForOp>(lb, ub, c1, ValueRange(), |
408 | workLoopBuilder(loopIdx + 1)); |
409 | } |
410 | |
411 | b.create<scf::YieldOp>(loc); |
412 | return; |
413 | } |
414 | |
415 | // Copy the body of the parallel op into the inner-most loop. |
416 | IRMapping mapping; |
417 | mapping.map(op.getInductionVars(), computeBlockInductionVars); |
418 | mapping.map(computeFuncType.captures, captures); |
419 | |
420 | for (auto &bodyOp : op.getRegion().front().without_terminator()) |
421 | b.clone(bodyOp, mapping); |
422 | b.create<scf::YieldOp>(loc); |
423 | }; |
424 | }; |
425 | |
426 | b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), |
427 | workLoopBuilder(0)); |
428 | b.create<func::ReturnOp>(ValueRange()); |
429 | |
430 | return {op.getNumLoops(), func, std::move(computeFuncType.captures)}; |
431 | } |
432 | |
433 | // Creates recursive async dispatch function for the given parallel compute |
434 | // function. Dispatch function keeps splitting block range into halves until it |
435 | // reaches a single block, and then excecutes it inline. |
436 | // |
437 | // Function pseudocode (mix of C++ and MLIR): |
438 | // |
439 | // func @async_dispatch(%block_start : index, %block_end : index, ...) { |
440 | // |
441 | // // Keep splitting block range until we reached a range of size 1. |
442 | // while (%block_end - %block_start > 1) { |
443 | // %mid_index = block_start + (block_end - block_start) / 2; |
444 | // async.execute { call @async_dispatch(%mid_index, %block_end); } |
445 | // %block_end = %mid_index |
446 | // } |
447 | // |
448 | // // Call parallel compute function for a single block. |
449 | // call @parallel_compute_fn(%block_start, %block_size, ...); |
450 | // } |
451 | // |
452 | static func::FuncOp |
453 | createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, |
454 | PatternRewriter &rewriter) { |
455 | OpBuilder::InsertionGuard guard(rewriter); |
456 | Location loc = computeFunc.func.getLoc(); |
457 | ImplicitLocOpBuilder b(loc, rewriter); |
458 | |
459 | ModuleOp module = computeFunc.func->getParentOfType<ModuleOp>(); |
460 | |
461 | ArrayRef<Type> computeFuncInputTypes = |
462 | computeFunc.func.getFunctionType().getInputs(); |
463 | |
464 | // Compared to the parallel compute function async dispatch function takes |
465 | // additional !async.group argument. Also instead of a single `blockIndex` it |
466 | // takes `blockStart` and `blockEnd` arguments to define the range of |
467 | // dispatched blocks. |
468 | SmallVector<Type> inputTypes; |
469 | inputTypes.push_back(async::GroupType::get(rewriter.getContext())); |
470 | inputTypes.push_back(rewriter.getIndexType()); // add blockStart argument |
471 | inputTypes.append(in_start: computeFuncInputTypes.begin(), in_end: computeFuncInputTypes.end()); |
472 | |
473 | FunctionType type = rewriter.getFunctionType(inputTypes, TypeRange()); |
474 | func::FuncOp func = func::FuncOp::create(loc, "async_dispatch_fn", type); |
475 | func.setPrivate(); |
476 | |
477 | // Insert function into the module symbol table and assign it unique name. |
478 | SymbolTable symbolTable(module); |
479 | symbolTable.insert(symbol: func); |
480 | rewriter.getListener()->notifyOperationInserted(op: func, /*previous=*/{}); |
481 | |
482 | // Create function entry block. |
483 | Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(), |
484 | SmallVector<Location>(type.getNumInputs(), loc)); |
485 | b.setInsertionPointToEnd(block); |
486 | |
487 | Type indexTy = b.getIndexType(); |
488 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
489 | Value c2 = b.create<arith::ConstantIndexOp>(args: 2); |
490 | |
491 | // Get the async group that will track async dispatch completion. |
492 | Value group = block->getArgument(i: 0); |
493 | |
494 | // Get the block iteration range: [blockStart, blockEnd) |
495 | Value blockStart = block->getArgument(i: 1); |
496 | Value blockEnd = block->getArgument(i: 2); |
497 | |
498 | // Create a work splitting while loop for the [blockStart, blockEnd) range. |
499 | SmallVector<Type> types = {indexTy, indexTy}; |
500 | SmallVector<Value> operands = {blockStart, blockEnd}; |
501 | SmallVector<Location> locations = {loc, loc}; |
502 | |
503 | // Create a recursive dispatch loop. |
504 | scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands); |
505 | Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations); |
506 | Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations); |
507 | |
508 | // Setup dispatch loop condition block: decide if we need to go into the |
509 | // `after` block and launch one more async dispatch. |
510 | { |
511 | b.setInsertionPointToEnd(before); |
512 | Value start = before->getArgument(i: 0); |
513 | Value end = before->getArgument(i: 1); |
514 | Value distance = b.create<arith::SubIOp>(end, start); |
515 | Value dispatch = |
516 | b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1); |
517 | b.create<scf::ConditionOp>(dispatch, before->getArguments()); |
518 | } |
519 | |
520 | // Setup the async dispatch loop body: recursively call dispatch function |
521 | // for the seconds half of the original range and go to the next iteration. |
522 | { |
523 | b.setInsertionPointToEnd(after); |
524 | Value start = after->getArgument(i: 0); |
525 | Value end = after->getArgument(i: 1); |
526 | Value distance = b.create<arith::SubIOp>(end, start); |
527 | Value halfDistance = b.create<arith::DivSIOp>(distance, c2); |
528 | Value midIndex = b.create<arith::AddIOp>(start, halfDistance); |
529 | |
530 | // Call parallel compute function inside the async.execute region. |
531 | auto executeBodyBuilder = [&](OpBuilder &executeBuilder, |
532 | Location executeLoc, ValueRange executeArgs) { |
533 | // Update the original `blockStart` and `blockEnd` with new range. |
534 | SmallVector<Value> operands{block->getArguments().begin(), |
535 | block->getArguments().end()}; |
536 | operands[1] = midIndex; |
537 | operands[2] = end; |
538 | |
539 | executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(), |
540 | func.getResultTypes(), operands); |
541 | executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); |
542 | }; |
543 | |
544 | // Create async.execute operation to dispatch half of the block range. |
545 | auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), |
546 | executeBodyBuilder); |
547 | b.create<AddToGroupOp>(indexTy, execute.getToken(), group); |
548 | b.create<scf::YieldOp>(ValueRange({start, midIndex})); |
549 | } |
550 | |
551 | // After dispatching async operations to process the tail of the block range |
552 | // call the parallel compute function for the first block of the range. |
553 | b.setInsertionPointAfter(whileOp); |
554 | |
555 | // Drop async dispatch specific arguments: async group, block start and end. |
556 | auto forwardedInputs = block->getArguments().drop_front(N: 3); |
557 | SmallVector<Value> computeFuncOperands = {blockStart}; |
558 | computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end()); |
559 | |
560 | b.create<func::CallOp>(computeFunc.func.getSymName(), |
561 | computeFunc.func.getResultTypes(), |
562 | computeFuncOperands); |
563 | b.create<func::ReturnOp>(ValueRange()); |
564 | |
565 | return func; |
566 | } |
567 | |
568 | // Launch async dispatch of the parallel compute function. |
569 | static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, |
570 | ParallelComputeFunction ¶llelComputeFunction, |
571 | scf::ParallelOp op, Value blockSize, |
572 | Value blockCount, |
573 | const SmallVector<Value> &tripCounts) { |
574 | MLIRContext *ctx = op->getContext(); |
575 | |
576 | // Add one more level of indirection to dispatch parallel compute functions |
577 | // using async operations and recursive work splitting. |
578 | func::FuncOp asyncDispatchFunction = |
579 | createAsyncDispatchFunction(parallelComputeFunction, rewriter); |
580 | |
581 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
582 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
583 | |
584 | // Appends operands shared by async dispatch and parallel compute functions to |
585 | // the given operands vector. |
586 | auto appendBlockComputeOperands = [&](SmallVector<Value> &operands) { |
587 | operands.append(RHS: tripCounts); |
588 | operands.append(op.getLowerBound().begin(), op.getLowerBound().end()); |
589 | operands.append(op.getUpperBound().begin(), op.getUpperBound().end()); |
590 | operands.append(op.getStep().begin(), op.getStep().end()); |
591 | operands.append(parallelComputeFunction.captures); |
592 | }; |
593 | |
594 | // Check if the block size is one, in this case we can skip the async dispatch |
595 | // completely. If this will be known statically, then canonicalization will |
596 | // erase async group operations. |
597 | Value isSingleBlock = |
598 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1); |
599 | |
600 | auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { |
601 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
602 | |
603 | // Call parallel compute function for the single block. |
604 | SmallVector<Value> operands = {c0, blockSize}; |
605 | appendBlockComputeOperands(operands); |
606 | |
607 | b.create<func::CallOp>(parallelComputeFunction.func.getSymName(), |
608 | parallelComputeFunction.func.getResultTypes(), |
609 | operands); |
610 | b.create<scf::YieldOp>(); |
611 | }; |
612 | |
613 | auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { |
614 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
615 | |
616 | // Create an async.group to wait on all async tokens from the concurrent |
617 | // execution of multiple parallel compute function. First block will be |
618 | // executed synchronously in the caller thread. |
619 | Value groupSize = b.create<arith::SubIOp>(blockCount, c1); |
620 | Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); |
621 | |
622 | // Launch async dispatch function for [0, blockCount) range. |
623 | SmallVector<Value> operands = {group, c0, blockCount, blockSize}; |
624 | appendBlockComputeOperands(operands); |
625 | |
626 | b.create<func::CallOp>(asyncDispatchFunction.getSymName(), |
627 | asyncDispatchFunction.getResultTypes(), operands); |
628 | |
629 | // Wait for the completion of all parallel compute operations. |
630 | b.create<AwaitAllOp>(group); |
631 | |
632 | b.create<scf::YieldOp>(); |
633 | }; |
634 | |
635 | // Dispatch either single block compute function, or launch async dispatch. |
636 | b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch); |
637 | } |
638 | |
639 | // Dispatch parallel compute functions by submitting all async compute tasks |
640 | // from a simple for loop in the caller thread. |
641 | static void |
642 | doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, |
643 | ParallelComputeFunction ¶llelComputeFunction, |
644 | scf::ParallelOp op, Value blockSize, Value blockCount, |
645 | const SmallVector<Value> &tripCounts) { |
646 | MLIRContext *ctx = op->getContext(); |
647 | |
648 | func::FuncOp compute = parallelComputeFunction.func; |
649 | |
650 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
651 | Value c1 = b.create<arith::ConstantIndexOp>(args: 1); |
652 | |
653 | // Create an async.group to wait on all async tokens from the concurrent |
654 | // execution of multiple parallel compute function. First block will be |
655 | // executed synchronously in the caller thread. |
656 | Value groupSize = b.create<arith::SubIOp>(blockCount, c1); |
657 | Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); |
658 | |
659 | // Call parallel compute function for all blocks. |
660 | using LoopBodyBuilder = |
661 | std::function<void(OpBuilder &, Location, Value, ValueRange)>; |
662 | |
663 | // Returns parallel compute function operands to process the given block. |
664 | auto computeFuncOperands = [&](Value blockIndex) -> SmallVector<Value> { |
665 | SmallVector<Value> computeFuncOperands = {blockIndex, blockSize}; |
666 | computeFuncOperands.append(RHS: tripCounts); |
667 | computeFuncOperands.append(op.getLowerBound().begin(), |
668 | op.getLowerBound().end()); |
669 | computeFuncOperands.append(op.getUpperBound().begin(), |
670 | op.getUpperBound().end()); |
671 | computeFuncOperands.append(op.getStep().begin(), op.getStep().end()); |
672 | computeFuncOperands.append(parallelComputeFunction.captures); |
673 | return computeFuncOperands; |
674 | }; |
675 | |
676 | // Induction variable is the index of the block: [0, blockCount). |
677 | LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc, |
678 | Value iv, ValueRange args) { |
679 | ImplicitLocOpBuilder b(loc, loopBuilder); |
680 | |
681 | // Call parallel compute function inside the async.execute region. |
682 | auto executeBodyBuilder = [&](OpBuilder &executeBuilder, |
683 | Location executeLoc, ValueRange executeArgs) { |
684 | executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(), |
685 | compute.getResultTypes(), |
686 | computeFuncOperands(iv)); |
687 | executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); |
688 | }; |
689 | |
690 | // Create async.execute operation to launch parallel computate function. |
691 | auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), |
692 | executeBodyBuilder); |
693 | b.create<AddToGroupOp>(rewriter.getIndexType(), execute.getToken(), group); |
694 | b.create<scf::YieldOp>(); |
695 | }; |
696 | |
697 | // Iterate over all compute blocks and launch parallel compute operations. |
698 | b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder); |
699 | |
700 | // Call parallel compute function for the first block in the caller thread. |
701 | b.create<func::CallOp>(compute.getSymName(), compute.getResultTypes(), |
702 | computeFuncOperands(c0)); |
703 | |
704 | // Wait for the completion of all async compute operations. |
705 | b.create<AwaitAllOp>(group); |
706 | } |
707 | |
708 | LogicalResult |
709 | AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, |
710 | PatternRewriter &rewriter) const { |
711 | // We do not currently support rewrite for parallel op with reductions. |
712 | if (op.getNumReductions() != 0) |
713 | return failure(); |
714 | |
715 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
716 | |
717 | // Computing minTaskSize emits IR and can be implemented as executing a cost |
718 | // model on the body of the scf.parallel. Thus it needs to be computed before |
719 | // the body of the scf.parallel has been manipulated. |
720 | Value minTaskSize = computeMinTaskSize(b, op); |
721 | |
722 | // Make sure that all constants will be inside the parallel operation body to |
723 | // reduce the number of parallel compute function arguments. |
724 | cloneConstantsIntoTheRegion(op.getRegion(), rewriter); |
725 | |
726 | // Compute trip count for each loop induction variable: |
727 | // tripCount = ceil_div(upperBound - lowerBound, step); |
728 | SmallVector<Value> tripCounts(op.getNumLoops()); |
729 | for (size_t i = 0; i < op.getNumLoops(); ++i) { |
730 | auto lb = op.getLowerBound()[i]; |
731 | auto ub = op.getUpperBound()[i]; |
732 | auto step = op.getStep()[i]; |
733 | auto range = b.createOrFold<arith::SubIOp>(ub, lb); |
734 | tripCounts[i] = b.createOrFold<arith::CeilDivSIOp>(range, step); |
735 | } |
736 | |
737 | // Compute a product of trip counts to get the 1-dimensional iteration space |
738 | // for the scf.parallel operation. |
739 | Value tripCount = tripCounts[0]; |
740 | for (size_t i = 1; i < tripCounts.size(); ++i) |
741 | tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]); |
742 | |
743 | // Short circuit no-op parallel loops (zero iterations) that can arise from |
744 | // the memrefs with dynamic dimension(s) equal to zero. |
745 | Value c0 = b.create<arith::ConstantIndexOp>(args: 0); |
746 | Value isZeroIterations = |
747 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0); |
748 | |
749 | // Do absolutely nothing if the trip count is zero. |
750 | auto noOp = [&](OpBuilder &nestedBuilder, Location loc) { |
751 | nestedBuilder.create<scf::YieldOp>(loc); |
752 | }; |
753 | |
754 | // Compute the parallel block size and dispatch concurrent tasks computing |
755 | // results for each block. |
756 | auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) { |
757 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
758 | |
759 | // Collect statically known constants defining the loop nest in the parallel |
760 | // compute function. LLVM can't always push constants across the non-trivial |
761 | // async dispatch call graph, by providing these values explicitly we can |
762 | // choose to build more efficient loop nest, and rely on a better constant |
763 | // folding, loop unrolling and vectorization. |
764 | ParallelComputeFunctionBounds staticBounds = { |
765 | integerConstants(tripCounts), |
766 | integerConstants(op.getLowerBound()), |
767 | integerConstants(op.getUpperBound()), |
768 | integerConstants(op.getStep()), |
769 | }; |
770 | |
771 | // Find how many inner iteration dimensions are statically known, and their |
772 | // product is smaller than the `512`. We align the parallel compute block |
773 | // size by the product of statically known dimensions, so that we can |
774 | // guarantee that the inner loops executes from 0 to the loop trip counts |
775 | // and we can elide dynamic loop boundaries, and give LLVM an opportunity to |
776 | // unroll the loops. The constant `512` is arbitrary, it should depend on |
777 | // how many iterations LLVM will typically decide to unroll. |
778 | static constexpr int64_t maxUnrollableIterations = 512; |
779 | |
780 | // The number of inner loops with statically known number of iterations less |
781 | // than the `maxUnrollableIterations` value. |
782 | int numUnrollableLoops = 0; |
783 | |
784 | auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; }; |
785 | |
786 | SmallVector<int64_t> numIterations(op.getNumLoops()); |
787 | numIterations.back() = getInt(staticBounds.tripCounts.back()); |
788 | |
789 | for (int i = op.getNumLoops() - 2; i >= 0; --i) { |
790 | int64_t tripCount = getInt(staticBounds.tripCounts[i]); |
791 | int64_t innerIterations = numIterations[i + 1]; |
792 | numIterations[i] = tripCount * innerIterations; |
793 | |
794 | // Update the number of inner loops that we can potentially unroll. |
795 | if (innerIterations > 0 && innerIterations <= maxUnrollableIterations) |
796 | numUnrollableLoops++; |
797 | } |
798 | |
799 | Value numWorkerThreadsVal; |
800 | if (numWorkerThreads >= 0) |
801 | numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(args: numWorkerThreads); |
802 | else |
803 | numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>(); |
804 | |
805 | // With large number of threads the value of creating many compute blocks |
806 | // is reduced because the problem typically becomes memory bound. For this |
807 | // reason we scale the number of workers using an equivalent to the |
808 | // following logic: |
809 | // float overshardingFactor = numWorkerThreads <= 4 ? 8.0 |
810 | // : numWorkerThreads <= 8 ? 4.0 |
811 | // : numWorkerThreads <= 16 ? 2.0 |
812 | // : numWorkerThreads <= 32 ? 1.0 |
813 | // : numWorkerThreads <= 64 ? 0.8 |
814 | // : 0.6; |
815 | |
816 | // Pairs of non-inclusive lower end of the bracket and factor that the |
817 | // number of workers needs to be scaled with if it falls in that bucket. |
818 | const SmallVector<std::pair<int, float>> overshardingBrackets = { |
819 | {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}}; |
820 | const float initialOvershardingFactor = 8.0f; |
821 | |
822 | Value scalingFactor = b.create<arith::ConstantFloatOp>( |
823 | args: llvm::APFloat(initialOvershardingFactor), args: b.getF32Type()); |
824 | for (const std::pair<int, float> &p : overshardingBrackets) { |
825 | Value bracketBegin = b.create<arith::ConstantIndexOp>(args: p.first); |
826 | Value inBracket = b.create<arith::CmpIOp>( |
827 | arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); |
828 | Value bracketScalingFactor = b.create<arith::ConstantFloatOp>( |
829 | args: llvm::APFloat(p.second), args: b.getF32Type()); |
830 | scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor, |
831 | scalingFactor); |
832 | } |
833 | Value numWorkersIndex = |
834 | b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal); |
835 | Value numWorkersFloat = |
836 | b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex); |
837 | Value scaledNumWorkers = |
838 | b.create<arith::MulFOp>(scalingFactor, numWorkersFloat); |
839 | Value scaledNumInt = |
840 | b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers); |
841 | Value scaledWorkers = |
842 | b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt); |
843 | |
844 | Value maxComputeBlocks = b.create<arith::MaxSIOp>( |
845 | b.create<arith::ConstantIndexOp>(1), scaledWorkers); |
846 | |
847 | // Compute parallel block size from the parallel problem size: |
848 | // blockSize = min(tripCount, |
849 | // max(ceil_div(tripCount, maxComputeBlocks), |
850 | // minTaskSize)) |
851 | Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); |
852 | Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize); |
853 | Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); |
854 | |
855 | // Dispatch parallel compute function using async recursive work splitting, |
856 | // or by submitting compute task sequentially from a caller thread. |
857 | auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch; |
858 | |
859 | // Create a parallel compute function that takes a block id and computes |
860 | // the parallel operation body for a subset of iteration space. |
861 | |
862 | // Compute the number of parallel compute blocks. |
863 | Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); |
864 | |
865 | // Dispatch parallel compute function without hints to unroll inner loops. |
866 | auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) { |
867 | ParallelComputeFunction compute = |
868 | createParallelComputeFunction(op, staticBounds, 0, rewriter); |
869 | |
870 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
871 | doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts); |
872 | b.create<scf::YieldOp>(); |
873 | }; |
874 | |
875 | // Dispatch parallel compute function with hints for unrolling inner loops. |
876 | auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) { |
877 | ParallelComputeFunction compute = createParallelComputeFunction( |
878 | op, staticBounds, numUnrollableLoops, rewriter); |
879 | |
880 | ImplicitLocOpBuilder b(loc, nestedBuilder); |
881 | // Align the block size to be a multiple of the statically known |
882 | // number of iterations in the inner loops. |
883 | Value numIters = b.create<arith::ConstantIndexOp>( |
884 | numIterations[op.getNumLoops() - numUnrollableLoops]); |
885 | Value alignedBlockSize = b.create<arith::MulIOp>( |
886 | b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters); |
887 | doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount, |
888 | tripCounts); |
889 | b.create<scf::YieldOp>(); |
890 | }; |
891 | |
892 | // Dispatch to block aligned compute function only if the computed block |
893 | // size is larger than the number of iterations in the unrollable inner |
894 | // loops, because otherwise it can reduce the available parallelism. |
895 | if (numUnrollableLoops > 0) { |
896 | Value numIters = b.create<arith::ConstantIndexOp>( |
897 | numIterations[op.getNumLoops() - numUnrollableLoops]); |
898 | Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>( |
899 | arith::CmpIPredicate::sge, blockSize, numIters); |
900 | |
901 | b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned, |
902 | dispatchDefault); |
903 | b.create<scf::YieldOp>(); |
904 | } else { |
905 | dispatchDefault(b, loc); |
906 | } |
907 | }; |
908 | |
909 | // Replace the `scf.parallel` operation with the parallel compute function. |
910 | b.create<scf::IfOp>(isZeroIterations, noOp, dispatch); |
911 | |
912 | // Parallel operation was replaced with a block iteration loop. |
913 | rewriter.eraseOp(op: op); |
914 | |
915 | return success(); |
916 | } |
917 | |
918 | void AsyncParallelForPass::runOnOperation() { |
919 | MLIRContext *ctx = &getContext(); |
920 | |
921 | RewritePatternSet patterns(ctx); |
922 | populateAsyncParallelForPatterns( |
923 | patterns, asyncDispatch, numWorkerThreads, |
924 | [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { |
925 | return builder.create<arith::ConstantIndexOp>(minTaskSize); |
926 | }); |
927 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
928 | signalPassFailure(); |
929 | } |
930 | |
931 | void mlir::async::populateAsyncParallelForPatterns( |
932 | RewritePatternSet &patterns, bool asyncDispatch, int32_t numWorkerThreads, |
933 | const AsyncMinTaskSizeComputationFunction &computeMinTaskSize) { |
934 | MLIRContext *ctx = patterns.getContext(); |
935 | patterns.add<AsyncParallelForRewrite>(arg&: ctx, args&: asyncDispatch, args&: numWorkerThreads, |
936 | args: computeMinTaskSize); |
937 | } |
938 |
Definitions
- AsyncParallelForPass
- AsyncParallelForRewrite
- AsyncParallelForRewrite
- ParallelComputeFunctionType
- ParallelComputeFunctionArgs
- ParallelComputeFunctionBounds
- ParallelComputeFunction
- blockIndex
- blockSize
- tripCounts
- lowerBounds
- steps
- captures
- integerConstants
- delinearize
- getParallelComputeFunctionType
- createParallelComputeFunction
- createAsyncDispatchFunction
- doAsyncDispatch
- doSequentialDispatch
- matchAndRewrite
- runOnOperation
Improve your Profiling and Debugging skills
Find out more