1 | //===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines various utilies for multithreaded processing within MLIR. |
10 | // These utilities automatically handle many of the necessary threading |
11 | // conditions, such as properly ordering diagnostics, observing if threading is |
12 | // disabled, etc. These utilities should be used over other threading utilities |
13 | // whenever feasible. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #ifndef MLIR_IR_THREADING_H |
18 | #define MLIR_IR_THREADING_H |
19 | |
20 | #include "mlir/IR/Diagnostics.h" |
21 | #include "llvm/ADT/Sequence.h" |
22 | #include "llvm/Support/ThreadPool.h" |
23 | #include <atomic> |
24 | |
25 | namespace mlir { |
26 | |
27 | /// Invoke the given function on the elements between [begin, end) |
28 | /// asynchronously. If the given function returns a failure when processing any |
29 | /// of the elements, execution is stopped and a failure is returned from this |
30 | /// function. This means that in the case of failure, not all elements of the |
31 | /// range will be processed. Diagnostics emitted during processing are ordered |
32 | /// relative to the element's position within [begin, end). If the provided |
33 | /// context does not have multi-threading enabled, this function always |
34 | /// processes elements sequentially. |
35 | template <typename IteratorT, typename FuncT> |
36 | LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, |
37 | IteratorT end, FuncT &&func) { |
38 | unsigned numElements = static_cast<unsigned>(std::distance(begin, end)); |
39 | if (numElements == 0) |
40 | return success(); |
41 | |
42 | // If multithreading is disabled or there is a small number of elements, |
43 | // process the elements directly on this thread. |
44 | if (!context->isMultithreadingEnabled() || numElements <= 1) { |
45 | for (; begin != end; ++begin) |
46 | if (failed(func(*begin))) |
47 | return failure(); |
48 | return success(); |
49 | } |
50 | |
51 | // Build a wrapper processing function that properly initializes a parallel |
52 | // diagnostic handler. |
53 | ParallelDiagnosticHandler handler(context); |
54 | std::atomic<unsigned> curIndex(0); |
55 | std::atomic<bool> processingFailed(false); |
56 | auto processFn = [&] { |
57 | while (!processingFailed) { |
58 | unsigned index = curIndex++; |
59 | if (index >= numElements) |
60 | break; |
61 | handler.setOrderIDForThread(index); |
62 | if (failed(func(*std::next(begin, index)))) |
63 | processingFailed = true; |
64 | handler.eraseOrderIDForThread(); |
65 | } |
66 | }; |
67 | |
68 | // Otherwise, process the elements in parallel. |
69 | llvm::ThreadPoolInterface &threadPool = context->getThreadPool(); |
70 | llvm::ThreadPoolTaskGroup tasksGroup(threadPool); |
71 | size_t numActions = std::min(a: numElements, b: threadPool.getMaxConcurrency()); |
72 | for (unsigned i = 0; i < numActions; ++i) |
73 | tasksGroup.async(processFn); |
74 | // If the current thread is a worker thread from the pool, then waiting for |
75 | // the task group allows the current thread to also participate in processing |
76 | // tasks from the group, which avoid any deadlock/starvation. |
77 | tasksGroup.wait(); |
78 | return failure(isFailure: processingFailed); |
79 | } |
80 | |
81 | /// Invoke the given function on the elements in the provided range |
82 | /// asynchronously. If the given function returns a failure when processing any |
83 | /// of the elements, execution is stopped and a failure is returned from this |
84 | /// function. This means that in the case of failure, not all elements of the |
85 | /// range will be processed. Diagnostics emitted during processing are ordered |
86 | /// relative to the element's position within the range. If the provided context |
87 | /// does not have multi-threading enabled, this function always processes |
88 | /// elements sequentially. |
89 | template <typename RangeT, typename FuncT> |
90 | LogicalResult failableParallelForEach(MLIRContext *context, RangeT &&range, |
91 | FuncT &&func) { |
92 | return failableParallelForEach(context, std::begin(range), std::end(range), |
93 | std::forward<FuncT>(func)); |
94 | } |
95 | |
96 | /// Invoke the given function on the elements between [begin, end) |
97 | /// asynchronously. If the given function returns a failure when processing any |
98 | /// of the elements, execution is stopped and a failure is returned from this |
99 | /// function. This means that in the case of failure, not all elements of the |
100 | /// range will be processed. Diagnostics emitted during processing are ordered |
101 | /// relative to the element's position within [begin, end). If the provided |
102 | /// context does not have multi-threading enabled, this function always |
103 | /// processes elements sequentially. |
104 | template <typename FuncT> |
105 | LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin, |
106 | size_t end, FuncT &&func) { |
107 | return failableParallelForEach(context, llvm::seq(Begin: begin, End: end), |
108 | std::forward<FuncT>(func)); |
109 | } |
110 | |
111 | /// Invoke the given function on the elements between [begin, end) |
112 | /// asynchronously. Diagnostics emitted during processing are ordered relative |
113 | /// to the element's position within [begin, end). If the provided context does |
114 | /// not have multi-threading enabled, this function always processes elements |
115 | /// sequentially. |
116 | template <typename IteratorT, typename FuncT> |
117 | void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end, |
118 | FuncT &&func) { |
119 | (void)failableParallelForEach(context, begin, end, [&](auto &&value) { |
120 | return func(std::forward<decltype(value)>(value)), success(); |
121 | }); |
122 | } |
123 | |
124 | /// Invoke the given function on the elements in the provided range |
125 | /// asynchronously. Diagnostics emitted during processing are ordered relative |
126 | /// to the element's position within the range. If the provided context does not |
127 | /// have multi-threading enabled, this function always processes elements |
128 | /// sequentially. |
129 | template <typename RangeT, typename FuncT> |
130 | void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) { |
131 | parallelForEach(context, std::begin(range), std::end(range), |
132 | std::forward<FuncT>(func)); |
133 | } |
134 | |
135 | /// Invoke the given function on the elements between [begin, end) |
136 | /// asynchronously. Diagnostics emitted during processing are ordered relative |
137 | /// to the element's position within [begin, end). If the provided context does |
138 | /// not have multi-threading enabled, this function always processes elements |
139 | /// sequentially. |
140 | template <typename FuncT> |
141 | void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func) { |
142 | parallelForEach(context, llvm::seq(Begin: begin, End: end), std::forward<FuncT>(func)); |
143 | } |
144 | |
145 | } // namespace mlir |
146 | |
147 | #endif // MLIR_IR_THREADING_H |
148 | |