1//
2// Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net)
3//
4// Distributed under the Boost Software License, Version 1.0. (See accompanying
5// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6//
7
8#ifndef BOOST_COBALT_DETAIL_RACE_HPP
9#define BOOST_COBALT_DETAIL_RACE_HPP
10
11#include <boost/cobalt/detail/await_result_helper.hpp>
12#include <boost/cobalt/detail/fork.hpp>
13#include <boost/cobalt/detail/handler.hpp>
14#include <boost/cobalt/detail/forward_cancellation.hpp>
15#include <boost/cobalt/result.hpp>
16#include <boost/cobalt/this_thread.hpp>
17#include <boost/cobalt/detail/util.hpp>
18
19#include <boost/asio/bind_allocator.hpp>
20#include <boost/asio/bind_cancellation_slot.hpp>
21#include <boost/asio/bind_executor.hpp>
22#include <boost/asio/cancellation_signal.hpp>
23#include <boost/asio/associated_cancellation_slot.hpp>
24
25
26#include <boost/intrusive_ptr.hpp>
27#include <boost/core/demangle.hpp>
28#include <boost/core/span.hpp>
29#include <boost/variant2/variant.hpp>
30
31#include <coroutine>
32#include <optional>
33#include <algorithm>
34
35
36namespace boost::cobalt::detail
37{
38
39struct left_race_tag {};
40
41// helpers it determining the type of things;
42template<typename Base, // range of aw
43 typename Awaitable = Base>
44struct race_traits
45{
46 // for a ranges race this is based on the range, not the AW in it.
47 constexpr static bool is_lvalue = std::is_lvalue_reference_v<Base>;
48
49 // what the value is supposed to be cast to before the co_await_operator
50 using awaitable = std::conditional_t<is_lvalue, std::decay_t<Awaitable> &, Awaitable &&>;
51
52 // do we need operator co_await
53 constexpr static bool is_actual = awaitable_type<awaitable>;
54
55 // the type with .await_ functions & interrupt_await
56 using actual_awaitable
57 = std::conditional_t<
58 is_actual,
59 awaitable,
60 decltype(get_awaitable_type(std::declval<awaitable>()))>;
61
62 // the type to be used with interruptible
63 using interruptible_type
64 = std::conditional_t<
65 std::is_lvalue_reference_v<Base>,
66 std::decay_t<actual_awaitable> &,
67 std::decay_t<actual_awaitable> &&>;
68
69 constexpr static bool interruptible =
70 cobalt::interruptible<interruptible_type>;
71
72 static void do_interrupt(std::decay_t<actual_awaitable> & aw)
73 {
74 if constexpr (interruptible)
75 static_cast<interruptible_type>(aw).interrupt_await();
76 }
77
78};
79
80struct interruptible_base
81{
82 virtual void interrupt_await() = 0;
83};
84
85template<asio::cancellation_type Ct, typename URBG, typename ... Args>
86struct race_variadic_impl
87{
88
89 template<typename URBG_>
90 race_variadic_impl(URBG_ && g, Args && ... args)
91 : args{std::forward<Args>(args)...}, g(std::forward<URBG_>(g))
92 {
93 }
94
95 std::tuple<Args...> args;
96 URBG g;
97
98 constexpr static std::size_t tuple_size = sizeof...(Args);
99
100 struct awaitable : fork::static_shared_state<256 * tuple_size>
101 {
102
103#if !defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
104 boost::source_location loc;
105#endif
106
107 template<std::size_t ... Idx>
108 awaitable(std::tuple<Args...> & args, URBG & g, std::index_sequence<Idx...>) :
109 aws{args}
110 {
111 if constexpr (!std::is_same_v<URBG, left_race_tag>)
112 std::shuffle(impls.begin(), impls.end(), g);
113 std::fill(working.begin(), working.end(), nullptr);
114 }
115
116 std::tuple<Args...> & aws;
117 std::array<asio::cancellation_signal, tuple_size> cancel_;
118
119 template<typename > constexpr static auto make_null() {return nullptr;};
120 std::array<asio::cancellation_signal*, tuple_size> cancel = {make_null<Args>()...};
121
122 std::array<interruptible_base*, tuple_size> working;
123
124 std::size_t index{std::numeric_limits<std::size_t>::max()};
125
126 constexpr static bool all_void = (std::is_void_v<co_await_result_t<Args>> && ... );
127 std::optional<variant2::variant<void_as_monostate<co_await_result_t<Args>>...>> result;
128 std::exception_ptr error;
129
130 bool has_result() const
131 {
132 return index != std::numeric_limits<std::size_t>::max();
133 }
134
135 void cancel_all()
136 {
137 interrupt_await();
138 for (auto i = 0u; i < tuple_size; i++)
139 if (auto &r = cancel[i]; r)
140 std::exchange(r, nullptr)->emit(Ct);
141 }
142
143 void interrupt_await()
144 {
145 for (auto i : working)
146 if (i)
147 i->interrupt_await();
148 }
149
150 template<typename T, typename Error>
151 void assign_error(system::result<T, Error> & res)
152 try
153 {
154 std::move(res).value(loc);
155 }
156 catch(...)
157 {
158 error = std::current_exception();
159 }
160
161 template<typename T>
162 void assign_error(system::result<T, std::exception_ptr> & res)
163 {
164 error = std::move(res).error();
165 }
166
167 template<std::size_t Idx>
168 static detail::fork await_impl(awaitable & this_)
169 try
170 {
171 using traits = race_traits<mp11::mp_at_c<mp11::mp_list<Args...>, Idx>>;
172
173 typename traits::actual_awaitable aw_{
174 get_awaitable_type(
175 static_cast<typename traits::awaitable>(std::get<Idx>(this_.aws))
176 )
177 };
178
179 as_result_t aw{aw_};
180
181
182 struct interruptor final : interruptible_base
183 {
184 std::decay_t<typename traits::actual_awaitable> & aw;
185 interruptor(std::decay_t<typename traits::actual_awaitable> & aw) : aw(aw) {}
186 void interrupt_await() override
187 {
188 traits::do_interrupt(aw);
189 }
190 };
191 interruptor in{aw_};
192 //if constexpr (traits::interruptible)
193 this_.working[Idx] = &in;
194
195 auto transaction = [&this_, idx = Idx] {
196 if (this_.has_result())
197 boost::throw_exception(e: std::runtime_error("Another transaction already started"));
198 this_.cancel[idx] = nullptr;
199 // reserve the index early bc
200 this_.index = idx;
201 this_.cancel_all();
202 };
203
204 co_await fork::set_transaction_function(transaction);
205 // check manually if we're ready
206 auto rd = aw.await_ready();
207 if (!rd)
208 {
209 this_.cancel[Idx] = &this_.cancel_[Idx];
210 co_await this_.cancel[Idx]->slot();
211 // make sure the executor is set
212 co_await detail::fork::wired_up;
213
214 // do the await - this doesn't call await-ready again
215 if constexpr (std::is_void_v<decltype(aw_.await_resume())>)
216 {
217 auto res = co_await aw;
218 if (!this_.has_result())
219 {
220 this_.index = Idx;
221 if (res.has_error())
222 this_.assign_error(res);
223 }
224 if constexpr(!all_void)
225 if (this_.index == Idx && !res.has_error())
226 this_.result.emplace(variant2::in_place_index<Idx>);
227 }
228 else
229 {
230 auto val = co_await aw;
231 if (!this_.has_result())
232 this_.index = Idx;
233 if (this_.index == Idx)
234 {
235 if (val.has_error())
236 this_.assign_error(val);
237 else
238 this_.result.emplace(variant2::in_place_index<Idx>, *std::move(val));
239 }
240 }
241 this_.cancel[Idx] = nullptr;
242 }
243 else
244 {
245 if (!this_.has_result())
246 this_.index = Idx;
247 if constexpr (std::is_void_v<decltype(aw_.await_resume())>)
248 {
249 auto res = aw.await_resume();
250 if (this_.index == Idx)
251 {
252 if (res.has_error())
253 this_.assign_error(res);
254 else
255 this_.result.emplace(variant2::in_place_index<Idx>);
256 }
257 }
258 else
259 {
260 if (this_.index == Idx)
261 {
262 auto res = aw.await_resume();
263 if (res.has_error())
264 this_.assign_error(res);
265 else
266 this_.result.emplace(variant2::in_place_index<Idx>, *std::move(res));
267 }
268 else
269 aw.await_resume();
270 }
271 this_.cancel[Idx] = nullptr;
272 }
273 this_.cancel_all();
274 this_.working[Idx] = nullptr;
275 }
276 catch(...)
277 {
278 if (!this_.has_result())
279 this_.index = Idx;
280 if (this_.index == Idx)
281 this_.error = std::current_exception();
282 this_.working[Idx] = nullptr;
283 }
284
285 std::array<detail::fork(*)(awaitable&), tuple_size> impls {
286 []<std::size_t ... Idx>(std::index_sequence<Idx...>)
287 {
288 return std::array<detail::fork(*)(awaitable&), tuple_size>{&await_impl<Idx>...};
289 }(std::make_index_sequence<tuple_size>{})
290 };
291
292 detail::fork last_forked;
293
294 bool await_ready()
295 {
296 last_forked = impls[0](*this);
297 return last_forked.done();
298 }
299
300 template<typename H>
301 auto await_suspend(
302 std::coroutine_handle<H> h,
303 const boost::source_location & loc = BOOST_CURRENT_LOCATION)
304 {
305 this->loc = loc;
306
307 this->exec = &cobalt::detail::get_executor(h);
308 last_forked.release().resume();
309
310 if (!this->outstanding_work()) // already done, resume rightaway.
311 return false;
312
313 for (std::size_t idx = 1u;
314 idx < tuple_size; idx++) // we'
315 {
316 auto l = impls[idx](*this);
317 const auto d = l.done();
318 l.release();
319 if (d)
320 break;
321 }
322
323 if (!this->outstanding_work()) // already done, resume rightaway.
324 return false;
325
326 // arm the cancel
327 assign_cancellation(
328 h,
329 [&](asio::cancellation_type ct)
330 {
331 for (auto & cs : cancel)
332 if (cs)
333 cs->emit(ct);
334 });
335
336 this->coro.reset(h.address());
337 return true;
338 }
339
340#if _MSC_VER
341 BOOST_NOINLINE
342#endif
343 auto await_resume()
344 {
345 if (error)
346 std::rethrow_exception(error);
347 if constexpr (all_void)
348 return index;
349 else
350 return std::move(*result);
351 }
352
353 auto await_resume(const as_tuple_tag &)
354 {
355 if constexpr (all_void)
356 return std::make_tuple(args&: error, args&: index);
357 else
358 return std::make_tuple(error, std::move(*result));
359 }
360
361 auto await_resume(const as_result_tag & )
362 -> system::result<std::conditional_t<all_void, std::size_t, variant2::variant<void_as_monostate<co_await_result_t<Args>>...>>, std::exception_ptr>
363 {
364 if (error)
365 return {system::in_place_error, error};
366 if constexpr (all_void)
367 return {system::in_place_value, index};
368 else
369 return {system::in_place_value, std::move(*result)};
370 }
371 };
372 awaitable operator co_await() &&
373 {
374 return awaitable{args, g, std::make_index_sequence<tuple_size>{}};
375 }
376};
377
378
379template<asio::cancellation_type Ct, typename URBG, typename Range>
380struct race_ranged_impl
381{
382
383 using result_type = co_await_result_t<std::decay_t<decltype(*std::begin(std::declval<Range>()))>>;
384 template<typename URBG_>
385 race_ranged_impl(URBG_ && g, Range && rng)
386 : range{std::forward<Range>(rng)}, g(std::forward<URBG_>(g))
387 {
388 }
389
390 Range range;
391 URBG g;
392
393 struct awaitable : fork::shared_state
394 {
395
396#if !defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
397 boost::source_location loc;
398#endif
399
400 using type = std::decay_t<decltype(*std::begin(std::declval<Range>()))>;
401 using traits = race_traits<Range, type>;
402
403 std::size_t index{std::numeric_limits<std::size_t>::max()};
404
405 std::conditional_t<
406 std::is_void_v<result_type>,
407 variant2::monostate,
408 std::optional<result_type>> result;
409
410 std::exception_ptr error;
411
412#if !defined(BOOST_COBALT_NO_PMR)
413 pmr::monotonic_buffer_resource res;
414 pmr::polymorphic_allocator<void> alloc{&resource};
415
416 Range &aws;
417
418 struct dummy
419 {
420 template<typename ... Args>
421 dummy(Args && ...) {}
422 };
423
424 std::conditional_t<traits::interruptible,
425 pmr::vector<std::decay_t<typename traits::actual_awaitable>*>,
426 dummy> working{std::size(aws), alloc};
427
428 /* all below `reorder` is reordered
429 *
430 * cancel[idx] is for aws[reorder[idx]]
431 */
432 pmr::vector<std::size_t> reorder{std::size(aws), alloc};
433 pmr::vector<asio::cancellation_signal> cancel_{std::size(aws), alloc};
434 pmr::vector<asio::cancellation_signal*> cancel{std::size(aws), alloc};
435
436#else
437 Range &aws;
438
439 struct dummy
440 {
441 template<typename ... Args>
442 dummy(Args && ...) {}
443 };
444
445 std::conditional_t<traits::interruptible,
446 std::vector<std::decay_t<typename traits::actual_awaitable>*>,
447 dummy> working{std::size(aws), std::allocator<void>()};
448
449 /* all below `reorder` is reordered
450 *
451 * cancel[idx] is for aws[reorder[idx]]
452 */
453 std::vector<std::size_t> reorder{std::size(aws), std::allocator<void>()};
454 std::vector<asio::cancellation_signal> cancel_{std::size(aws), std::allocator<void>()};
455 std::vector<asio::cancellation_signal*> cancel{std::size(aws), std::allocator<void>()};
456
457#endif
458
459 bool has_result() const {return index != std::numeric_limits<std::size_t>::max(); }
460
461
462 awaitable(Range & aws, URBG & g)
463 : fork::shared_state((256 + sizeof(co_awaitable_type<type>) + sizeof(std::size_t)) * std::size(aws))
464 , aws(aws)
465 {
466 std::generate(reorder.begin(), reorder.end(), [i = std::size_t(0u)]() mutable {return i++;});
467 if constexpr (traits::interruptible)
468 std::fill(working.begin(), working.end(), nullptr);
469 if constexpr (!std::is_same_v<URBG, left_race_tag>)
470 std::shuffle(reorder.begin(), reorder.end(), g);
471 }
472
473 void cancel_all()
474 {
475 interrupt_await();
476 for (auto & r : cancel)
477 if (r)
478 std::exchange(obj&: r, new_val: nullptr)->emit(type: Ct);
479 }
480 void interrupt_await()
481 {
482 if constexpr (traits::interruptible)
483 for (auto aw : working)
484 if (aw)
485 traits::do_interrupt(*aw);
486 }
487
488
489 template<typename T, typename Error>
490 void assign_error(system::result<T, Error> & res)
491 try
492 {
493 std::move(res).value(loc);
494 }
495 catch(...)
496 {
497 error = std::current_exception();
498 }
499
500 template<typename T>
501 void assign_error(system::result<T, std::exception_ptr> & res)
502 {
503 error = std::move(res).error();
504 }
505
506 static detail::fork await_impl(awaitable & this_, std::size_t idx)
507 try
508 {
509 typename traits::actual_awaitable aw_{
510 get_awaitable_type(
511 static_cast<typename traits::awaitable>(*std::next(std::begin(this_.aws), idx))
512 )};
513
514 as_result_t aw{aw_};
515
516 if constexpr (traits::interruptible)
517 this_.working[idx] = &aw_;
518
519 auto transaction = [&this_, idx = idx] {
520 if (this_.has_result())
521 boost::throw_exception(e: std::runtime_error("Another transaction already started"));
522 this_.cancel[idx] = nullptr;
523 // reserve the index early bc
524 this_.index = idx;
525 this_.cancel_all();
526 };
527
528 co_await fork::set_transaction_function(transaction);
529 // check manually if we're ready
530 auto rd = aw.await_ready();
531 if (!rd)
532 {
533 this_.cancel[idx] = &this_.cancel_[idx];
534 co_await this_.cancel[idx]->slot();
535 // make sure the executor is set
536 co_await detail::fork::wired_up;
537
538 // do the await - this doesn't call await-ready again
539 if constexpr (std::is_void_v<result_type>)
540 {
541 auto res = co_await aw;
542 if (!this_.has_result())
543 {
544 if (res.has_error())
545 this_.assign_error(res);
546 this_.index = idx;
547 }
548 }
549 else
550 {
551 auto val = co_await aw;
552 if (!this_.has_result())
553 this_.index = idx;
554 if (this_.index == idx)
555 {
556 if (val.has_error())
557 this_.assign_error(val);
558 else
559 this_.result.emplace(*std::move(val));
560 }
561 }
562 this_.cancel[idx] = nullptr;
563 }
564 else
565 {
566
567 if (!this_.has_result())
568 this_.index = idx;
569 if constexpr (std::is_void_v<decltype(aw_.await_resume())>)
570 {
571 auto val = aw.await_resume();
572 if (val.has_error())
573 this_.assign_error(val);
574 }
575 else
576 {
577 if (this_.index == idx)
578 {
579 auto val = aw.await_resume();
580 if (val.has_error())
581 this_.assign_error(val);
582 else
583 this_.result.emplace(*std::move(val));
584 }
585 else
586 aw.await_resume();
587 }
588 this_.cancel[idx] = nullptr;
589 }
590 this_.cancel_all();
591 if constexpr (traits::interruptible)
592 this_.working[idx] = nullptr;
593 }
594 catch(...)
595 {
596 if (!this_.has_result())
597 this_.index = idx;
598 if (this_.index == idx)
599 this_.error = std::current_exception();
600 if constexpr (traits::interruptible)
601 this_.working[idx] = nullptr;
602 }
603
604 detail::fork last_forked;
605
606 bool await_ready()
607 {
608 last_forked = await_impl(this_&: *this, idx: reorder.front());
609 return last_forked.done();
610 }
611
612 template<typename H>
613 auto await_suspend(std::coroutine_handle<H> h,
614 const boost::source_location & loc = BOOST_CURRENT_LOCATION)
615 {
616 this->loc = loc;
617 this->exec = &detail::get_executor(h);
618 last_forked.release().resume();
619
620 if (!this->outstanding_work()) // already done, resume rightaway.
621 return false;
622
623 for (auto itr = std::next(x: reorder.begin());
624 itr < reorder.end(); std::advance(i&: itr, n: 1)) // we'
625 {
626 auto l = await_impl(this_&: *this, idx: *itr);
627 auto d = l.done();
628 l.release();
629 if (d)
630 break;
631 }
632
633 if (!this->outstanding_work()) // already done, resume rightaway.
634 return false;
635
636 // arm the cancel
637 assign_cancellation(
638 h,
639 [&](asio::cancellation_type ct)
640 {
641 for (auto & cs : cancel)
642 if (cs)
643 cs->emit(type: ct);
644 });
645
646 this->coro.reset(h.address());
647 return true;
648 }
649
650#if _MSC_VER
651 BOOST_NOINLINE
652#endif
653 auto await_resume()
654 {
655 if (error)
656 std::rethrow_exception(error);
657 if constexpr (std::is_void_v<result_type>)
658 return index;
659 else
660 return std::make_pair(index, *result);
661 }
662
663 auto await_resume(const as_tuple_tag &)
664 {
665 if constexpr (std::is_void_v<result_type>)
666 return std::make_tuple(args&: error, args&: index);
667 else
668 return std::make_tuple(error, std::make_pair(index, std::move(*result)));
669 }
670
671 auto await_resume(const as_result_tag & )
672 -> system::result<result_type, std::exception_ptr>
673 {
674 if (error)
675 return {system::in_place_error, error};
676 if constexpr (std::is_void_v<result_type>)
677 return {system::in_place_value, index};
678 else
679 return {system::in_place_value, std::make_pair(index, std::move(*result))};
680 }
681
682 };
683 awaitable operator co_await() &&
684 {
685 return awaitable{range, g};
686 }
687};
688
689}
690
691#endif //BOOST_COBALT_DETAIL_RACE_HPP
692

source code of boost/libs/cobalt/include/boost/cobalt/detail/race.hpp