| 1 | // |
| 2 | // Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net) |
| 3 | // |
| 4 | // Distributed under the Boost Software License, Version 1.0. (See accompanying |
| 5 | // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) |
| 6 | // |
| 7 | |
| 8 | #ifndef BOOST_COBALT_DETAIL_THREAD_HPP |
| 9 | #define BOOST_COBALT_DETAIL_THREAD_HPP |
| 10 | |
| 11 | #include <boost/cobalt/config.hpp> |
| 12 | #include <boost/cobalt/detail/forward_cancellation.hpp> |
| 13 | #include <boost/cobalt/detail/handler.hpp> |
| 14 | #include <boost/cobalt/concepts.hpp> |
| 15 | #include <boost/cobalt/this_coro.hpp> |
| 16 | |
| 17 | #include <boost/asio/cancellation_signal.hpp> |
| 18 | |
| 19 | #include <thread> |
| 20 | |
| 21 | namespace boost::cobalt |
| 22 | { |
| 23 | |
| 24 | struct as_tuple_tag; |
| 25 | struct as_result_tag; |
| 26 | |
| 27 | namespace detail |
| 28 | { |
| 29 | struct thread_promise; |
| 30 | } |
| 31 | |
| 32 | struct thread; |
| 33 | |
| 34 | namespace detail |
| 35 | { |
| 36 | |
| 37 | |
| 38 | struct signal_helper_2 |
| 39 | { |
| 40 | asio::cancellation_signal signal; |
| 41 | }; |
| 42 | |
| 43 | |
| 44 | struct thread_state |
| 45 | { |
| 46 | asio::io_context ctx{1u}; |
| 47 | asio::cancellation_signal signal; |
| 48 | std::mutex mtx; |
| 49 | std::optional<completion_handler<std::exception_ptr>> waitor; |
| 50 | std::atomic<bool> done = false; |
| 51 | }; |
| 52 | |
| 53 | struct thread_promise : signal_helper_2, |
| 54 | promise_cancellation_base<asio::cancellation_slot, asio::enable_total_cancellation>, |
| 55 | promise_throw_if_cancelled_base, |
| 56 | enable_awaitables<thread_promise>, |
| 57 | enable_await_allocator<thread_promise>, |
| 58 | enable_await_executor<thread_promise> |
| 59 | { |
| 60 | BOOST_COBALT_DECL thread_promise(); |
| 61 | |
| 62 | struct initial_awaitable |
| 63 | { |
| 64 | bool await_ready() const {return false;} |
| 65 | void await_suspend(std::coroutine_handle<thread_promise> h) |
| 66 | { |
| 67 | h.promise().mtx.unlock(); |
| 68 | } |
| 69 | |
| 70 | void await_resume() {} |
| 71 | }; |
| 72 | |
| 73 | auto initial_suspend() noexcept |
| 74 | { |
| 75 | return initial_awaitable{}; |
| 76 | } |
| 77 | std::suspend_never final_suspend() noexcept |
| 78 | { |
| 79 | wexec_.reset(); |
| 80 | return {}; |
| 81 | } |
| 82 | |
| 83 | void unhandled_exception() { throw; } |
| 84 | void return_void() { } |
| 85 | |
| 86 | using executor_type = typename cobalt::executor; |
| 87 | const executor_type & get_executor() const {return *exec_;} |
| 88 | |
| 89 | #if !defined(BOOST_COBALT_NO_PMR) |
| 90 | using allocator_type = pmr::polymorphic_allocator<void>; |
| 91 | using resource_type = pmr::unsynchronized_pool_resource; |
| 92 | |
| 93 | resource_type * resource; |
| 94 | allocator_type get_allocator() const { return allocator_type(resource); } |
| 95 | #endif |
| 96 | |
| 97 | using promise_cancellation_base<asio::cancellation_slot, asio::enable_total_cancellation>::await_transform; |
| 98 | using promise_throw_if_cancelled_base::await_transform; |
| 99 | using enable_awaitables<thread_promise>::await_transform; |
| 100 | using enable_await_allocator<thread_promise>::await_transform; |
| 101 | using enable_await_executor<thread_promise>::await_transform; |
| 102 | |
| 103 | BOOST_COBALT_DECL |
| 104 | boost::cobalt::thread get_return_object(); |
| 105 | |
| 106 | void set_executor(asio::io_context::executor_type exec) |
| 107 | { |
| 108 | wexec_.emplace(args&: exec); |
| 109 | exec_.emplace(args&: exec); |
| 110 | } |
| 111 | |
| 112 | std::mutex mtx; |
| 113 | private: |
| 114 | |
| 115 | std::optional<asio::executor_work_guard<asio::io_context::executor_type>> wexec_; |
| 116 | std::optional<cobalt::executor> exec_; |
| 117 | }; |
| 118 | |
| 119 | struct thread_awaitable |
| 120 | { |
| 121 | asio::cancellation_slot cl; |
| 122 | std::optional<std::tuple<std::exception_ptr>> res; |
| 123 | bool await_ready(const boost::source_location & loc = BOOST_CURRENT_LOCATION) const |
| 124 | { |
| 125 | if (state_ == nullptr) |
| 126 | boost::throw_exception(e: std::invalid_argument("Thread expired" ), loc); |
| 127 | std::lock_guard<std::mutex> lock{state_->mtx}; |
| 128 | return state_->done; |
| 129 | } |
| 130 | |
| 131 | template<typename Promise> |
| 132 | bool await_suspend(std::coroutine_handle<Promise> h) |
| 133 | { |
| 134 | BOOST_ASSERT(state_); |
| 135 | |
| 136 | std::lock_guard<std::mutex> lock{state_->mtx}; |
| 137 | if (state_->done) |
| 138 | return false; |
| 139 | |
| 140 | if constexpr (requires (Promise p) {p.get_cancellation_slot();}) |
| 141 | if ((cl = h.promise().get_cancellation_slot()).is_connected()) |
| 142 | { |
| 143 | cl.assign( |
| 144 | [st = state_](asio::cancellation_type type) |
| 145 | { |
| 146 | std::lock_guard<std::mutex> lock{st->mtx}; |
| 147 | asio::post(st->ctx, |
| 148 | [st, type] |
| 149 | { |
| 150 | BOOST_ASIO_HANDLER_LOCATION((__FILE__, __LINE__, __func__)); |
| 151 | st->signal.emit(type); |
| 152 | }); |
| 153 | }); |
| 154 | |
| 155 | } |
| 156 | |
| 157 | state_->waitor.emplace(h, res); |
| 158 | return true; |
| 159 | } |
| 160 | |
| 161 | void await_resume() |
| 162 | { |
| 163 | if (cl.is_connected()) |
| 164 | cl.clear(); |
| 165 | if (thread_) |
| 166 | thread_->join(); |
| 167 | if (!res) // await_ready |
| 168 | return; |
| 169 | if (auto ee = std::get<0>(t&: *res)) |
| 170 | std::rethrow_exception(ee); |
| 171 | } |
| 172 | |
| 173 | system::result<void, std::exception_ptr> await_resume(const as_result_tag &) |
| 174 | { |
| 175 | if (cl.is_connected()) |
| 176 | cl.clear(); |
| 177 | if (thread_) |
| 178 | thread_->join(); |
| 179 | if (!res) // await_ready |
| 180 | return {system::in_place_value}; |
| 181 | if (auto ee = std::get<0>(t&: *res)) |
| 182 | return {system::in_place_error, std::move(ee)}; |
| 183 | |
| 184 | return {system::in_place_value}; |
| 185 | } |
| 186 | |
| 187 | std::tuple<std::exception_ptr> await_resume(const as_tuple_tag &) |
| 188 | { |
| 189 | if (cl.is_connected()) |
| 190 | cl.clear(); |
| 191 | if (thread_) |
| 192 | thread_->join(); |
| 193 | |
| 194 | return std::get<0>(t&: *res); |
| 195 | } |
| 196 | |
| 197 | explicit thread_awaitable(std::shared_ptr<detail::thread_state> state) |
| 198 | : state_(std::move(state)) {} |
| 199 | |
| 200 | explicit thread_awaitable(std::thread thread, |
| 201 | std::shared_ptr<detail::thread_state> state) |
| 202 | : thread_(std::move(thread)), state_(std::move(state)) {} |
| 203 | private: |
| 204 | std::optional<std::thread> thread_; |
| 205 | std::shared_ptr<detail::thread_state> state_; |
| 206 | }; |
| 207 | } |
| 208 | |
| 209 | } |
| 210 | |
| 211 | #endif //BOOST_COBALT_DETAIL_THREAD_HPP |
| 212 | |