1//
2// Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net)
3//
4// Distributed under the Boost Software License, Version 1.0. (See accompanying
5// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6//
7
8#ifndef BOOST_COBALT_DETAIL_JOIN_HPP
9#define BOOST_COBALT_DETAIL_JOIN_HPP
10
11#include <boost/cobalt/detail/await_result_helper.hpp>
12#include <boost/cobalt/detail/exception.hpp>
13#include <boost/cobalt/detail/fork.hpp>
14#include <boost/cobalt/detail/forward_cancellation.hpp>
15#include <boost/cobalt/detail/util.hpp>
16#include <boost/cobalt/detail/wrapper.hpp>
17#include <boost/cobalt/task.hpp>
18#include <boost/cobalt/this_thread.hpp>
19
20#include <boost/asio/associated_cancellation_slot.hpp>
21#include <boost/asio/bind_cancellation_slot.hpp>
22#include <boost/asio/cancellation_signal.hpp>
23
24
25#include <boost/core/ignore_unused.hpp>
26#include <boost/intrusive_ptr.hpp>
27#include <boost/system/result.hpp>
28#include <boost/variant2/variant.hpp>
29
30#include <array>
31#include <coroutine>
32#include <algorithm>
33
34namespace boost::cobalt::detail
35{
36
37template<typename ... Args>
38struct join_variadic_impl
39{
40 using tuple_type = std::tuple<decltype(get_awaitable_type(std::declval<Args&&>()))...>;
41
42 join_variadic_impl(Args && ... args)
43 : args{std::forward<Args>(args)...}
44 {
45 }
46
47 std::tuple<Args...> args;
48
49 constexpr static std::size_t tuple_size = sizeof...(Args);
50
51 struct awaitable : fork::static_shared_state<256 * tuple_size>
52 {
53 template<std::size_t ... Idx>
54 awaitable(std::tuple<Args...> & args, std::index_sequence<Idx...>) :
55 aws(awaitable_type_getter<Args>(std::get<Idx>(args))...)
56 {
57 }
58
59 tuple_type aws;
60
61 std::array<asio::cancellation_signal, tuple_size> cancel_;
62 template<typename > constexpr static auto make_null() {return nullptr;};
63 std::array<asio::cancellation_signal*, tuple_size> cancel = {make_null<Args>()...};
64
65 constexpr static bool all_void = (std::is_void_v<co_await_result_t<Args>> && ...);
66 template<typename T>
67 using result_store_part =
68 std::optional<void_as_monostate<co_await_result_t<T>>>;
69
70 std::conditional_t<all_void,
71 variant2::monostate,
72 std::tuple<result_store_part<Args>...>> result;
73 std::exception_ptr error;
74
75 template<std::size_t Idx>
76 void cancel_step()
77 {
78 auto &r = cancel[Idx];
79 if (r)
80 std::exchange(r, nullptr)->emit(asio::cancellation_type::all);
81 }
82
83 void cancel_all()
84 {
85 mp11::mp_for_each<mp11::mp_iota_c<sizeof...(Args)>>
86 ([&](auto idx)
87 {
88 cancel_step<idx>();
89 });
90 }
91
92
93
94 template<std::size_t Idx>
95 void interrupt_await_step()
96 {
97 using type = std::tuple_element_t<Idx, tuple_type>;
98 using t = std::conditional_t<std::is_reference_v<std::tuple_element_t<Idx, std::tuple<Args...>>>,
99 type &,
100 type &&>;
101
102 if constexpr (interruptible<t>)
103 if (this->cancel[Idx] != nullptr)
104 static_cast<t>(std::get<Idx>(aws)).interrupt_await();
105 }
106
107 void interrupt_await()
108 {
109 mp11::mp_for_each<mp11::mp_iota_c<sizeof...(Args)>>
110 ([&](auto idx)
111 {
112 interrupt_await_step<idx>();
113 });
114 }
115
116
117 // GCC doesn't like member funs
118 template<std::size_t Idx>
119 static detail::fork await_impl(awaitable & this_)
120 try
121 {
122 auto & aw = std::get<Idx>(this_.aws);
123 // check manually if we're ready
124 auto rd = aw.await_ready();
125 if (!rd)
126 {
127 this_.cancel[Idx] = &this_.cancel_[Idx];
128 co_await this_.cancel[Idx]->slot();
129 // make sure the executor is set
130 co_await detail::fork::wired_up;
131 // do the await - this doesn't call await-ready again
132
133 if constexpr (std::is_void_v<decltype(aw.await_resume())>)
134 {
135 co_await aw;
136 if constexpr (!all_void)
137 std::get<Idx>(this_.result).emplace();
138 }
139 else
140 std::get<Idx>(this_.result).emplace(co_await aw);
141 }
142 else
143 {
144 if constexpr (std::is_void_v<decltype(aw.await_resume())>)
145 {
146 aw.await_resume();
147 if constexpr (!all_void)
148 std::get<Idx>(this_.result).emplace();
149 }
150 else
151 std::get<Idx>(this_.result).emplace(aw.await_resume());
152 }
153
154 }
155 catch(...)
156 {
157 if (!this_.error)
158 this_.error = std::current_exception();
159 this_.cancel_all();
160 }
161
162 std::array<detail::fork(*)(awaitable&), tuple_size> impls {
163 []<std::size_t ... Idx>(std::index_sequence<Idx...>)
164 {
165 return std::array<detail::fork(*)(awaitable&), tuple_size>{&await_impl<Idx>...};
166 }(std::make_index_sequence<tuple_size>{})
167 };
168
169 detail::fork last_forked;
170 std::size_t last_index = 0u;
171
172 bool await_ready()
173 {
174 while (last_index < tuple_size)
175 {
176 last_forked = impls[last_index++](*this);
177 if (!last_forked.done())
178 return false; // one coro didn't immediately complete!
179 }
180 last_forked.release();
181 return true;
182 }
183
184 template<typename H>
185 auto await_suspend(
186 std::coroutine_handle<H> h
187#if defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
188 , const boost::source_location & loc = BOOST_CURRENT_LOCATION
189#endif
190 )
191 {
192#if defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
193 this->loc = loc;
194#endif
195 this->exec = &detail::get_executor(h);
196 last_forked.release().resume();
197 while (last_index < tuple_size)
198 impls[last_index++](*this).release();
199
200 if (error)
201 cancel_all();
202
203 if (!this->outstanding_work()) // already done, resume rightaway.
204 return false;
205
206 // arm the cancel
207 assign_cancellation(
208 h,
209 [&](asio::cancellation_type ct)
210 {
211 for (auto cs : cancel)
212 if (cs)
213 cs->emit(ct);
214 });
215
216 this->coro.reset(h.address());
217 return true;
218 }
219
220#if _MSC_VER
221 BOOST_NOINLINE
222#endif
223 auto await_resume()
224 {
225 if (error)
226 std::rethrow_exception(error);
227 if constexpr(!all_void)
228 return mp11::tuple_transform(
229 []<typename T>(std::optional<T> & var)
230 -> T
231 {
232 BOOST_ASSERT(var.has_value());
233 return std::move(*var);
234 }, result);
235 }
236
237 auto await_resume(const as_tuple_tag &)
238 {
239 using t = decltype(await_resume());
240 if constexpr(!all_void)
241 {
242 if (error)
243 return std::make_tuple(error, t{});
244 else
245 return std::make_tuple(std::current_exception(),
246 mp11::tuple_transform(
247 []<typename T>(std::optional<T> & var)
248 -> T
249 {
250 BOOST_ASSERT(var.has_value());
251 return std::move(*var);
252 }, result));
253 }
254 else
255 return std::make_tuple(args&: error);
256 }
257
258 auto await_resume(const as_result_tag &)
259 {
260 using t = decltype(await_resume());
261 using rt = system::result<t, std::exception_ptr>;
262 if (error)
263 return rt(system::in_place_error, error);
264
265 if constexpr(!all_void)
266 return mp11::tuple_transform(
267 []<typename T>(std::optional<T> & var)
268 -> rt
269 {
270 BOOST_ASSERT(var.has_value());
271 return std::move(*var);
272 }, result);
273 else
274 return rt{system::in_place_value};
275 }
276 };
277 awaitable operator co_await() &&
278 {
279 return awaitable(args, std::make_index_sequence<sizeof...(Args)>{});
280 }
281};
282
283template<typename Range>
284struct join_ranged_impl
285{
286 Range aws;
287
288 using result_type = co_await_result_t<std::decay_t<decltype(*std::begin(std::declval<Range>()))>>;
289
290 constexpr static std::size_t result_size =
291 sizeof(std::conditional_t<std::is_void_v<result_type>, variant2::monostate, result_type>);
292
293 struct awaitable : fork::shared_state
294 {
295 struct dummy
296 {
297 template<typename ... Args>
298 dummy(Args && ...) {}
299 };
300
301 using type = std::decay_t<decltype(*std::begin(std::declval<Range>()))>;
302#if !defined(BOOST_COBALT_NO_PMR)
303 pmr::polymorphic_allocator<void> alloc{&resource};
304
305 std::conditional_t<awaitable_type<type>, Range &,
306 pmr::vector<co_awaitable_type<type>>> aws;
307
308 pmr::vector<bool> ready{std::size(aws), alloc};
309 pmr::vector<asio::cancellation_signal> cancel_{std::size(aws), alloc};
310 pmr::vector<asio::cancellation_signal*> cancel{std::size(aws), alloc};
311
312
313
314 std::conditional_t<
315 std::is_void_v<result_type>,
316 dummy,
317 pmr::vector<std::optional<void_as_monostate<result_type>>>>
318 result{
319 cancel.size(),
320 alloc};
321#else
322 std::allocator<void> alloc;
323 std::conditional_t<awaitable_type<type>, Range &, std::vector<co_awaitable_type<type>>> aws;
324
325 std::vector<bool> ready{std::size(aws), alloc};
326 std::vector<asio::cancellation_signal> cancel_{std::size(aws), alloc};
327 std::vector<asio::cancellation_signal*> cancel{std::size(aws), alloc};
328
329 std::conditional_t<
330 std::is_void_v<result_type>,
331 dummy,
332 std::vector<std::optional<void_as_monostate<result_type>>>>
333 result{
334 cancel.size(),
335 alloc};
336#endif
337 std::exception_ptr error;
338
339 awaitable(Range & aws_, std::false_type /* needs operator co_await */)
340 : fork::shared_state((512 + sizeof(co_awaitable_type<type>) + result_size) * std::size(aws_))
341 , aws{alloc}
342 , ready{std::size(aws_), alloc}
343 , cancel_{std::size(aws_), alloc}
344 , cancel{std::size(aws_), alloc}
345 {
346 aws.reserve(std::size(aws_));
347 for (auto && a : aws_)
348 {
349 using a_0 = std::decay_t<decltype(a)>;
350 using a_t = std::conditional_t<
351 std::is_lvalue_reference_v<Range>, a_0 &, a_0 &&>;
352 aws.emplace_back(awaitable_type_getter<a_t>(static_cast<a_t>(a)));
353 }
354
355 std::transform(std::begin(this->aws),
356 std::end(this->aws),
357 std::begin(cont&: ready),
358 [](auto & aw) {return aw.await_ready();});
359 }
360 awaitable(Range & aws, std::true_type /* needs operator co_await */)
361 : fork::shared_state((512 + sizeof(co_awaitable_type<type>) + result_size) * std::size(aws))
362 , aws(aws)
363 {
364 std::transform(std::begin(aws), std::end(aws), std::begin(cont&: ready), [](auto & aw) {return aw.await_ready();});
365 }
366
367 awaitable(Range & aws)
368 : awaitable(aws, std::bool_constant<awaitable_type<type>>{})
369 {
370 }
371
372 void cancel_all()
373 {
374 for (auto & r : cancel)
375 if (r)
376 std::exchange(obj&: r, new_val: nullptr)->emit(type: asio::cancellation_type::all);
377 }
378
379 void interrupt_await()
380 {
381 using t = std::conditional_t<std::is_reference_v<Range>,
382 co_awaitable_type<type> &,
383 co_awaitable_type<type> &&>;
384
385 if constexpr (interruptible<t>)
386 {
387 std::size_t idx = 0u;
388 for (auto & aw : aws)
389 if (cancel[idx])
390 static_cast<t>(aw).interrupt_await();
391 }
392 }
393
394
395 static detail::fork await_impl(awaitable & this_, std::size_t idx)
396 try
397 {
398 auto & aw = *std::next(std::begin(this_.aws), idx);
399 auto rd = aw.await_ready();
400 if (!rd)
401 {
402 this_.cancel[idx] = &this_.cancel_[idx];
403 co_await this_.cancel[idx]->slot();
404 co_await detail::fork::wired_up;
405 if constexpr (std::is_void_v<decltype(aw.await_resume())>)
406 co_await aw;
407 else
408 this_.result[idx].emplace(co_await aw);
409 }
410 else
411 {
412 if constexpr (std::is_void_v<decltype(aw.await_resume())>)
413 aw.await_resume();
414 else
415 this_.result[idx].emplace(aw.await_resume());
416 }
417 }
418 catch(...)
419 {
420 if (!this_.error)
421 this_.error = std::current_exception();
422 this_.cancel_all();
423 }
424
425 detail::fork last_forked;
426 std::size_t last_index = 0u;
427
428 bool await_ready()
429 {
430 while (last_index < cancel.size())
431 {
432 last_forked = await_impl(this_&: *this, idx: last_index++);
433 if (!last_forked.done())
434 return false; // one coro didn't immediately complete!
435 }
436 last_forked.release();
437 return true;
438 }
439
440
441 template<typename H>
442 auto await_suspend(
443 std::coroutine_handle<H> h
444#if defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
445 , const boost::source_location & loc = BOOST_CURRENT_LOCATION
446#endif
447 )
448 {
449#if defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
450 this->loc = loc;
451#endif
452 exec = &detail::get_executor(h);
453
454 last_forked.release().resume();
455 while (last_index < cancel.size())
456 await_impl(this_&: *this, idx: last_index++).release();
457
458 if (error)
459 cancel_all();
460
461 if (!this->outstanding_work()) // already done, resume right away.
462 return false;
463
464 // arm the cancel
465 assign_cancellation(
466 h,
467 [&](asio::cancellation_type ct)
468 {
469 for (auto cs : cancel)
470 if (cs)
471 cs->emit(type: ct);
472 });
473
474
475 this->coro.reset(h.address());
476 return true;
477 }
478
479 auto await_resume(const as_tuple_tag & )
480 {
481#if defined(BOOST_COBALT_NO_PMR)
482 std::vector<result_type> rr;
483#else
484 pmr::vector<result_type> rr{this_thread::get_allocator()};
485#endif
486
487 if (error)
488 return std::make_tuple(error, rr);
489 if constexpr (!std::is_void_v<result_type>)
490 {
491 rr.reserve(result.size());
492 for (auto & t : result)
493 rr.push_back(*std::move(t));
494 return std::make_tuple(std::exception_ptr(), std::move(rr));
495 }
496 }
497
498 auto await_resume(const as_result_tag & )
499 {
500#if defined(BOOST_COBALT_NO_PMR)
501 std::vector<result_type> rr;
502#else
503 pmr::vector<result_type> rr{this_thread::get_allocator()};
504#endif
505
506 if (error)
507 return system::result<decltype(rr), std::exception_ptr>(error);
508 if constexpr (!std::is_void_v<result_type>)
509 {
510 rr.reserve(result.size());
511 for (auto & t : result)
512 rr.push_back(*std::move(t));
513 return rr;
514 }
515 }
516
517#if _MSC_VER
518 BOOST_NOINLINE
519#endif
520 auto await_resume()
521 {
522 if (error)
523 std::rethrow_exception(error);
524 if constexpr (!std::is_void_v<result_type>)
525 {
526#if defined(BOOST_COBALT_NO_PMR)
527 std::vector<result_type> rr;
528#else
529 pmr::vector<result_type> rr{this_thread::get_allocator()};
530#endif
531 rr.reserve(result.size());
532 for (auto & t : result)
533 rr.push_back(*std::move(t));
534 return rr;
535 }
536 }
537 };
538 awaitable operator co_await() && {return awaitable{aws};}
539};
540
541}
542
543
544#endif //BOOST_COBALT_DETAIL_JOIN_HPP
545

source code of boost/libs/cobalt/include/boost/cobalt/detail/join.hpp