1 | use crate::*; |
2 | use core::mem; |
3 | use core::pin::Pin; |
4 | use core::task::{Context, Poll}; |
5 | |
6 | pin_project_lite::pin_project! { |
7 | /// A stream for the [`join`](fn.join.html) function. |
8 | #[derive(Debug)] |
9 | pub struct Join<A, B> |
10 | where |
11 | A: OrderedStream, |
12 | B: OrderedStream<Data = A::Data, Ordering=A::Ordering>, |
13 | { |
14 | #[pin] |
15 | stream_a: A, |
16 | #[pin] |
17 | stream_b: B, |
18 | state: JoinState<A::Data, B::Data, A::Ordering>, |
19 | } |
20 | } |
21 | |
22 | /// Join two streams while preserving the overall ordering of elements. |
23 | /// |
24 | /// You can think of this as implementing the "merge" step of a merge sort on the two streams, |
25 | /// producing a single stream that is sorted given two sorted streams. If the streams return |
26 | /// [`PollResult::NoneBefore`] as intended, then the joined stream will be able to produce items |
27 | /// when only one of the sources has unblocked. |
28 | pub fn join<A, B>(stream_a: A, stream_b: B) -> Join<A, B> |
29 | where |
30 | A: OrderedStream, |
31 | B: OrderedStream<Data = A::Data, Ordering = A::Ordering>, |
32 | { |
33 | Join { |
34 | stream_a, |
35 | stream_b, |
36 | state: JoinState::None, |
37 | } |
38 | } |
39 | |
40 | #[derive (Debug)] |
41 | enum JoinState<A, B, T> { |
42 | None, |
43 | A(A, T), |
44 | B(B, T), |
45 | OnlyPollA, |
46 | OnlyPollB, |
47 | Terminated, |
48 | } |
49 | |
50 | impl<A, B, T> JoinState<A, B, T> { |
51 | fn take_split(&mut self) -> (PollState<A, T>, PollState<B, T>) { |
52 | match mem::replace(self, src:JoinState::None) { |
53 | JoinState::None => (PollState::Pending, PollState::Pending), |
54 | JoinState::A(a: A, t: T) => (PollState::Item(a, t), PollState::Pending), |
55 | JoinState::B(b: B, t: T) => (PollState::Pending, PollState::Item(b, t)), |
56 | JoinState::OnlyPollA => (PollState::Pending, PollState::Terminated), |
57 | JoinState::OnlyPollB => (PollState::Terminated, PollState::Pending), |
58 | JoinState::Terminated => (PollState::Terminated, PollState::Terminated), |
59 | } |
60 | } |
61 | } |
62 | |
63 | /// A helper equivalent to Poll<PollResult<T, I>> but easier to match |
64 | pub(crate) enum PollState<I, T> { |
65 | Item(I, T), |
66 | Pending, |
67 | NoneBefore, |
68 | Terminated, |
69 | } |
70 | |
71 | impl<I, T: Ord> PollState<I, T> { |
72 | fn ordering(&self) -> Option<&T> { |
73 | match self { |
74 | Self::Item(_, t) => Some(t), |
75 | _ => None, |
76 | } |
77 | } |
78 | |
79 | fn update( |
80 | &mut self, |
81 | before: Option<&T>, |
82 | other_token: Option<&T>, |
83 | retry: bool, |
84 | run: impl FnOnce(Option<&T>) -> Poll<PollResult<T, I>>, |
85 | ) -> bool { |
86 | match self { |
87 | // Do not re-poll if we have an item already or if we are terminated |
88 | Self::Item { .. } | Self::Terminated => return false, |
89 | |
90 | // No need to re-poll if we already declared no items <= before |
91 | Self::NoneBefore if retry => return false, |
92 | |
93 | _ => {} |
94 | } |
95 | |
96 | // Run the poll with the earlier of the two tokens to avoid transitioning to Pending (which |
97 | // will stall the Join) when we could have transitioned to NoneBefore. |
98 | let ordering = match (before, other_token) { |
99 | (Some(u), Some(o)) => { |
100 | if *u > *o { |
101 | // The other ordering is earlier - so a retry might let us upgrade a Pending to a |
102 | // NoneBefore |
103 | Some(o) |
104 | } else if retry { |
105 | // A retry will not improve matters, so don't bother |
106 | return false; |
107 | } else { |
108 | Some(u) |
109 | } |
110 | } |
111 | (Some(t), None) | (None, Some(t)) => Some(t), |
112 | (None, None) => None, |
113 | }; |
114 | |
115 | *self = run(ordering).into(); |
116 | matches!(self, Self::Item { .. }) |
117 | } |
118 | } |
119 | |
120 | impl<I, T> From<PollState<I, T>> for Poll<PollResult<T, I>> { |
121 | fn from(poll: PollState<I, T>) -> Self { |
122 | match poll { |
123 | PollState::Item(data: I, ordering: T) => Poll::Ready(PollResult::Item { data, ordering }), |
124 | PollState::Pending => Poll::Pending, |
125 | PollState::NoneBefore => Poll::Ready(PollResult::NoneBefore), |
126 | PollState::Terminated => Poll::Ready(PollResult::Terminated), |
127 | } |
128 | } |
129 | } |
130 | |
131 | impl<I, T> From<Poll<PollResult<T, I>>> for PollState<I, T> { |
132 | fn from(poll: Poll<PollResult<T, I>>) -> Self { |
133 | match poll { |
134 | Poll::Ready(PollResult::Item { data: I, ordering: T }) => Self::Item(data, ordering), |
135 | Poll::Ready(PollResult::NoneBefore) => Self::NoneBefore, |
136 | Poll::Ready(PollResult::Terminated) => Self::Terminated, |
137 | Poll::Pending => Self::Pending, |
138 | } |
139 | } |
140 | } |
141 | |
142 | impl<A, B> Join<A, B> |
143 | where |
144 | A: OrderedStream, |
145 | B: OrderedStream<Data = A::Data, Ordering = A::Ordering>, |
146 | { |
147 | /// Split into the source streams. |
148 | /// |
149 | /// This method returns the source streams along with any buffered item and its |
150 | /// ordering. |
151 | pub fn into_inner(self) -> (A, B, Option<(A::Data, A::Ordering)>) { |
152 | let item = match self.state { |
153 | JoinState::A(a, o) => Some((a, o)), |
154 | JoinState::B(b, o) => Some((b, o)), |
155 | _ => None, |
156 | }; |
157 | |
158 | (self.stream_a, self.stream_b, item) |
159 | } |
160 | |
161 | /// Provide direct access to the underlying stream. |
162 | /// |
163 | /// This may be useful if the stream provides APIs beyond [OrderedStream]. Note that the join |
164 | /// itself may be buffering an item from this stream, so you should consult |
165 | /// [Self::peek_buffered] and, if needed, [Self::take_buffered] before polling it directly. |
166 | pub fn stream_a(self: Pin<&mut Self>) -> Pin<&mut A> { |
167 | self.project().stream_a |
168 | } |
169 | |
170 | /// Provide direct access to the underlying stream. |
171 | /// |
172 | /// This may be useful if the stream provides APIs beyond [OrderedStream]. Note that the join |
173 | /// itself may be buffering an item from this stream, so you should consult |
174 | /// [Self::peek_buffered] and, if needed, [Self::take_buffered] before polling it directly. |
175 | pub fn stream_b(self: Pin<&mut Self>) -> Pin<&mut B> { |
176 | self.project().stream_b |
177 | } |
178 | |
179 | /// Allow access to the buffered item, if any. |
180 | /// |
181 | /// At most one of the two sides will be `Some`. The returned item is a candidate for being |
182 | /// the next item returned by the joined stream, but it could not be returned by the most |
183 | /// recent [`OrderedStream::poll_next_before`] call. |
184 | pub fn peek_buffered( |
185 | self: Pin<&mut Self>, |
186 | ) -> ( |
187 | Option<(&mut A::Data, &A::Ordering)>, |
188 | Option<(&mut B::Data, &B::Ordering)>, |
189 | ) { |
190 | match self.project().state { |
191 | JoinState::A(a, o) => (Some((a, o)), None), |
192 | JoinState::B(b, o) => (None, Some((b, o))), |
193 | _ => (None, None), |
194 | } |
195 | } |
196 | |
197 | /// Remove the buffered item, if one is present. |
198 | /// |
199 | /// This does not poll either underlying stream. See [Self::peek_buffered] for details on why |
200 | /// buffering exists. |
201 | pub fn take_buffered(self: Pin<&mut Self>) -> Option<(A::Data, A::Ordering)> { |
202 | let state = self.project().state; |
203 | match mem::replace(state, JoinState::None) { |
204 | JoinState::A(a, o) => Some((a, o)), |
205 | JoinState::B(b, o) => Some((b, o)), |
206 | other => { |
207 | *state = other; |
208 | None |
209 | } |
210 | } |
211 | } |
212 | } |
213 | |
214 | impl<A, B> OrderedStream for Join<A, B> |
215 | where |
216 | A: OrderedStream, |
217 | B: OrderedStream<Data = A::Data, Ordering = A::Ordering>, |
218 | { |
219 | type Data = A::Data; |
220 | type Ordering = A::Ordering; |
221 | |
222 | fn poll_next_before( |
223 | self: Pin<&mut Self>, |
224 | cx: &mut Context<'_>, |
225 | before: Option<&Self::Ordering>, |
226 | ) -> Poll<PollResult<Self::Ordering, Self::Data>> { |
227 | let mut this = self.project(); |
228 | let (mut poll_a, mut poll_b) = this.state.take_split(); |
229 | |
230 | poll_a.update(before, poll_b.ordering(), false, |ordering| { |
231 | this.stream_a.as_mut().poll_next_before(cx, ordering) |
232 | }); |
233 | if poll_b.update(before, poll_a.ordering(), false, |ordering| { |
234 | this.stream_b.as_mut().poll_next_before(cx, ordering) |
235 | }) { |
236 | // If B just got an item, it's possible that A already knows that it won't have any |
237 | // items before that item; we couldn't ask that question before. Ask it now. |
238 | poll_a.update(before, poll_b.ordering(), true, |ordering| { |
239 | this.stream_a.as_mut().poll_next_before(cx, ordering) |
240 | }); |
241 | } |
242 | |
243 | match (poll_a, poll_b) { |
244 | // Both are ready - we can judge ordering directly (simplest case). The first one is |
245 | // returned while the other one is buffered for the next poll. |
246 | (PollState::Item(a, ta), PollState::Item(b, tb)) => { |
247 | if ta <= tb { |
248 | *this.state = JoinState::B(b, tb); |
249 | Poll::Ready(PollResult::Item { |
250 | data: a, |
251 | ordering: ta, |
252 | }) |
253 | } else { |
254 | *this.state = JoinState::A(a, ta); |
255 | Poll::Ready(PollResult::Item { |
256 | data: b, |
257 | ordering: tb, |
258 | }) |
259 | } |
260 | } |
261 | |
262 | // If both sides are terminated, so are we. |
263 | (PollState::Terminated, PollState::Terminated) => { |
264 | *this.state = JoinState::Terminated; |
265 | Poll::Ready(PollResult::Terminated) |
266 | } |
267 | |
268 | // If one side is terminated, we can produce items directly from the other side. |
269 | (a, PollState::Terminated) => { |
270 | *this.state = JoinState::OnlyPollA; |
271 | a.into() |
272 | } |
273 | (PollState::Terminated, b) => { |
274 | *this.state = JoinState::OnlyPollB; |
275 | b.into() |
276 | } |
277 | |
278 | // If one side is pending, we can't return Ready until that gets resolved. Because we |
279 | // have already requested that our child streams wake us when it is possible to make |
280 | // any kind of progress, we meet the requirements to return Poll::Pending. |
281 | (PollState::Item(a, t), PollState::Pending) => { |
282 | *this.state = JoinState::A(a, t); |
283 | Poll::Pending |
284 | } |
285 | (PollState::Pending, PollState::Item(b, t)) => { |
286 | *this.state = JoinState::B(b, t); |
287 | Poll::Pending |
288 | } |
289 | (PollState::Pending, PollState::Pending) => Poll::Pending, |
290 | (PollState::Pending, PollState::NoneBefore) => Poll::Pending, |
291 | (PollState::NoneBefore, PollState::Pending) => Poll::Pending, |
292 | |
293 | // If both sides report NoneBefore, so can we. |
294 | (PollState::NoneBefore, PollState::NoneBefore) => Poll::Ready(PollResult::NoneBefore), |
295 | |
296 | (PollState::Item(data, ordering), PollState::NoneBefore) => { |
297 | // B was polled using either the Some value of (before) or using A's ordering. |
298 | // |
299 | // If before is set and is earlier than A's ordering, then B might later produce a |
300 | // value with (bt >= before && bt < at), so we can't return A's item yet and must |
301 | // buffer it. However, we can return None because neither stream will produce |
302 | // items before the ordering passed in before. |
303 | // |
304 | // If before is either None or after A's ordering, B's NoneBefore return represents a |
305 | // promise to not produce an item before A's, so we can return A's item now. |
306 | match before { |
307 | Some(before) if ordering > *before => { |
308 | *this.state = JoinState::A(data, ordering); |
309 | Poll::Ready(PollResult::NoneBefore) |
310 | } |
311 | _ => Poll::Ready(PollResult::Item { data, ordering }), |
312 | } |
313 | } |
314 | |
315 | (PollState::NoneBefore, PollState::Item(data, ordering)) => { |
316 | // A was polled using either the Some value of (before) or using B's ordering. |
317 | // |
318 | // By a mirror of the above argument, this NoneBefore result gives us permission to |
319 | // produce either B's item or NoneBefore. |
320 | match before { |
321 | Some(before) if ordering > *before => { |
322 | *this.state = JoinState::B(data, ordering); |
323 | Poll::Ready(PollResult::NoneBefore) |
324 | } |
325 | _ => Poll::Ready(PollResult::Item { data, ordering }), |
326 | } |
327 | } |
328 | } |
329 | } |
330 | |
331 | fn position_hint(&self) -> Option<MaybeBorrowed<'_, Self::Ordering>> { |
332 | let (a, b) = match &self.state { |
333 | JoinState::None => (self.stream_a.position_hint(), self.stream_b.position_hint()), |
334 | JoinState::A(_, t) => ( |
335 | Some(MaybeBorrowed::Borrowed(t)), |
336 | self.stream_b.position_hint(), |
337 | ), |
338 | JoinState::B(_, t) => ( |
339 | self.stream_b.position_hint(), |
340 | Some(MaybeBorrowed::Borrowed(t)), |
341 | ), |
342 | JoinState::OnlyPollA => return self.stream_a.position_hint(), |
343 | JoinState::OnlyPollB => return self.stream_b.position_hint(), |
344 | JoinState::Terminated => return None, |
345 | }; |
346 | // We can only provide a hint if we have a valid hint for both sides |
347 | match (a, b) { |
348 | (Some(a), Some(b)) if *a <= *b => Some(a), |
349 | (Some(_), Some(b)) => Some(b), |
350 | _ => None, |
351 | } |
352 | } |
353 | |
354 | fn size_hint(&self) -> (usize, Option<usize>) { |
355 | let extra = match &self.state { |
356 | JoinState::None => 0, |
357 | JoinState::A { .. } => 1, |
358 | JoinState::B { .. } => 1, |
359 | JoinState::OnlyPollA => return self.stream_a.size_hint(), |
360 | JoinState::OnlyPollB => return self.stream_b.size_hint(), |
361 | JoinState::Terminated => return (0, Some(0)), |
362 | }; |
363 | let (al, ah) = self.stream_a.size_hint(); |
364 | let (bl, bh) = self.stream_b.size_hint(); |
365 | let min = al.saturating_add(bl).saturating_add(extra); |
366 | let max = ah |
367 | .and_then(|a| bh.and_then(|b| a.checked_add(b))) |
368 | .and_then(|h| h.checked_add(extra)); |
369 | (min, max) |
370 | } |
371 | } |
372 | |
373 | impl<A, B> FusedOrderedStream for Join<A, B> |
374 | where |
375 | A: OrderedStream, |
376 | B: OrderedStream<Data = A::Data, Ordering = A::Ordering>, |
377 | { |
378 | fn is_terminated(&self) -> bool { |
379 | matches!(self.state, JoinState::Terminated) |
380 | } |
381 | } |
382 | |
383 | #[cfg (test)] |
384 | mod test { |
385 | extern crate alloc; |
386 | use crate::join; |
387 | use crate::FromStream; |
388 | use crate::OrderedStream; |
389 | use crate::OrderedStreamExt; |
390 | use crate::PollResult; |
391 | use alloc::rc::Rc; |
392 | use core::cell::Cell; |
393 | use core::pin::Pin; |
394 | use core::task::{Context, Poll}; |
395 | use futures_executor::block_on; |
396 | use futures_util::pin_mut; |
397 | use futures_util::stream::iter; |
398 | |
399 | #[derive (Debug, PartialEq)] |
400 | pub struct Message { |
401 | serial: u32, |
402 | } |
403 | |
404 | #[test ] |
405 | fn join_two() { |
406 | block_on(async { |
407 | let stream1 = iter([ |
408 | Message { serial: 1 }, |
409 | Message { serial: 4 }, |
410 | Message { serial: 5 }, |
411 | ]); |
412 | |
413 | let stream2 = iter([ |
414 | Message { serial: 2 }, |
415 | Message { serial: 3 }, |
416 | Message { serial: 6 }, |
417 | ]); |
418 | let mut joined = join( |
419 | FromStream::with_ordering(stream1, |m| m.serial), |
420 | FromStream::with_ordering(stream2, |m| m.serial), |
421 | ); |
422 | for i in 0..6 { |
423 | let msg = joined.next().await.unwrap(); |
424 | assert_eq!(msg.serial, i as u32 + 1); |
425 | } |
426 | }); |
427 | } |
428 | |
429 | #[test ] |
430 | fn join_one_slow() { |
431 | futures_executor::block_on(async { |
432 | pub struct DelayStream(Rc<Cell<u8>>); |
433 | |
434 | impl OrderedStream for DelayStream { |
435 | type Ordering = u32; |
436 | type Data = Message; |
437 | fn poll_next_before( |
438 | self: Pin<&mut Self>, |
439 | _: &mut Context<'_>, |
440 | before: Option<&Self::Ordering>, |
441 | ) -> Poll<PollResult<Self::Ordering, Self::Data>> { |
442 | match self.0.get() { |
443 | 0 => Poll::Pending, |
444 | 1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore), |
445 | 1 => Poll::Pending, |
446 | |
447 | 2 => { |
448 | self.0.set(3); |
449 | Poll::Ready(PollResult::Item { |
450 | data: Message { serial: 4 }, |
451 | ordering: 4, |
452 | }) |
453 | } |
454 | _ => Poll::Ready(PollResult::Terminated), |
455 | } |
456 | } |
457 | } |
458 | |
459 | let stream1 = iter([ |
460 | Message { serial: 1 }, |
461 | Message { serial: 3 }, |
462 | Message { serial: 5 }, |
463 | ]); |
464 | |
465 | let stream1 = FromStream::with_ordering(stream1, |m| m.serial); |
466 | let go = Rc::new(Cell::new(0)); |
467 | let stream2 = DelayStream(go.clone()); |
468 | |
469 | let join = join(stream1, stream2); |
470 | let waker = futures_util::task::noop_waker(); |
471 | let mut ctx = core::task::Context::from_waker(&waker); |
472 | |
473 | pin_mut!(join); |
474 | |
475 | // When the DelayStream has no information about what it contains, join returns Pending |
476 | // (since there could be a serial-0 message output of DelayStream) |
477 | assert_eq!( |
478 | join.as_mut().poll_next_before(&mut ctx, None), |
479 | Poll::Pending |
480 | ); |
481 | |
482 | go.set(1); |
483 | // Now the DelayStream will return NoneBefore on serial 1 |
484 | assert_eq!( |
485 | join.as_mut().poll_next_before(&mut ctx, None), |
486 | Poll::Ready(PollResult::Item { |
487 | data: Message { serial: 1 }, |
488 | ordering: 1, |
489 | }) |
490 | ); |
491 | // however, it does not (yet) do so for serial 3 |
492 | assert_eq!( |
493 | join.as_mut().poll_next_before(&mut ctx, None), |
494 | Poll::Pending |
495 | ); |
496 | |
497 | go.set(2); |
498 | assert_eq!( |
499 | join.as_mut().poll_next_before(&mut ctx, None), |
500 | Poll::Ready(PollResult::Item { |
501 | data: Message { serial: 3 }, |
502 | ordering: 3, |
503 | }) |
504 | ); |
505 | assert_eq!( |
506 | join.as_mut().poll_next_before(&mut ctx, None), |
507 | Poll::Ready(PollResult::Item { |
508 | data: Message { serial: 4 }, |
509 | ordering: 4, |
510 | }) |
511 | ); |
512 | assert_eq!( |
513 | join.as_mut().poll_next_before(&mut ctx, None), |
514 | Poll::Ready(PollResult::Item { |
515 | data: Message { serial: 5 }, |
516 | ordering: 5, |
517 | }) |
518 | ); |
519 | |
520 | assert_eq!( |
521 | join.as_mut().poll_next_before(&mut ctx, None), |
522 | Poll::Ready(PollResult::Terminated) |
523 | ); |
524 | }); |
525 | } |
526 | } |
527 | |