1//===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines various utilies for multithreaded processing within MLIR.
10// These utilities automatically handle many of the necessary threading
11// conditions, such as properly ordering diagnostics, observing if threading is
12// disabled, etc. These utilities should be used over other threading utilities
13// whenever feasible.
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef MLIR_IR_THREADING_H
18#define MLIR_IR_THREADING_H
19
20#include "mlir/IR/Diagnostics.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/Support/ThreadPool.h"
23#include <atomic>
24
25namespace mlir {
26
27/// Invoke the given function on the elements between [begin, end)
28/// asynchronously. If the given function returns a failure when processing any
29/// of the elements, execution is stopped and a failure is returned from this
30/// function. This means that in the case of failure, not all elements of the
31/// range will be processed. Diagnostics emitted during processing are ordered
32/// relative to the element's position within [begin, end). If the provided
33/// context does not have multi-threading enabled, this function always
34/// processes elements sequentially.
35template <typename IteratorT, typename FuncT>
36LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
37 IteratorT end, FuncT &&func) {
38 unsigned numElements = static_cast<unsigned>(std::distance(begin, end));
39 if (numElements == 0)
40 return success();
41
42 // If multithreading is disabled or there is a small number of elements,
43 // process the elements directly on this thread.
44 if (!context->isMultithreadingEnabled() || numElements <= 1) {
45 for (; begin != end; ++begin)
46 if (failed(func(*begin)))
47 return failure();
48 return success();
49 }
50
51 // Build a wrapper processing function that properly initializes a parallel
52 // diagnostic handler.
53 ParallelDiagnosticHandler handler(context);
54 std::atomic<unsigned> curIndex(0);
55 std::atomic<bool> processingFailed(false);
56 auto processFn = [&] {
57 while (!processingFailed) {
58 unsigned index = curIndex++;
59 if (index >= numElements)
60 break;
61 handler.setOrderIDForThread(index);
62 if (failed(func(*std::next(begin, index))))
63 processingFailed = true;
64 handler.eraseOrderIDForThread();
65 }
66 };
67
68 // Otherwise, process the elements in parallel.
69 llvm::ThreadPoolInterface &threadPool = context->getThreadPool();
70 llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
71 size_t numActions = std::min(a: numElements, b: threadPool.getMaxConcurrency());
72 for (unsigned i = 0; i < numActions; ++i)
73 tasksGroup.async(processFn);
74 // If the current thread is a worker thread from the pool, then waiting for
75 // the task group allows the current thread to also participate in processing
76 // tasks from the group, which avoid any deadlock/starvation.
77 tasksGroup.wait();
78 return failure(isFailure: processingFailed);
79}
80
81/// Invoke the given function on the elements in the provided range
82/// asynchronously. If the given function returns a failure when processing any
83/// of the elements, execution is stopped and a failure is returned from this
84/// function. This means that in the case of failure, not all elements of the
85/// range will be processed. Diagnostics emitted during processing are ordered
86/// relative to the element's position within the range. If the provided context
87/// does not have multi-threading enabled, this function always processes
88/// elements sequentially.
89template <typename RangeT, typename FuncT>
90LogicalResult failableParallelForEach(MLIRContext *context, RangeT &&range,
91 FuncT &&func) {
92 return failableParallelForEach(context, std::begin(range), std::end(range),
93 std::forward<FuncT>(func));
94}
95
96/// Invoke the given function on the elements between [begin, end)
97/// asynchronously. If the given function returns a failure when processing any
98/// of the elements, execution is stopped and a failure is returned from this
99/// function. This means that in the case of failure, not all elements of the
100/// range will be processed. Diagnostics emitted during processing are ordered
101/// relative to the element's position within [begin, end). If the provided
102/// context does not have multi-threading enabled, this function always
103/// processes elements sequentially.
104template <typename FuncT>
105LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin,
106 size_t end, FuncT &&func) {
107 return failableParallelForEach(context, llvm::seq(Begin: begin, End: end),
108 std::forward<FuncT>(func));
109}
110
111/// Invoke the given function on the elements between [begin, end)
112/// asynchronously. Diagnostics emitted during processing are ordered relative
113/// to the element's position within [begin, end). If the provided context does
114/// not have multi-threading enabled, this function always processes elements
115/// sequentially.
116template <typename IteratorT, typename FuncT>
117void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end,
118 FuncT &&func) {
119 (void)failableParallelForEach(context, begin, end, [&](auto &&value) {
120 return func(std::forward<decltype(value)>(value)), success();
121 });
122}
123
124/// Invoke the given function on the elements in the provided range
125/// asynchronously. Diagnostics emitted during processing are ordered relative
126/// to the element's position within the range. If the provided context does not
127/// have multi-threading enabled, this function always processes elements
128/// sequentially.
129template <typename RangeT, typename FuncT>
130void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) {
131 parallelForEach(context, std::begin(range), std::end(range),
132 std::forward<FuncT>(func));
133}
134
135/// Invoke the given function on the elements between [begin, end)
136/// asynchronously. Diagnostics emitted during processing are ordered relative
137/// to the element's position within [begin, end). If the provided context does
138/// not have multi-threading enabled, this function always processes elements
139/// sequentially.
140template <typename FuncT>
141void parallelFor(MLIRContext *context, size_t begin, size_t end, FuncT &&func) {
142 parallelForEach(context, llvm::seq(Begin: begin, End: end), std::forward<FuncT>(func));
143}
144
145} // namespace mlir
146
147#endif // MLIR_IR_THREADING_H
148

source code of mlir/include/mlir/IR/Threading.h