1use crate::*;
2use core::ops::DerefMut;
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6fn poll_multiple_step<I, P, S>(
7 streams: I,
8 cx: &mut Context<'_>,
9 before: Option<&S::Ordering>,
10 mut retry: Option<&mut Option<S::Ordering>>,
11) -> Poll<PollResult<S::Ordering, S::Data>>
12where
13 I: IntoIterator<Item = Pin<P>>,
14 P: DerefMut<Target = Peekable<S>>,
15 S: OrderedStream,
16 S::Ordering: Clone,
17{
18 // The stream with the earliest item that is actually before the given point
19 let mut best: Option<Pin<P>> = None;
20 // true if we have a stream that has not terminated
21 let mut has_data = false;
22 let mut has_pending = false;
23 let mut skip_retry = false;
24 for mut stream in streams {
25 let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0));
26 let current_bound = match (before, best_before) {
27 (Some(given), Some(best)) if given <= best => Some(given),
28 (_, Some(best)) => Some(best),
29 (given, None) => given,
30 };
31 // improved is true if have improved the `before` bound from the initial value
32
33 match stream.as_mut().poll_peek_before(cx, current_bound) {
34 Poll::Pending => {
35 has_pending = true;
36 skip_retry = true;
37 }
38 Poll::Ready(PollResult::Terminated) => continue,
39 Poll::Ready(PollResult::NoneBefore) => {
40 has_data = true;
41 }
42 Poll::Ready(PollResult::Item { ordering, .. }) => {
43 has_data = true;
44 match current_bound {
45 Some(max) if max < ordering => continue,
46 _ => {}
47 }
48 match (&mut retry, before, has_pending) {
49 (Some(retry), Some(initial_bound), true) if ordering < initial_bound => {
50 // We have just improved the initial bound, so the streams that
51 // previously returned Pending might be able to return NoneBefore in a
52 // retry. This is only useful if there are no later Pending returns, so
53 // those will set skip_retry.
54 **retry = Some(ordering.clone());
55 skip_retry = false;
56 }
57 (Some(retry), None, true) => {
58 **retry = Some(ordering.clone());
59 skip_retry = false;
60 }
61 _ => {}
62 }
63 best = Some(stream);
64 }
65 }
66 }
67 if skip_retry {
68 retry.map(|r| *r = None);
69 }
70 match best {
71 _ if has_pending => Poll::Pending,
72 None if has_data => Poll::Ready(PollResult::NoneBefore),
73 None => Poll::Ready(PollResult::Terminated),
74 // This is guaranteed to return PollResult::Item
75 Some(mut stream) => stream.as_mut().poll_next_before(cx, before),
76 }
77}
78
79/// Join a collection of [`OrderedStream`]s.
80///
81/// This is similar to repeatedly using [`join()`] on all the streams in the contained collection.
82/// It is not optimized to avoid polling streams that are not ready, so it works best if the number
83/// of streams is relatively small.
84//
85// Unlike `FutureUnordered` or `SelectAll`, the ordering properties that this struct provides can
86// easily require that all items in the collection be consulted before returning any item. An
87// example of such a situation is a series of streams that all generate timestamps (locally) for
88// their items and only return `NoneBefore` for past timestamps. If only one stream produces an
89// item for each call to `JoinMultiple::poll_next_before`, that timestamp must be checked against
90// every other stream, and no amount of preparatory work or hints will help this.
91//
92// On the other hand, if all streams provide a position hint that matches their next item, it is
93// possible to build a priority queue to sort the streams and reduce the cost of a single poll from
94// `n` to `log(n)`. This does require maintaining a snapshot of the hints (so S::Ordering: Clone),
95// and will significantly increase the worst-case workload, so it should be a distinct type.
96#[derive(Debug, Default, Clone)]
97pub struct JoinMultiple<C>(pub C);
98impl<C> Unpin for JoinMultiple<C> {}
99
100impl<C, S> OrderedStream for JoinMultiple<C>
101where
102 for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
103 S: OrderedStream + Unpin,
104 S::Ordering: Clone,
105{
106 type Ordering = S::Ordering;
107 type Data = S::Data;
108 fn poll_next_before(
109 mut self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 before: Option<&S::Ordering>,
112 ) -> Poll<PollResult<S::Ordering, S::Data>> {
113 let mut retry = None;
114 let rv = poll_multiple_step(
115 self.as_mut().get_mut().0.into_iter().map(Pin::new),
116 cx,
117 before,
118 Some(&mut retry),
119 );
120 if rv.is_pending() && retry.is_some() {
121 poll_multiple_step(
122 self.get_mut().0.into_iter().map(Pin::new),
123 cx,
124 retry.as_ref(),
125 None,
126 )
127 } else {
128 rv
129 }
130 }
131}
132
133impl<C, S> FusedOrderedStream for JoinMultiple<C>
134where
135 for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
136 for<'a> &'a C: IntoIterator<Item = &'a Peekable<S>>,
137 S: OrderedStream + Unpin,
138 S::Ordering: Clone,
139{
140 fn is_terminated(&self) -> bool {
141 self.0.into_iter().all(|peekable| peekable.is_terminated())
142 }
143}
144
145pin_project_lite::pin_project! {
146 /// Join a collection of pinned [`OrderedStream`]s.
147 ///
148 /// This is identical to [`JoinMultiple`], but accepts [`OrderedStream`]s that are not [`Unpin`] by
149 /// requiring that the collection provide a pinned [`IntoIterator`] implementation.
150 ///
151 /// This is not a feature available in most `std` collections. If you wish to use them, you
152 /// should use `Box::pin` to make the stream [`Unpin`] before inserting it in the collection,
153 /// and then use [`JoinMultiple`] on the resulting collection.
154 #[derive(Debug,Default,Clone)]
155 pub struct JoinMultiplePin<C> {
156 #[pin]
157 pub streams: C,
158 }
159}
160
161impl<C> JoinMultiplePin<C> {
162 pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut C> {
163 self.project().streams
164 }
165}
166
167impl<C, S> OrderedStream for JoinMultiplePin<C>
168where
169 for<'a> Pin<&'a mut C>: IntoIterator<Item = Pin<&'a mut Peekable<S>>>,
170 S: OrderedStream,
171 S::Ordering: Clone,
172{
173 type Ordering = S::Ordering;
174 type Data = S::Data;
175 fn poll_next_before(
176 mut self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 before: Option<&S::Ordering>,
179 ) -> Poll<PollResult<S::Ordering, S::Data>> {
180 let mut retry: Option<::Ordering> = None;
181 let rv: Poll::Ordering, …>> = poll_multiple_step(self.as_mut().as_pin_mut(), cx, before, retry:Some(&mut retry));
182 if rv.is_pending() && retry.is_some() {
183 poll_multiple_step(self.as_pin_mut(), cx, before:retry.as_ref(), retry:None)
184 } else {
185 rv
186 }
187 }
188}
189
190#[cfg(test)]
191mod test {
192 extern crate alloc;
193
194 use crate::{FromStream, JoinMultiple, OrderedStream, OrderedStreamExt, PollResult};
195 use alloc::{boxed::Box, rc::Rc, vec, vec::Vec};
196 use core::{cell::Cell, pin::Pin, task::Context, task::Poll};
197 use futures_core::Stream;
198 use futures_util::{pin_mut, stream::iter};
199
200 #[derive(Debug, PartialEq)]
201 pub struct Message {
202 serial: u32,
203 }
204
205 #[test]
206 fn join_mutiple() {
207 futures_executor::block_on(async {
208 pub struct RemoteLogSource {
209 stream: Pin<Box<dyn Stream<Item = Message>>>,
210 }
211
212 let mut logs = [
213 RemoteLogSource {
214 stream: Box::pin(iter([
215 Message { serial: 1 },
216 Message { serial: 4 },
217 Message { serial: 5 },
218 ])),
219 },
220 RemoteLogSource {
221 stream: Box::pin(iter([
222 Message { serial: 2 },
223 Message { serial: 3 },
224 Message { serial: 6 },
225 ])),
226 },
227 ];
228 let streams: Vec<_> = logs
229 .iter_mut()
230 .map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable())
231 .collect();
232 let mut joined = JoinMultiple(streams);
233 for i in 0..6 {
234 let msg = joined.next().await.unwrap();
235 assert_eq!(msg.serial, i as u32 + 1);
236 }
237 });
238 }
239
240 #[test]
241 fn join_one_slow() {
242 futures_executor::block_on(async {
243 pub struct DelayStream(Rc<Cell<u8>>);
244
245 impl OrderedStream for DelayStream {
246 type Ordering = u32;
247 type Data = Message;
248 fn poll_next_before(
249 self: Pin<&mut Self>,
250 _: &mut Context<'_>,
251 before: Option<&Self::Ordering>,
252 ) -> Poll<PollResult<Self::Ordering, Self::Data>> {
253 match self.0.get() {
254 0 => Poll::Pending,
255 1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore),
256 1 => Poll::Pending,
257
258 2 => {
259 self.0.set(3);
260 Poll::Ready(PollResult::Item {
261 data: Message { serial: 4 },
262 ordering: 4,
263 })
264 }
265 _ => Poll::Ready(PollResult::Terminated),
266 }
267 }
268 }
269
270 let stream1 = iter([
271 Message { serial: 1 },
272 Message { serial: 3 },
273 Message { serial: 5 },
274 ]);
275
276 let stream1 = FromStream::with_ordering(stream1, |m| m.serial);
277 let go = Rc::new(Cell::new(0));
278 let stream2 = DelayStream(go.clone());
279
280 let stream1: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> =
281 Box::pin(stream1);
282 let stream2: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> =
283 Box::pin(stream2);
284 let streams = vec![stream1.peekable(), stream2.peekable()];
285 let join = JoinMultiple(streams);
286 let waker = futures_util::task::noop_waker();
287 let mut ctx = core::task::Context::from_waker(&waker);
288
289 pin_mut!(join);
290
291 // When the DelayStream has no information about what it contains, join returns Pending
292 // (since there could be a serial-0 message output of DelayStream)
293 assert_eq!(
294 join.as_mut().poll_next_before(&mut ctx, None),
295 Poll::Pending
296 );
297
298 go.set(1);
299 // Now the DelayStream will return NoneBefore on serial 1
300 assert_eq!(
301 join.as_mut().poll_next_before(&mut ctx, None),
302 Poll::Ready(PollResult::Item {
303 data: Message { serial: 1 },
304 ordering: 1,
305 })
306 );
307 // however, it does not (yet) do so for serial 3
308 assert_eq!(
309 join.as_mut().poll_next_before(&mut ctx, None),
310 Poll::Pending
311 );
312
313 go.set(2);
314 assert_eq!(
315 join.as_mut().poll_next_before(&mut ctx, None),
316 Poll::Ready(PollResult::Item {
317 data: Message { serial: 3 },
318 ordering: 3,
319 })
320 );
321 assert_eq!(
322 join.as_mut().poll_next_before(&mut ctx, None),
323 Poll::Ready(PollResult::Item {
324 data: Message { serial: 4 },
325 ordering: 4,
326 })
327 );
328 assert_eq!(
329 join.as_mut().poll_next_before(&mut ctx, None),
330 Poll::Ready(PollResult::Item {
331 data: Message { serial: 5 },
332 ordering: 5,
333 })
334 );
335
336 assert_eq!(
337 join.as_mut().poll_next_before(&mut ctx, None),
338 Poll::Ready(PollResult::Terminated)
339 );
340 });
341 }
342}
343