1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // UNSUPPORTED: c++03, c++11, c++14, c++17 |
10 | |
11 | #include <cassert> |
12 | #include <coroutine> |
13 | #include <memory> |
14 | |
15 | #include "test_macros.h" |
16 | |
17 | bool cancel = false; |
18 | |
19 | struct goroutine |
20 | { |
21 | static int const N = 10; |
22 | static int count; |
23 | static std::coroutine_handle<> stack[N]; |
24 | |
25 | static void schedule(std::coroutine_handle<>& rh) |
26 | { |
27 | assert(count < N); |
28 | stack[count++] = rh; |
29 | rh = nullptr; |
30 | } |
31 | |
32 | goroutine() = default; |
33 | goroutine(const goroutine&) = default; |
34 | goroutine& operator=(const goroutine&) = default; |
35 | ~goroutine() {} |
36 | |
37 | static void run_one() |
38 | { |
39 | assert(count > 0); |
40 | stack[--count](); |
41 | } |
42 | |
43 | struct promise_type |
44 | { |
45 | std::suspend_never initial_suspend() { |
46 | return {}; |
47 | } |
48 | std::suspend_never final_suspend() noexcept { return {}; } |
49 | void return_void() {} |
50 | goroutine get_return_object() { |
51 | return{}; |
52 | } |
53 | void unhandled_exception() {} |
54 | }; |
55 | }; |
56 | int goroutine::count; |
57 | std::coroutine_handle<> goroutine::stack[N]; |
58 | |
59 | std::coroutine_handle<goroutine::promise_type> workaround; |
60 | |
61 | class channel; |
62 | |
63 | struct push_awaiter { |
64 | channel* ch; |
65 | bool await_ready() {return false; } |
66 | void await_suspend(std::coroutine_handle<> rh); |
67 | void await_resume() {} |
68 | }; |
69 | |
70 | struct pull_awaiter { |
71 | channel * ch; |
72 | |
73 | bool await_ready(); |
74 | void await_suspend(std::coroutine_handle<> rh); |
75 | int await_resume(); |
76 | }; |
77 | |
78 | class channel |
79 | { |
80 | using T = int; |
81 | |
82 | friend struct push_awaiter; |
83 | friend struct pull_awaiter; |
84 | |
85 | T const* pvalue = nullptr; |
86 | std::coroutine_handle<> reader = nullptr; |
87 | std::coroutine_handle<> writer = nullptr; |
88 | public: |
89 | push_awaiter push(T const& value) |
90 | { |
91 | assert(pvalue == nullptr); |
92 | assert(!writer); |
93 | pvalue = &value; |
94 | |
95 | return { .ch: this }; |
96 | } |
97 | |
98 | pull_awaiter pull() |
99 | { |
100 | assert(!reader); |
101 | |
102 | return { .ch: this }; |
103 | } |
104 | |
105 | void sync_push(T const& value) |
106 | { |
107 | assert(!pvalue); |
108 | pvalue = &value; |
109 | assert(reader); |
110 | reader(); |
111 | assert(!pvalue); |
112 | reader = nullptr; |
113 | } |
114 | |
115 | auto sync_pull() |
116 | { |
117 | while (!pvalue) goroutine::run_one(); |
118 | auto result = *pvalue; |
119 | pvalue = nullptr; |
120 | if (writer) |
121 | { |
122 | auto wr = writer; |
123 | writer = nullptr; |
124 | wr(); |
125 | } |
126 | return result; |
127 | } |
128 | }; |
129 | |
130 | void push_awaiter::await_suspend(std::coroutine_handle<> rh) |
131 | { |
132 | ch->writer = rh; |
133 | if (ch->reader) goroutine::schedule(ch->reader); |
134 | } |
135 | |
136 | |
137 | bool pull_awaiter::await_ready() { |
138 | return !!ch->writer; |
139 | } |
140 | void pull_awaiter::await_suspend(std::coroutine_handle<> rh) { |
141 | ch->reader = rh; |
142 | } |
143 | int pull_awaiter::await_resume() { |
144 | auto result = *ch->pvalue; |
145 | ch->pvalue = nullptr; |
146 | if (ch->writer) { |
147 | //goroutine::schedule(ch->writer); |
148 | auto wr = ch->writer; |
149 | ch->writer = nullptr; |
150 | wr(); |
151 | } |
152 | return result; |
153 | } |
154 | |
155 | goroutine pusher(channel& left, channel& right) |
156 | { |
157 | for (;;) { |
158 | auto val = co_await left.pull(); |
159 | co_await right.push(val + 1); |
160 | } |
161 | } |
162 | |
163 | const int N = 100; |
164 | channel c[N + 1]; |
165 | |
166 | int main(int, char**) { |
167 | for (int i = 0; i < N; ++i) |
168 | pusher(left&: c[i], right&: c[i + 1]); |
169 | |
170 | c[0].sync_push(value: 0); |
171 | int result = c[N].sync_pull(); |
172 | |
173 | assert(result == 100); |
174 | |
175 | return 0; |
176 | } |
177 | |