1use crate::*;
2use core::mem;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6pin_project_lite::pin_project! {
7 /// A stream for the [`join`](fn.join.html) function.
8 #[derive(Debug)]
9 pub struct Join<A, B>
10 where
11 A: OrderedStream,
12 B: OrderedStream<Data = A::Data, Ordering=A::Ordering>,
13 {
14 #[pin]
15 stream_a: A,
16 #[pin]
17 stream_b: B,
18 state: JoinState<A::Data, B::Data, A::Ordering>,
19 }
20}
21
22/// Join two streams while preserving the overall ordering of elements.
23///
24/// You can think of this as implementing the "merge" step of a merge sort on the two streams,
25/// producing a single stream that is sorted given two sorted streams. If the streams return
26/// [`PollResult::NoneBefore`] as intended, then the joined stream will be able to produce items
27/// when only one of the sources has unblocked.
28pub fn join<A, B>(stream_a: A, stream_b: B) -> Join<A, B>
29where
30 A: OrderedStream,
31 B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
32{
33 Join {
34 stream_a,
35 stream_b,
36 state: JoinState::None,
37 }
38}
39
40#[derive(Debug)]
41enum JoinState<A, B, T> {
42 None,
43 A(A, T),
44 B(B, T),
45 OnlyPollA,
46 OnlyPollB,
47 Terminated,
48}
49
50impl<A, B, T> JoinState<A, B, T> {
51 fn take_split(&mut self) -> (PollState<A, T>, PollState<B, T>) {
52 match mem::replace(self, src:JoinState::None) {
53 JoinState::None => (PollState::Pending, PollState::Pending),
54 JoinState::A(a: A, t: T) => (PollState::Item(a, t), PollState::Pending),
55 JoinState::B(b: B, t: T) => (PollState::Pending, PollState::Item(b, t)),
56 JoinState::OnlyPollA => (PollState::Pending, PollState::Terminated),
57 JoinState::OnlyPollB => (PollState::Terminated, PollState::Pending),
58 JoinState::Terminated => (PollState::Terminated, PollState::Terminated),
59 }
60 }
61}
62
63/// A helper equivalent to Poll<PollResult<T, I>> but easier to match
64pub(crate) enum PollState<I, T> {
65 Item(I, T),
66 Pending,
67 NoneBefore,
68 Terminated,
69}
70
71impl<I, T: Ord> PollState<I, T> {
72 fn ordering(&self) -> Option<&T> {
73 match self {
74 Self::Item(_, t) => Some(t),
75 _ => None,
76 }
77 }
78
79 fn update(
80 &mut self,
81 before: Option<&T>,
82 other_token: Option<&T>,
83 retry: bool,
84 run: impl FnOnce(Option<&T>) -> Poll<PollResult<T, I>>,
85 ) -> bool {
86 match self {
87 // Do not re-poll if we have an item already or if we are terminated
88 Self::Item { .. } | Self::Terminated => return false,
89
90 // No need to re-poll if we already declared no items <= before
91 Self::NoneBefore if retry => return false,
92
93 _ => {}
94 }
95
96 // Run the poll with the earlier of the two tokens to avoid transitioning to Pending (which
97 // will stall the Join) when we could have transitioned to NoneBefore.
98 let ordering = match (before, other_token) {
99 (Some(u), Some(o)) => {
100 if *u > *o {
101 // The other ordering is earlier - so a retry might let us upgrade a Pending to a
102 // NoneBefore
103 Some(o)
104 } else if retry {
105 // A retry will not improve matters, so don't bother
106 return false;
107 } else {
108 Some(u)
109 }
110 }
111 (Some(t), None) | (None, Some(t)) => Some(t),
112 (None, None) => None,
113 };
114
115 *self = run(ordering).into();
116 matches!(self, Self::Item { .. })
117 }
118}
119
120impl<I, T> From<PollState<I, T>> for Poll<PollResult<T, I>> {
121 fn from(poll: PollState<I, T>) -> Self {
122 match poll {
123 PollState::Item(data: I, ordering: T) => Poll::Ready(PollResult::Item { data, ordering }),
124 PollState::Pending => Poll::Pending,
125 PollState::NoneBefore => Poll::Ready(PollResult::NoneBefore),
126 PollState::Terminated => Poll::Ready(PollResult::Terminated),
127 }
128 }
129}
130
131impl<I, T> From<Poll<PollResult<T, I>>> for PollState<I, T> {
132 fn from(poll: Poll<PollResult<T, I>>) -> Self {
133 match poll {
134 Poll::Ready(PollResult::Item { data: I, ordering: T }) => Self::Item(data, ordering),
135 Poll::Ready(PollResult::NoneBefore) => Self::NoneBefore,
136 Poll::Ready(PollResult::Terminated) => Self::Terminated,
137 Poll::Pending => Self::Pending,
138 }
139 }
140}
141
142impl<A, B> Join<A, B>
143where
144 A: OrderedStream,
145 B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
146{
147 /// Split into the source streams.
148 ///
149 /// This method returns the source streams along with any buffered item and its
150 /// ordering.
151 pub fn into_inner(self) -> (A, B, Option<(A::Data, A::Ordering)>) {
152 let item = match self.state {
153 JoinState::A(a, o) => Some((a, o)),
154 JoinState::B(b, o) => Some((b, o)),
155 _ => None,
156 };
157
158 (self.stream_a, self.stream_b, item)
159 }
160
161 /// Provide direct access to the underlying stream.
162 ///
163 /// This may be useful if the stream provides APIs beyond [OrderedStream]. Note that the join
164 /// itself may be buffering an item from this stream, so you should consult
165 /// [Self::peek_buffered] and, if needed, [Self::take_buffered] before polling it directly.
166 pub fn stream_a(self: Pin<&mut Self>) -> Pin<&mut A> {
167 self.project().stream_a
168 }
169
170 /// Provide direct access to the underlying stream.
171 ///
172 /// This may be useful if the stream provides APIs beyond [OrderedStream]. Note that the join
173 /// itself may be buffering an item from this stream, so you should consult
174 /// [Self::peek_buffered] and, if needed, [Self::take_buffered] before polling it directly.
175 pub fn stream_b(self: Pin<&mut Self>) -> Pin<&mut B> {
176 self.project().stream_b
177 }
178
179 /// Allow access to the buffered item, if any.
180 ///
181 /// At most one of the two sides will be `Some`. The returned item is a candidate for being
182 /// the next item returned by the joined stream, but it could not be returned by the most
183 /// recent [`OrderedStream::poll_next_before`] call.
184 pub fn peek_buffered(
185 self: Pin<&mut Self>,
186 ) -> (
187 Option<(&mut A::Data, &A::Ordering)>,
188 Option<(&mut B::Data, &B::Ordering)>,
189 ) {
190 match self.project().state {
191 JoinState::A(a, o) => (Some((a, o)), None),
192 JoinState::B(b, o) => (None, Some((b, o))),
193 _ => (None, None),
194 }
195 }
196
197 /// Remove the buffered item, if one is present.
198 ///
199 /// This does not poll either underlying stream. See [Self::peek_buffered] for details on why
200 /// buffering exists.
201 pub fn take_buffered(self: Pin<&mut Self>) -> Option<(A::Data, A::Ordering)> {
202 let state = self.project().state;
203 match mem::replace(state, JoinState::None) {
204 JoinState::A(a, o) => Some((a, o)),
205 JoinState::B(b, o) => Some((b, o)),
206 other => {
207 *state = other;
208 None
209 }
210 }
211 }
212}
213
214impl<A, B> OrderedStream for Join<A, B>
215where
216 A: OrderedStream,
217 B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
218{
219 type Data = A::Data;
220 type Ordering = A::Ordering;
221
222 fn poll_next_before(
223 self: Pin<&mut Self>,
224 cx: &mut Context<'_>,
225 before: Option<&Self::Ordering>,
226 ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
227 let mut this = self.project();
228 let (mut poll_a, mut poll_b) = this.state.take_split();
229
230 poll_a.update(before, poll_b.ordering(), false, |ordering| {
231 this.stream_a.as_mut().poll_next_before(cx, ordering)
232 });
233 if poll_b.update(before, poll_a.ordering(), false, |ordering| {
234 this.stream_b.as_mut().poll_next_before(cx, ordering)
235 }) {
236 // If B just got an item, it's possible that A already knows that it won't have any
237 // items before that item; we couldn't ask that question before. Ask it now.
238 poll_a.update(before, poll_b.ordering(), true, |ordering| {
239 this.stream_a.as_mut().poll_next_before(cx, ordering)
240 });
241 }
242
243 match (poll_a, poll_b) {
244 // Both are ready - we can judge ordering directly (simplest case). The first one is
245 // returned while the other one is buffered for the next poll.
246 (PollState::Item(a, ta), PollState::Item(b, tb)) => {
247 if ta <= tb {
248 *this.state = JoinState::B(b, tb);
249 Poll::Ready(PollResult::Item {
250 data: a,
251 ordering: ta,
252 })
253 } else {
254 *this.state = JoinState::A(a, ta);
255 Poll::Ready(PollResult::Item {
256 data: b,
257 ordering: tb,
258 })
259 }
260 }
261
262 // If both sides are terminated, so are we.
263 (PollState::Terminated, PollState::Terminated) => {
264 *this.state = JoinState::Terminated;
265 Poll::Ready(PollResult::Terminated)
266 }
267
268 // If one side is terminated, we can produce items directly from the other side.
269 (a, PollState::Terminated) => {
270 *this.state = JoinState::OnlyPollA;
271 a.into()
272 }
273 (PollState::Terminated, b) => {
274 *this.state = JoinState::OnlyPollB;
275 b.into()
276 }
277
278 // If one side is pending, we can't return Ready until that gets resolved. Because we
279 // have already requested that our child streams wake us when it is possible to make
280 // any kind of progress, we meet the requirements to return Poll::Pending.
281 (PollState::Item(a, t), PollState::Pending) => {
282 *this.state = JoinState::A(a, t);
283 Poll::Pending
284 }
285 (PollState::Pending, PollState::Item(b, t)) => {
286 *this.state = JoinState::B(b, t);
287 Poll::Pending
288 }
289 (PollState::Pending, PollState::Pending) => Poll::Pending,
290 (PollState::Pending, PollState::NoneBefore) => Poll::Pending,
291 (PollState::NoneBefore, PollState::Pending) => Poll::Pending,
292
293 // If both sides report NoneBefore, so can we.
294 (PollState::NoneBefore, PollState::NoneBefore) => Poll::Ready(PollResult::NoneBefore),
295
296 (PollState::Item(data, ordering), PollState::NoneBefore) => {
297 // B was polled using either the Some value of (before) or using A's ordering.
298 //
299 // If before is set and is earlier than A's ordering, then B might later produce a
300 // value with (bt >= before && bt < at), so we can't return A's item yet and must
301 // buffer it. However, we can return None because neither stream will produce
302 // items before the ordering passed in before.
303 //
304 // If before is either None or after A's ordering, B's NoneBefore return represents a
305 // promise to not produce an item before A's, so we can return A's item now.
306 match before {
307 Some(before) if ordering > *before => {
308 *this.state = JoinState::A(data, ordering);
309 Poll::Ready(PollResult::NoneBefore)
310 }
311 _ => Poll::Ready(PollResult::Item { data, ordering }),
312 }
313 }
314
315 (PollState::NoneBefore, PollState::Item(data, ordering)) => {
316 // A was polled using either the Some value of (before) or using B's ordering.
317 //
318 // By a mirror of the above argument, this NoneBefore result gives us permission to
319 // produce either B's item or NoneBefore.
320 match before {
321 Some(before) if ordering > *before => {
322 *this.state = JoinState::B(data, ordering);
323 Poll::Ready(PollResult::NoneBefore)
324 }
325 _ => Poll::Ready(PollResult::Item { data, ordering }),
326 }
327 }
328 }
329 }
330
331 fn position_hint(&self) -> Option<MaybeBorrowed<'_, Self::Ordering>> {
332 let (a, b) = match &self.state {
333 JoinState::None => (self.stream_a.position_hint(), self.stream_b.position_hint()),
334 JoinState::A(_, t) => (
335 Some(MaybeBorrowed::Borrowed(t)),
336 self.stream_b.position_hint(),
337 ),
338 JoinState::B(_, t) => (
339 self.stream_b.position_hint(),
340 Some(MaybeBorrowed::Borrowed(t)),
341 ),
342 JoinState::OnlyPollA => return self.stream_a.position_hint(),
343 JoinState::OnlyPollB => return self.stream_b.position_hint(),
344 JoinState::Terminated => return None,
345 };
346 // We can only provide a hint if we have a valid hint for both sides
347 match (a, b) {
348 (Some(a), Some(b)) if *a <= *b => Some(a),
349 (Some(_), Some(b)) => Some(b),
350 _ => None,
351 }
352 }
353
354 fn size_hint(&self) -> (usize, Option<usize>) {
355 let extra = match &self.state {
356 JoinState::None => 0,
357 JoinState::A { .. } => 1,
358 JoinState::B { .. } => 1,
359 JoinState::OnlyPollA => return self.stream_a.size_hint(),
360 JoinState::OnlyPollB => return self.stream_b.size_hint(),
361 JoinState::Terminated => return (0, Some(0)),
362 };
363 let (al, ah) = self.stream_a.size_hint();
364 let (bl, bh) = self.stream_b.size_hint();
365 let min = al.saturating_add(bl).saturating_add(extra);
366 let max = ah
367 .and_then(|a| bh.and_then(|b| a.checked_add(b)))
368 .and_then(|h| h.checked_add(extra));
369 (min, max)
370 }
371}
372
373impl<A, B> FusedOrderedStream for Join<A, B>
374where
375 A: OrderedStream,
376 B: OrderedStream<Data = A::Data, Ordering = A::Ordering>,
377{
378 fn is_terminated(&self) -> bool {
379 matches!(self.state, JoinState::Terminated)
380 }
381}
382
383#[cfg(test)]
384mod test {
385 extern crate alloc;
386 use crate::join;
387 use crate::FromStream;
388 use crate::OrderedStream;
389 use crate::OrderedStreamExt;
390 use crate::PollResult;
391 use alloc::rc::Rc;
392 use core::cell::Cell;
393 use core::pin::Pin;
394 use core::task::{Context, Poll};
395 use futures_executor::block_on;
396 use futures_util::pin_mut;
397 use futures_util::stream::iter;
398
399 #[derive(Debug, PartialEq)]
400 pub struct Message {
401 serial: u32,
402 }
403
404 #[test]
405 fn join_two() {
406 block_on(async {
407 let stream1 = iter([
408 Message { serial: 1 },
409 Message { serial: 4 },
410 Message { serial: 5 },
411 ]);
412
413 let stream2 = iter([
414 Message { serial: 2 },
415 Message { serial: 3 },
416 Message { serial: 6 },
417 ]);
418 let mut joined = join(
419 FromStream::with_ordering(stream1, |m| m.serial),
420 FromStream::with_ordering(stream2, |m| m.serial),
421 );
422 for i in 0..6 {
423 let msg = joined.next().await.unwrap();
424 assert_eq!(msg.serial, i as u32 + 1);
425 }
426 });
427 }
428
429 #[test]
430 fn join_one_slow() {
431 futures_executor::block_on(async {
432 pub struct DelayStream(Rc<Cell<u8>>);
433
434 impl OrderedStream for DelayStream {
435 type Ordering = u32;
436 type Data = Message;
437 fn poll_next_before(
438 self: Pin<&mut Self>,
439 _: &mut Context<'_>,
440 before: Option<&Self::Ordering>,
441 ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
442 match self.0.get() {
443 0 => Poll::Pending,
444 1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore),
445 1 => Poll::Pending,
446
447 2 => {
448 self.0.set(3);
449 Poll::Ready(PollResult::Item {
450 data: Message { serial: 4 },
451 ordering: 4,
452 })
453 }
454 _ => Poll::Ready(PollResult::Terminated),
455 }
456 }
457 }
458
459 let stream1 = iter([
460 Message { serial: 1 },
461 Message { serial: 3 },
462 Message { serial: 5 },
463 ]);
464
465 let stream1 = FromStream::with_ordering(stream1, |m| m.serial);
466 let go = Rc::new(Cell::new(0));
467 let stream2 = DelayStream(go.clone());
468
469 let join = join(stream1, stream2);
470 let waker = futures_util::task::noop_waker();
471 let mut ctx = core::task::Context::from_waker(&waker);
472
473 pin_mut!(join);
474
475 // When the DelayStream has no information about what it contains, join returns Pending
476 // (since there could be a serial-0 message output of DelayStream)
477 assert_eq!(
478 join.as_mut().poll_next_before(&mut ctx, None),
479 Poll::Pending
480 );
481
482 go.set(1);
483 // Now the DelayStream will return NoneBefore on serial 1
484 assert_eq!(
485 join.as_mut().poll_next_before(&mut ctx, None),
486 Poll::Ready(PollResult::Item {
487 data: Message { serial: 1 },
488 ordering: 1,
489 })
490 );
491 // however, it does not (yet) do so for serial 3
492 assert_eq!(
493 join.as_mut().poll_next_before(&mut ctx, None),
494 Poll::Pending
495 );
496
497 go.set(2);
498 assert_eq!(
499 join.as_mut().poll_next_before(&mut ctx, None),
500 Poll::Ready(PollResult::Item {
501 data: Message { serial: 3 },
502 ordering: 3,
503 })
504 );
505 assert_eq!(
506 join.as_mut().poll_next_before(&mut ctx, None),
507 Poll::Ready(PollResult::Item {
508 data: Message { serial: 4 },
509 ordering: 4,
510 })
511 );
512 assert_eq!(
513 join.as_mut().poll_next_before(&mut ctx, None),
514 Poll::Ready(PollResult::Item {
515 data: Message { serial: 5 },
516 ordering: 5,
517 })
518 );
519
520 assert_eq!(
521 join.as_mut().poll_next_before(&mut ctx, None),
522 Poll::Ready(PollResult::Terminated)
523 );
524 });
525 }
526}
527