1use super::assert_stream;
2use core::{fmt, pin::Pin};
3use futures_core::stream::{FusedStream, Stream};
4use futures_core::task::{Context, Poll};
5use pin_project_lite::pin_project;
6
7/// Type to tell [`SelectWithStrategy`] which stream to poll next.
8#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)]
9pub enum PollNext {
10 /// Poll the first stream.
11 Left,
12 /// Poll the second stream.
13 Right,
14}
15
16impl PollNext {
17 /// Toggle the value and return the old one.
18 pub fn toggle(&mut self) -> Self {
19 let old = *self;
20 *self = self.other();
21 old
22 }
23
24 fn other(&self) -> PollNext {
25 match self {
26 PollNext::Left => PollNext::Right,
27 PollNext::Right => PollNext::Left,
28 }
29 }
30}
31
32impl Default for PollNext {
33 fn default() -> Self {
34 PollNext::Left
35 }
36}
37
38enum InternalState {
39 Start,
40 LeftFinished,
41 RightFinished,
42 BothFinished,
43}
44
45impl InternalState {
46 fn finish(&mut self, ps: PollNext) {
47 match (&self, ps) {
48 (InternalState::Start, PollNext::Left) => {
49 *self = InternalState::LeftFinished;
50 }
51 (InternalState::Start, PollNext::Right) => {
52 *self = InternalState::RightFinished;
53 }
54 (InternalState::LeftFinished, PollNext::Right)
55 | (InternalState::RightFinished, PollNext::Left) => {
56 *self = InternalState::BothFinished;
57 }
58 _ => {}
59 }
60 }
61}
62
63pin_project! {
64 /// Stream for the [`select_with_strategy()`] function. See function docs for details.
65 #[must_use = "streams do nothing unless polled"]
66 #[project = SelectWithStrategyProj]
67 pub struct SelectWithStrategy<St1, St2, Clos, State> {
68 #[pin]
69 stream1: St1,
70 #[pin]
71 stream2: St2,
72 internal_state: InternalState,
73 state: State,
74 clos: Clos,
75 }
76}
77
78/// This function will attempt to pull items from both streams. You provide a
79/// closure to tell [`SelectWithStrategy`] which stream to poll. The closure can
80/// store state on `SelectWithStrategy` to which it will receive a `&mut` on every
81/// invocation. This allows basing the strategy on prior choices.
82///
83/// After one of the two input streams completes, the remaining one will be
84/// polled exclusively. The returned stream completes when both input
85/// streams have completed.
86///
87/// Note that this function consumes both streams and returns a wrapped
88/// version of them.
89///
90/// ## Examples
91///
92/// ### Priority
93/// This example shows how to always prioritize the left stream.
94///
95/// ```rust
96/// # futures::executor::block_on(async {
97/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
98///
99/// let left = repeat(1);
100/// let right = repeat(2);
101///
102/// // We don't need any state, so let's make it an empty tuple.
103/// // We must provide some type here, as there is no way for the compiler
104/// // to infer it. As we don't need to capture variables, we can just
105/// // use a function pointer instead of a closure.
106/// fn prio_left(_: &mut ()) -> PollNext { PollNext::Left }
107///
108/// let mut out = select_with_strategy(left, right, prio_left);
109///
110/// for _ in 0..100 {
111/// // Whenever we poll out, we will alwas get `1`.
112/// assert_eq!(1, out.select_next_some().await);
113/// }
114/// # });
115/// ```
116///
117/// ### Round Robin
118/// This example shows how to select from both streams round robin.
119/// Note: this special case is provided by [`futures-util::stream::select`].
120///
121/// ```rust
122/// # futures::executor::block_on(async {
123/// use futures::stream::{ repeat, select_with_strategy, PollNext, StreamExt };
124///
125/// let left = repeat(1);
126/// let right = repeat(2);
127///
128/// let rrobin = |last: &mut PollNext| last.toggle();
129///
130/// let mut out = select_with_strategy(left, right, rrobin);
131///
132/// for _ in 0..100 {
133/// // We should be alternating now.
134/// assert_eq!(1, out.select_next_some().await);
135/// assert_eq!(2, out.select_next_some().await);
136/// }
137/// # });
138/// ```
139pub fn select_with_strategy<St1, St2, Clos, State>(
140 stream1: St1,
141 stream2: St2,
142 which: Clos,
143) -> SelectWithStrategy<St1, St2, Clos, State>
144where
145 St1: Stream,
146 St2: Stream<Item = St1::Item>,
147 Clos: FnMut(&mut State) -> PollNext,
148 State: Default,
149{
150 assert_stream::<St1::Item, _>(SelectWithStrategy {
151 stream1,
152 stream2,
153 state: Default::default(),
154 internal_state: InternalState::Start,
155 clos: which,
156 })
157}
158
159impl<St1, St2, Clos, State> SelectWithStrategy<St1, St2, Clos, State> {
160 /// Acquires a reference to the underlying streams that this combinator is
161 /// pulling from.
162 pub fn get_ref(&self) -> (&St1, &St2) {
163 (&self.stream1, &self.stream2)
164 }
165
166 /// Acquires a mutable reference to the underlying streams that this
167 /// combinator is pulling from.
168 ///
169 /// Note that care must be taken to avoid tampering with the state of the
170 /// stream which may otherwise confuse this combinator.
171 pub fn get_mut(&mut self) -> (&mut St1, &mut St2) {
172 (&mut self.stream1, &mut self.stream2)
173 }
174
175 /// Acquires a pinned mutable reference to the underlying streams that this
176 /// combinator is pulling from.
177 ///
178 /// Note that care must be taken to avoid tampering with the state of the
179 /// stream which may otherwise confuse this combinator.
180 pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut St1>, Pin<&mut St2>) {
181 let this = self.project();
182 (this.stream1, this.stream2)
183 }
184
185 /// Consumes this combinator, returning the underlying streams.
186 ///
187 /// Note that this may discard intermediate state of this combinator, so
188 /// care should be taken to avoid losing resources when this is called.
189 pub fn into_inner(self) -> (St1, St2) {
190 (self.stream1, self.stream2)
191 }
192}
193
194impl<St1, St2, Clos, State> FusedStream for SelectWithStrategy<St1, St2, Clos, State>
195where
196 St1: Stream,
197 St2: Stream<Item = St1::Item>,
198 Clos: FnMut(&mut State) -> PollNext,
199{
200 fn is_terminated(&self) -> bool {
201 match self.internal_state {
202 InternalState::BothFinished => true,
203 _ => false,
204 }
205 }
206}
207
208#[inline]
209fn poll_side<St1, St2, Clos, State>(
210 select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
211 side: PollNext,
212 cx: &mut Context<'_>,
213) -> Poll<Option<St1::Item>>
214where
215 St1: Stream,
216 St2: Stream<Item = St1::Item>,
217{
218 match side {
219 PollNext::Left => select.stream1.as_mut().poll_next(cx),
220 PollNext::Right => select.stream2.as_mut().poll_next(cx),
221 }
222}
223
224#[inline]
225fn poll_inner<St1, St2, Clos, State>(
226 select: &mut SelectWithStrategyProj<'_, St1, St2, Clos, State>,
227 side: PollNext,
228 cx: &mut Context<'_>,
229) -> Poll<Option<St1::Item>>
230where
231 St1: Stream,
232 St2: Stream<Item = St1::Item>,
233{
234 let first_done = match poll_side(select, side, cx) {
235 Poll::Ready(Some(item)) => return Poll::Ready(Some(item)),
236 Poll::Ready(None) => {
237 select.internal_state.finish(side);
238 true
239 }
240 Poll::Pending => false,
241 };
242 let other = side.other();
243 match poll_side(select, other, cx) {
244 Poll::Ready(None) => {
245 select.internal_state.finish(other);
246 if first_done {
247 Poll::Ready(None)
248 } else {
249 Poll::Pending
250 }
251 }
252 a => a,
253 }
254}
255
256impl<St1, St2, Clos, State> Stream for SelectWithStrategy<St1, St2, Clos, State>
257where
258 St1: Stream,
259 St2: Stream<Item = St1::Item>,
260 Clos: FnMut(&mut State) -> PollNext,
261{
262 type Item = St1::Item;
263
264 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<St1::Item>> {
265 let mut this = self.project();
266
267 match this.internal_state {
268 InternalState::Start => {
269 let next_side = (this.clos)(this.state);
270 poll_inner(&mut this, next_side, cx)
271 }
272 InternalState::LeftFinished => match this.stream2.poll_next(cx) {
273 Poll::Ready(None) => {
274 *this.internal_state = InternalState::BothFinished;
275 Poll::Ready(None)
276 }
277 a => a,
278 },
279 InternalState::RightFinished => match this.stream1.poll_next(cx) {
280 Poll::Ready(None) => {
281 *this.internal_state = InternalState::BothFinished;
282 Poll::Ready(None)
283 }
284 a => a,
285 },
286 InternalState::BothFinished => Poll::Ready(None),
287 }
288 }
289}
290
291impl<St1, St2, Clos, State> fmt::Debug for SelectWithStrategy<St1, St2, Clos, State>
292where
293 St1: fmt::Debug,
294 St2: fmt::Debug,
295 State: fmt::Debug,
296{
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 f.debug_struct("SelectWithStrategy")
299 .field("stream1", &self.stream1)
300 .field("stream2", &self.stream2)
301 .field("state", &self.state)
302 .finish()
303 }
304}
305