1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// UNSUPPORTED: c++03, c++11, c++14, c++17
10
11#include <cassert>
12#include <coroutine>
13#include <memory>
14
15#include "test_macros.h"
16
17bool cancel = false;
18
19struct goroutine
20{
21 static int const N = 10;
22 static int count;
23 static std::coroutine_handle<> stack[N];
24
25 static void schedule(std::coroutine_handle<>& rh)
26 {
27 assert(count < N);
28 stack[count++] = rh;
29 rh = nullptr;
30 }
31
32 goroutine() = default;
33 goroutine(const goroutine&) = default;
34 goroutine& operator=(const goroutine&) = default;
35 ~goroutine() {}
36
37 static void run_one()
38 {
39 assert(count > 0);
40 stack[--count]();
41 }
42
43 struct promise_type
44 {
45 std::suspend_never initial_suspend() {
46 return {};
47 }
48 std::suspend_never final_suspend() noexcept { return {}; }
49 void return_void() {}
50 goroutine get_return_object() {
51 return{};
52 }
53 void unhandled_exception() {}
54 };
55};
56int goroutine::count;
57std::coroutine_handle<> goroutine::stack[N];
58
59std::coroutine_handle<goroutine::promise_type> workaround;
60
61class channel;
62
63struct push_awaiter {
64 channel* ch;
65 bool await_ready() {return false; }
66 void await_suspend(std::coroutine_handle<> rh);
67 void await_resume() {}
68};
69
70struct pull_awaiter {
71 channel * ch;
72
73 bool await_ready();
74 void await_suspend(std::coroutine_handle<> rh);
75 int await_resume();
76};
77
78class channel
79{
80 using T = int;
81
82 friend struct push_awaiter;
83 friend struct pull_awaiter;
84
85 T const* pvalue = nullptr;
86 std::coroutine_handle<> reader = nullptr;
87 std::coroutine_handle<> writer = nullptr;
88public:
89 push_awaiter push(T const& value)
90 {
91 assert(pvalue == nullptr);
92 assert(!writer);
93 pvalue = &value;
94
95 return { .ch: this };
96 }
97
98 pull_awaiter pull()
99 {
100 assert(!reader);
101
102 return { .ch: this };
103 }
104
105 void sync_push(T const& value)
106 {
107 assert(!pvalue);
108 pvalue = &value;
109 assert(reader);
110 reader();
111 assert(!pvalue);
112 reader = nullptr;
113 }
114
115 auto sync_pull()
116 {
117 while (!pvalue) goroutine::run_one();
118 auto result = *pvalue;
119 pvalue = nullptr;
120 if (writer)
121 {
122 auto wr = writer;
123 writer = nullptr;
124 wr();
125 }
126 return result;
127 }
128};
129
130void push_awaiter::await_suspend(std::coroutine_handle<> rh)
131{
132 ch->writer = rh;
133 if (ch->reader) goroutine::schedule(ch->reader);
134}
135
136
137bool pull_awaiter::await_ready() {
138 return !!ch->writer;
139}
140void pull_awaiter::await_suspend(std::coroutine_handle<> rh) {
141 ch->reader = rh;
142}
143int pull_awaiter::await_resume() {
144 auto result = *ch->pvalue;
145 ch->pvalue = nullptr;
146 if (ch->writer) {
147 //goroutine::schedule(ch->writer);
148 auto wr = ch->writer;
149 ch->writer = nullptr;
150 wr();
151 }
152 return result;
153}
154
155goroutine pusher(channel& left, channel& right)
156{
157 for (;;) {
158 auto val = co_await left.pull();
159 co_await right.push(val + 1);
160 }
161}
162
163const int N = 100;
164channel c[N + 1];
165
166int main(int, char**) {
167 for (int i = 0; i < N; ++i)
168 pusher(left&: c[i], right&: c[i + 1]);
169
170 c[0].sync_push(value: 0);
171 int result = c[N].sync_pull();
172
173 assert(result == 100);
174
175 return 0;
176}
177

source code of libcxx/test/std/language.support/support.coroutines/end.to.end/go.pass.cpp