1 | use crate::*; |
2 | use core::ops::DerefMut; |
3 | use core::pin::Pin; |
4 | use core::task::{Context, Poll}; |
5 | |
6 | fn poll_multiple_step<I, P, S>( |
7 | streams: I, |
8 | cx: &mut Context<'_>, |
9 | before: Option<&S::Ordering>, |
10 | mut retry: Option<&mut Option<S::Ordering>>, |
11 | ) -> Poll<PollResult<S::Ordering, S::Data>> |
12 | where |
13 | I: IntoIterator<Item = Pin<P>>, |
14 | P: DerefMut<Target = Peekable<S>>, |
15 | S: OrderedStream, |
16 | S::Ordering: Clone, |
17 | { |
18 | // The stream with the earliest item that is actually before the given point |
19 | let mut best: Option<Pin<P>> = None; |
20 | // true if we have a stream that has not terminated |
21 | let mut has_data = false; |
22 | let mut has_pending = false; |
23 | let mut skip_retry = false; |
24 | for mut stream in streams { |
25 | let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0)); |
26 | let current_bound = match (before, best_before) { |
27 | (Some(given), Some(best)) if given <= best => Some(given), |
28 | (_, Some(best)) => Some(best), |
29 | (given, None) => given, |
30 | }; |
31 | // improved is true if have improved the `before` bound from the initial value |
32 | |
33 | match stream.as_mut().poll_peek_before(cx, current_bound) { |
34 | Poll::Pending => { |
35 | has_pending = true; |
36 | skip_retry = true; |
37 | } |
38 | Poll::Ready(PollResult::Terminated) => continue, |
39 | Poll::Ready(PollResult::NoneBefore) => { |
40 | has_data = true; |
41 | } |
42 | Poll::Ready(PollResult::Item { ordering, .. }) => { |
43 | has_data = true; |
44 | match current_bound { |
45 | Some(max) if max < ordering => continue, |
46 | _ => {} |
47 | } |
48 | match (&mut retry, before, has_pending) { |
49 | (Some(retry), Some(initial_bound), true) if ordering < initial_bound => { |
50 | // We have just improved the initial bound, so the streams that |
51 | // previously returned Pending might be able to return NoneBefore in a |
52 | // retry. This is only useful if there are no later Pending returns, so |
53 | // those will set skip_retry. |
54 | **retry = Some(ordering.clone()); |
55 | skip_retry = false; |
56 | } |
57 | (Some(retry), None, true) => { |
58 | **retry = Some(ordering.clone()); |
59 | skip_retry = false; |
60 | } |
61 | _ => {} |
62 | } |
63 | best = Some(stream); |
64 | } |
65 | } |
66 | } |
67 | if skip_retry { |
68 | retry.map(|r| *r = None); |
69 | } |
70 | match best { |
71 | _ if has_pending => Poll::Pending, |
72 | None if has_data => Poll::Ready(PollResult::NoneBefore), |
73 | None => Poll::Ready(PollResult::Terminated), |
74 | // This is guaranteed to return PollResult::Item |
75 | Some(mut stream) => stream.as_mut().poll_next_before(cx, before), |
76 | } |
77 | } |
78 | |
79 | /// Join a collection of [`OrderedStream`]s. |
80 | /// |
81 | /// This is similar to repeatedly using [`join()`] on all the streams in the contained collection. |
82 | /// It is not optimized to avoid polling streams that are not ready, so it works best if the number |
83 | /// of streams is relatively small. |
84 | // |
85 | // Unlike `FutureUnordered` or `SelectAll`, the ordering properties that this struct provides can |
86 | // easily require that all items in the collection be consulted before returning any item. An |
87 | // example of such a situation is a series of streams that all generate timestamps (locally) for |
88 | // their items and only return `NoneBefore` for past timestamps. If only one stream produces an |
89 | // item for each call to `JoinMultiple::poll_next_before`, that timestamp must be checked against |
90 | // every other stream, and no amount of preparatory work or hints will help this. |
91 | // |
92 | // On the other hand, if all streams provide a position hint that matches their next item, it is |
93 | // possible to build a priority queue to sort the streams and reduce the cost of a single poll from |
94 | // `n` to `log(n)`. This does require maintaining a snapshot of the hints (so S::Ordering: Clone), |
95 | // and will significantly increase the worst-case workload, so it should be a distinct type. |
96 | #[derive (Debug, Default, Clone)] |
97 | pub struct JoinMultiple<C>(pub C); |
98 | impl<C> Unpin for JoinMultiple<C> {} |
99 | |
100 | impl<C, S> OrderedStream for JoinMultiple<C> |
101 | where |
102 | for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>, |
103 | S: OrderedStream + Unpin, |
104 | S::Ordering: Clone, |
105 | { |
106 | type Ordering = S::Ordering; |
107 | type Data = S::Data; |
108 | fn poll_next_before( |
109 | mut self: Pin<&mut Self>, |
110 | cx: &mut Context<'_>, |
111 | before: Option<&S::Ordering>, |
112 | ) -> Poll<PollResult<S::Ordering, S::Data>> { |
113 | let mut retry = None; |
114 | let rv = poll_multiple_step( |
115 | self.as_mut().get_mut().0.into_iter().map(Pin::new), |
116 | cx, |
117 | before, |
118 | Some(&mut retry), |
119 | ); |
120 | if rv.is_pending() && retry.is_some() { |
121 | poll_multiple_step( |
122 | self.get_mut().0.into_iter().map(Pin::new), |
123 | cx, |
124 | retry.as_ref(), |
125 | None, |
126 | ) |
127 | } else { |
128 | rv |
129 | } |
130 | } |
131 | } |
132 | |
133 | impl<C, S> FusedOrderedStream for JoinMultiple<C> |
134 | where |
135 | for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>, |
136 | for<'a> &'a C: IntoIterator<Item = &'a Peekable<S>>, |
137 | S: OrderedStream + Unpin, |
138 | S::Ordering: Clone, |
139 | { |
140 | fn is_terminated(&self) -> bool { |
141 | self.0.into_iter().all(|peekable| peekable.is_terminated()) |
142 | } |
143 | } |
144 | |
145 | pin_project_lite::pin_project! { |
146 | /// Join a collection of pinned [`OrderedStream`]s. |
147 | /// |
148 | /// This is identical to [`JoinMultiple`], but accepts [`OrderedStream`]s that are not [`Unpin`] by |
149 | /// requiring that the collection provide a pinned [`IntoIterator`] implementation. |
150 | /// |
151 | /// This is not a feature available in most `std` collections. If you wish to use them, you |
152 | /// should use `Box::pin` to make the stream [`Unpin`] before inserting it in the collection, |
153 | /// and then use [`JoinMultiple`] on the resulting collection. |
154 | #[derive(Debug,Default,Clone)] |
155 | pub struct JoinMultiplePin<C> { |
156 | #[pin] |
157 | pub streams: C, |
158 | } |
159 | } |
160 | |
161 | impl<C> JoinMultiplePin<C> { |
162 | pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut C> { |
163 | self.project().streams |
164 | } |
165 | } |
166 | |
167 | impl<C, S> OrderedStream for JoinMultiplePin<C> |
168 | where |
169 | for<'a> Pin<&'a mut C>: IntoIterator<Item = Pin<&'a mut Peekable<S>>>, |
170 | S: OrderedStream, |
171 | S::Ordering: Clone, |
172 | { |
173 | type Ordering = S::Ordering; |
174 | type Data = S::Data; |
175 | fn poll_next_before( |
176 | mut self: Pin<&mut Self>, |
177 | cx: &mut Context<'_>, |
178 | before: Option<&S::Ordering>, |
179 | ) -> Poll<PollResult<S::Ordering, S::Data>> { |
180 | let mut retry: Option<::Ordering> = None; |
181 | let rv: Poll::Ordering, …>> = poll_multiple_step(self.as_mut().as_pin_mut(), cx, before, retry:Some(&mut retry)); |
182 | if rv.is_pending() && retry.is_some() { |
183 | poll_multiple_step(self.as_pin_mut(), cx, before:retry.as_ref(), retry:None) |
184 | } else { |
185 | rv |
186 | } |
187 | } |
188 | } |
189 | |
190 | #[cfg (test)] |
191 | mod test { |
192 | extern crate alloc; |
193 | |
194 | use crate::{FromStream, JoinMultiple, OrderedStream, OrderedStreamExt, PollResult}; |
195 | use alloc::{boxed::Box, rc::Rc, vec, vec::Vec}; |
196 | use core::{cell::Cell, pin::Pin, task::Context, task::Poll}; |
197 | use futures_core::Stream; |
198 | use futures_util::{pin_mut, stream::iter}; |
199 | |
200 | #[derive (Debug, PartialEq)] |
201 | pub struct Message { |
202 | serial: u32, |
203 | } |
204 | |
205 | #[test ] |
206 | fn join_mutiple() { |
207 | futures_executor::block_on(async { |
208 | pub struct RemoteLogSource { |
209 | stream: Pin<Box<dyn Stream<Item = Message>>>, |
210 | } |
211 | |
212 | let mut logs = [ |
213 | RemoteLogSource { |
214 | stream: Box::pin(iter([ |
215 | Message { serial: 1 }, |
216 | Message { serial: 4 }, |
217 | Message { serial: 5 }, |
218 | ])), |
219 | }, |
220 | RemoteLogSource { |
221 | stream: Box::pin(iter([ |
222 | Message { serial: 2 }, |
223 | Message { serial: 3 }, |
224 | Message { serial: 6 }, |
225 | ])), |
226 | }, |
227 | ]; |
228 | let streams: Vec<_> = logs |
229 | .iter_mut() |
230 | .map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable()) |
231 | .collect(); |
232 | let mut joined = JoinMultiple(streams); |
233 | for i in 0..6 { |
234 | let msg = joined.next().await.unwrap(); |
235 | assert_eq!(msg.serial, i as u32 + 1); |
236 | } |
237 | }); |
238 | } |
239 | |
240 | #[test ] |
241 | fn join_one_slow() { |
242 | futures_executor::block_on(async { |
243 | pub struct DelayStream(Rc<Cell<u8>>); |
244 | |
245 | impl OrderedStream for DelayStream { |
246 | type Ordering = u32; |
247 | type Data = Message; |
248 | fn poll_next_before( |
249 | self: Pin<&mut Self>, |
250 | _: &mut Context<'_>, |
251 | before: Option<&Self::Ordering>, |
252 | ) -> Poll<PollResult<Self::Ordering, Self::Data>> { |
253 | match self.0.get() { |
254 | 0 => Poll::Pending, |
255 | 1 if matches!(before, Some(&1)) => Poll::Ready(PollResult::NoneBefore), |
256 | 1 => Poll::Pending, |
257 | |
258 | 2 => { |
259 | self.0.set(3); |
260 | Poll::Ready(PollResult::Item { |
261 | data: Message { serial: 4 }, |
262 | ordering: 4, |
263 | }) |
264 | } |
265 | _ => Poll::Ready(PollResult::Terminated), |
266 | } |
267 | } |
268 | } |
269 | |
270 | let stream1 = iter([ |
271 | Message { serial: 1 }, |
272 | Message { serial: 3 }, |
273 | Message { serial: 5 }, |
274 | ]); |
275 | |
276 | let stream1 = FromStream::with_ordering(stream1, |m| m.serial); |
277 | let go = Rc::new(Cell::new(0)); |
278 | let stream2 = DelayStream(go.clone()); |
279 | |
280 | let stream1: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> = |
281 | Box::pin(stream1); |
282 | let stream2: Pin<Box<dyn OrderedStream<Ordering = u32, Data = Message>>> = |
283 | Box::pin(stream2); |
284 | let streams = vec![stream1.peekable(), stream2.peekable()]; |
285 | let join = JoinMultiple(streams); |
286 | let waker = futures_util::task::noop_waker(); |
287 | let mut ctx = core::task::Context::from_waker(&waker); |
288 | |
289 | pin_mut!(join); |
290 | |
291 | // When the DelayStream has no information about what it contains, join returns Pending |
292 | // (since there could be a serial-0 message output of DelayStream) |
293 | assert_eq!( |
294 | join.as_mut().poll_next_before(&mut ctx, None), |
295 | Poll::Pending |
296 | ); |
297 | |
298 | go.set(1); |
299 | // Now the DelayStream will return NoneBefore on serial 1 |
300 | assert_eq!( |
301 | join.as_mut().poll_next_before(&mut ctx, None), |
302 | Poll::Ready(PollResult::Item { |
303 | data: Message { serial: 1 }, |
304 | ordering: 1, |
305 | }) |
306 | ); |
307 | // however, it does not (yet) do so for serial 3 |
308 | assert_eq!( |
309 | join.as_mut().poll_next_before(&mut ctx, None), |
310 | Poll::Pending |
311 | ); |
312 | |
313 | go.set(2); |
314 | assert_eq!( |
315 | join.as_mut().poll_next_before(&mut ctx, None), |
316 | Poll::Ready(PollResult::Item { |
317 | data: Message { serial: 3 }, |
318 | ordering: 3, |
319 | }) |
320 | ); |
321 | assert_eq!( |
322 | join.as_mut().poll_next_before(&mut ctx, None), |
323 | Poll::Ready(PollResult::Item { |
324 | data: Message { serial: 4 }, |
325 | ordering: 4, |
326 | }) |
327 | ); |
328 | assert_eq!( |
329 | join.as_mut().poll_next_before(&mut ctx, None), |
330 | Poll::Ready(PollResult::Item { |
331 | data: Message { serial: 5 }, |
332 | ordering: 5, |
333 | }) |
334 | ); |
335 | |
336 | assert_eq!( |
337 | join.as_mut().poll_next_before(&mut ctx, None), |
338 | Poll::Ready(PollResult::Terminated) |
339 | ); |
340 | }); |
341 | } |
342 | } |
343 | |