1 | // Copyright 2009-2021 Intel Corporation |
2 | // SPDX-License-Identifier: Apache-2.0 |
3 | |
4 | #pragma once |
5 | |
6 | #include "../sys/platform.h" |
7 | #include "../sys/alloc.h" |
8 | #include "../sys/barrier.h" |
9 | #include "../sys/thread.h" |
10 | #include "../sys/mutex.h" |
11 | #include "../sys/condition.h" |
12 | #include "../sys/ref.h" |
13 | #include "../sys/atomic.h" |
14 | #include "../math/range.h" |
15 | #include "../../include/embree3/rtcore.h" |
16 | |
17 | #include <list> |
18 | |
19 | namespace embree |
20 | { |
21 | |
22 | /* The tasking system exports some symbols to be used by the tutorials. Thus we |
23 | hide is also in the API namespace when requested. */ |
24 | RTC_NAMESPACE_BEGIN |
25 | |
26 | struct TaskScheduler : public RefCount |
27 | { |
28 | ALIGNED_STRUCT_(64); |
29 | friend class Device; |
30 | |
31 | static const size_t TASK_STACK_SIZE = 4*1024; //!< task structure stack |
32 | static const size_t CLOSURE_STACK_SIZE = 512*1024; //!< stack for task closures |
33 | |
34 | struct Thread; |
35 | |
36 | /*! virtual interface for all tasks */ |
37 | struct TaskFunction { |
38 | virtual void execute() = 0; |
39 | }; |
40 | |
41 | /*! builds a task interface from a closure */ |
42 | template<typename Closure> |
43 | struct ClosureTaskFunction : public TaskFunction |
44 | { |
45 | Closure closure; |
46 | __forceinline ClosureTaskFunction (const Closure& closure) : closure(closure) {} |
47 | void execute() { closure(); }; |
48 | }; |
49 | |
50 | struct __aligned(64) Task |
51 | { |
52 | /*! states a task can be in */ |
53 | enum { DONE, INITIALIZED }; |
54 | |
55 | /*! switch from one state to another */ |
56 | __forceinline void switch_state(int from, int to) |
57 | { |
58 | __memory_barrier(); |
59 | MAYBE_UNUSED bool success = state.compare_exchange_strong(i1&: from,i2: to); |
60 | assert(success); |
61 | } |
62 | |
63 | /*! try to switch from one state to another */ |
64 | __forceinline bool try_switch_state(int from, int to) { |
65 | __memory_barrier(); |
66 | return state.compare_exchange_strong(i1&: from,i2: to); |
67 | } |
68 | |
69 | /*! increment/decrement dependency counter */ |
70 | void add_dependencies(int n) { |
71 | dependencies+=n; |
72 | } |
73 | |
74 | /*! initialize all tasks to DONE state by default */ |
75 | __forceinline Task() |
76 | : state(DONE) {} |
77 | |
78 | /*! construction of new task */ |
79 | __forceinline Task (TaskFunction* closure, Task* parent, size_t stackPtr, size_t N) |
80 | : dependencies(1), stealable(true), closure(closure), parent(parent), stackPtr(stackPtr), N(N) |
81 | { |
82 | if (parent) parent->add_dependencies(n: +1); |
83 | switch_state(from: DONE,to: INITIALIZED); |
84 | } |
85 | |
86 | /*! construction of stolen task, stealing thread will decrement initial dependency */ |
87 | __forceinline Task (TaskFunction* closure, Task* parent) |
88 | : dependencies(1), stealable(false), closure(closure), parent(parent), stackPtr(-1), N(1) |
89 | { |
90 | switch_state(from: DONE,to: INITIALIZED); |
91 | } |
92 | |
93 | /*! try to steal this task */ |
94 | bool try_steal(Task& child) |
95 | { |
96 | if (!stealable) return false; |
97 | if (!try_switch_state(from: INITIALIZED,to: DONE)) return false; |
98 | new (&child) Task(closure, this); |
99 | return true; |
100 | } |
101 | |
102 | /*! run this task */ |
103 | dll_export void run(Thread& thread); |
104 | |
105 | void run_internal(Thread& thread); |
106 | |
107 | public: |
108 | std::atomic<int> state; //!< state this task is in |
109 | std::atomic<int> dependencies; //!< dependencies to wait for |
110 | std::atomic<bool> stealable; //!< true if task can be stolen |
111 | TaskFunction* closure; //!< the closure to execute |
112 | Task* parent; //!< parent task to signal when we are finished |
113 | size_t stackPtr; //!< stack location where closure is stored |
114 | size_t N; //!< approximative size of task |
115 | }; |
116 | |
117 | struct TaskQueue |
118 | { |
119 | TaskQueue () |
120 | : left(0), right(0), stackPtr(0) {} |
121 | |
122 | __forceinline void* alloc(size_t bytes, size_t align = 64) |
123 | { |
124 | size_t ofs = bytes + ((align - stackPtr) & (align-1)); |
125 | if (stackPtr + ofs > CLOSURE_STACK_SIZE) |
126 | throw std::runtime_error("closure stack overflow" ); |
127 | stackPtr += ofs; |
128 | return &stack[stackPtr-bytes]; |
129 | } |
130 | |
131 | template<typename Closure> |
132 | __forceinline void push_right(Thread& thread, const size_t size, const Closure& closure) |
133 | { |
134 | if (right >= TASK_STACK_SIZE) |
135 | throw std::runtime_error("task stack overflow" ); |
136 | |
137 | /* allocate new task on right side of stack */ |
138 | size_t oldStackPtr = stackPtr; |
139 | TaskFunction* func = new (alloc(bytes: sizeof(ClosureTaskFunction<Closure>))) ClosureTaskFunction<Closure>(closure); |
140 | new (&tasks[right]) Task(func,thread.task,oldStackPtr,size); |
141 | right++; |
142 | |
143 | /* also move left pointer */ |
144 | if (left >= right-1) left = right-1; |
145 | } |
146 | |
147 | dll_export bool execute_local(Thread& thread, Task* parent); |
148 | bool execute_local_internal(Thread& thread, Task* parent); |
149 | bool steal(Thread& thread); |
150 | size_t getTaskSizeAtLeft(); |
151 | |
152 | bool empty() { return right == 0; } |
153 | |
154 | public: |
155 | |
156 | /* task stack */ |
157 | Task tasks[TASK_STACK_SIZE]; |
158 | __aligned(64) std::atomic<size_t> left; //!< threads steal from left |
159 | __aligned(64) std::atomic<size_t> right; //!< new tasks are added to the right |
160 | |
161 | /* closure stack */ |
162 | __aligned(64) char stack[CLOSURE_STACK_SIZE]; |
163 | size_t stackPtr; |
164 | }; |
165 | |
166 | /*! thread local structure for each thread */ |
167 | struct Thread |
168 | { |
169 | ALIGNED_STRUCT_(64); |
170 | |
171 | Thread (size_t threadIndex, const Ref<TaskScheduler>& scheduler) |
172 | : threadIndex(threadIndex), task(nullptr), scheduler(scheduler) {} |
173 | |
174 | __forceinline size_t threadCount() { |
175 | return scheduler->threadCounter; |
176 | } |
177 | |
178 | size_t threadIndex; //!< ID of this thread |
179 | TaskQueue tasks; //!< local task queue |
180 | Task* task; //!< current active task |
181 | Ref<TaskScheduler> scheduler; //!< pointer to task scheduler |
182 | }; |
183 | |
184 | /*! pool of worker threads */ |
185 | struct ThreadPool |
186 | { |
187 | ThreadPool (bool set_affinity); |
188 | ~ThreadPool (); |
189 | |
190 | /*! starts the threads */ |
191 | dll_export void startThreads(); |
192 | |
193 | /*! sets number of threads to use */ |
194 | void setNumThreads(size_t numThreads, bool startThreads = false); |
195 | |
196 | /*! adds a task scheduler object for scheduling */ |
197 | dll_export void add(const Ref<TaskScheduler>& scheduler); |
198 | |
199 | /*! remove the task scheduler object again */ |
200 | dll_export void remove(const Ref<TaskScheduler>& scheduler); |
201 | |
202 | /*! returns number of threads of the thread pool */ |
203 | size_t size() const { return numThreads; } |
204 | |
205 | /*! main loop for all threads */ |
206 | void thread_loop(size_t threadIndex); |
207 | |
208 | private: |
209 | std::atomic<size_t> numThreads; |
210 | std::atomic<size_t> numThreadsRunning; |
211 | bool set_affinity; |
212 | std::atomic<bool> running; |
213 | std::vector<thread_t> threads; |
214 | |
215 | private: |
216 | MutexSys mutex; |
217 | ConditionSys condition; |
218 | std::list<Ref<TaskScheduler> > schedulers; |
219 | }; |
220 | |
221 | TaskScheduler (); |
222 | ~TaskScheduler (); |
223 | |
224 | /*! initializes the task scheduler */ |
225 | static void create(size_t numThreads, bool set_affinity, bool start_threads); |
226 | |
227 | /*! destroys the task scheduler again */ |
228 | static void destroy(); |
229 | |
230 | /*! lets new worker threads join the tasking system */ |
231 | void join(); |
232 | void reset(); |
233 | |
234 | /*! let a worker thread allocate a thread index */ |
235 | dll_export ssize_t allocThreadIndex(); |
236 | |
237 | /*! wait for some number of threads available (threadCount includes main thread) */ |
238 | void wait_for_threads(size_t threadCount); |
239 | |
240 | /*! thread loop for all worker threads */ |
241 | std::exception_ptr thread_loop(size_t threadIndex); |
242 | |
243 | /*! steals a task from a different thread */ |
244 | bool steal_from_other_threads(Thread& thread); |
245 | |
246 | template<typename Predicate, typename Body> |
247 | static void steal_loop(Thread& thread, const Predicate& pred, const Body& body); |
248 | |
249 | /* spawn a new task at the top of the threads task stack */ |
250 | template<typename Closure> |
251 | void spawn_root(const Closure& closure, size_t size = 1, bool useThreadPool = true) |
252 | { |
253 | if (useThreadPool) startThreads(); |
254 | |
255 | size_t threadIndex = allocThreadIndex(); |
256 | std::unique_ptr<Thread> mthread(new Thread(threadIndex,this)); // too large for stack allocation |
257 | Thread& thread = *mthread; |
258 | assert(threadLocal[threadIndex].load() == nullptr); |
259 | threadLocal[threadIndex] = &thread; |
260 | Thread* oldThread = swapThread(thread: &thread); |
261 | thread.tasks.push_right(thread,size,closure); |
262 | { |
263 | Lock<MutexSys> lock(mutex); |
264 | anyTasksRunning++; |
265 | hasRootTask = true; |
266 | condition.notify_all(); |
267 | } |
268 | |
269 | if (useThreadPool) addScheduler(scheduler: this); |
270 | |
271 | while (thread.tasks.execute_local(thread,parent: nullptr)); |
272 | anyTasksRunning--; |
273 | if (useThreadPool) removeScheduler(scheduler: this); |
274 | |
275 | threadLocal[threadIndex] = nullptr; |
276 | swapThread(thread: oldThread); |
277 | |
278 | /* remember exception to throw */ |
279 | std::exception_ptr except = nullptr; |
280 | if (cancellingException != nullptr) except = cancellingException; |
281 | |
282 | /* wait for all threads to terminate */ |
283 | threadCounter--; |
284 | while (threadCounter > 0) yield(); |
285 | cancellingException = nullptr; |
286 | |
287 | /* re-throw proper exception */ |
288 | if (except != nullptr) |
289 | std::rethrow_exception(except); |
290 | } |
291 | |
292 | /* spawn a new task at the top of the threads task stack */ |
293 | template<typename Closure> |
294 | static __forceinline void spawn(size_t size, const Closure& closure) |
295 | { |
296 | Thread* thread = TaskScheduler::thread(); |
297 | if (likely(thread != nullptr)) thread->tasks.push_right(*thread,size,closure); |
298 | else instance()->spawn_root(closure,size); |
299 | } |
300 | |
301 | /* spawn a new task at the top of the threads task stack */ |
302 | template<typename Closure> |
303 | static __forceinline void spawn(const Closure& closure) { |
304 | spawn(1,closure); |
305 | } |
306 | |
307 | /* spawn a new task set */ |
308 | template<typename Index, typename Closure> |
309 | static void spawn(const Index begin, const Index end, const Index blockSize, const Closure& closure) |
310 | { |
311 | spawn(end-begin, [=]() |
312 | { |
313 | if (end-begin <= blockSize) { |
314 | return closure(range<Index>(begin,end)); |
315 | } |
316 | const Index center = (begin+end)/2; |
317 | spawn(begin,center,blockSize,closure); |
318 | spawn(center,end ,blockSize,closure); |
319 | wait(); |
320 | }); |
321 | } |
322 | |
323 | /* work on spawned subtasks and wait until all have finished */ |
324 | dll_export static bool wait(); |
325 | |
326 | /* returns the ID of the current thread */ |
327 | dll_export static size_t threadID(); |
328 | |
329 | /* returns the index (0..threadCount-1) of the current thread */ |
330 | dll_export static size_t threadIndex(); |
331 | |
332 | /* returns the total number of threads */ |
333 | dll_export static size_t threadCount(); |
334 | |
335 | private: |
336 | |
337 | /* returns the thread local task list of this worker thread */ |
338 | dll_export static Thread* thread(); |
339 | |
340 | /* sets the thread local task list of this worker thread */ |
341 | dll_export static Thread* swapThread(Thread* thread); |
342 | |
343 | /*! returns the taskscheduler object to be used by the master thread */ |
344 | dll_export static TaskScheduler* instance(); |
345 | |
346 | /*! starts the threads */ |
347 | dll_export static void startThreads(); |
348 | |
349 | /*! adds a task scheduler object for scheduling */ |
350 | dll_export static void addScheduler(const Ref<TaskScheduler>& scheduler); |
351 | |
352 | /*! remove the task scheduler object again */ |
353 | dll_export static void removeScheduler(const Ref<TaskScheduler>& scheduler); |
354 | |
355 | private: |
356 | std::vector<atomic<Thread*>> threadLocal; |
357 | std::atomic<size_t> threadCounter; |
358 | std::atomic<size_t> anyTasksRunning; |
359 | std::atomic<bool> hasRootTask; |
360 | std::exception_ptr cancellingException; |
361 | MutexSys mutex; |
362 | ConditionSys condition; |
363 | |
364 | private: |
365 | static size_t g_numThreads; |
366 | static __thread TaskScheduler* g_instance; |
367 | static __thread Thread* thread_local_thread; |
368 | static ThreadPool* threadPool; |
369 | }; |
370 | |
371 | RTC_NAMESPACE_END |
372 | |
373 | #if defined(RTC_NAMESPACE) |
374 | using RTC_NAMESPACE::TaskScheduler; |
375 | #endif |
376 | } |
377 | |