1 | use proc_macro2::TokenStream; |
2 | use quote::{format_ident, quote, ToTokens}; |
3 | use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token}; |
4 | |
5 | /// The `stream_select!` macro. |
6 | pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> { |
7 | let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?; |
8 | if args.len() < 2 { |
9 | return Ok(quote! { |
10 | compile_error!("stream select macro needs at least two arguments." ) |
11 | }); |
12 | } |
13 | let generic_idents = (0..args.len()).map(|i| format_ident!("_{}" , i)).collect::<Vec<_>>(); |
14 | let field_idents = (0..args.len()).map(|i| format_ident!("__{}" , i)).collect::<Vec<_>>(); |
15 | let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}" , i)).collect::<Vec<_>>(); |
16 | let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>(); |
17 | let args = args.iter().map(|e| e.to_token_stream()); |
18 | |
19 | Ok(quote! { |
20 | { |
21 | #[derive(Debug)] |
22 | struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*); |
23 | |
24 | enum StreamEnum<#(#generic_idents),*> { |
25 | #( |
26 | #generic_idents(#generic_idents) |
27 | ),*, |
28 | None, |
29 | } |
30 | |
31 | impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*> |
32 | where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)* |
33 | { |
34 | type Item = ITEM; |
35 | |
36 | fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> { |
37 | match self.get_mut() { |
38 | #( |
39 | Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx) |
40 | ),*, |
41 | Self::None => panic!("StreamEnum::None should never be polled!" ), |
42 | } |
43 | } |
44 | } |
45 | |
46 | impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*> |
47 | where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)* |
48 | { |
49 | type Item = ITEM; |
50 | |
51 | fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> { |
52 | let Self(#(ref mut #field_idents),*) = self.get_mut(); |
53 | #( |
54 | let mut #field_idents_2 = false; |
55 | )* |
56 | let mut any_pending = false; |
57 | { |
58 | let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*]; |
59 | __futures_crate::async_await::shuffle(&mut stream_array); |
60 | |
61 | for mut s in stream_array { |
62 | if let StreamEnum::None = s { |
63 | continue; |
64 | } else { |
65 | match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) { |
66 | r @ __futures_crate::task::Poll::Ready(Some(_)) => { |
67 | return r; |
68 | }, |
69 | __futures_crate::task::Poll::Pending => { |
70 | any_pending = true; |
71 | }, |
72 | __futures_crate::task::Poll::Ready(None) => { |
73 | match s { |
74 | #( |
75 | StreamEnum::#generic_idents(_) => { #field_idents_2 = true; } |
76 | ),*, |
77 | StreamEnum::None => panic!("StreamEnum::None should never be polled!" ), |
78 | } |
79 | }, |
80 | } |
81 | } |
82 | } |
83 | } |
84 | #( |
85 | if #field_idents_2 { |
86 | *#field_idents = None; |
87 | } |
88 | )* |
89 | if any_pending { |
90 | __futures_crate::task::Poll::Pending |
91 | } else { |
92 | __futures_crate::task::Poll::Ready(None) |
93 | } |
94 | } |
95 | |
96 | fn size_hint(&self) -> (usize, Option<usize>) { |
97 | let mut s = (0, Some(0)); |
98 | #( |
99 | if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) { |
100 | s.0 += new_hint.0; |
101 | // We can change this out for `.zip` when the MSRV is 1.46.0 or higher. |
102 | s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b)); |
103 | } |
104 | )* |
105 | s |
106 | } |
107 | } |
108 | |
109 | StreamSelect(#(Some(#args)),*) |
110 | |
111 | } |
112 | }) |
113 | } |
114 | |