1use proc_macro2::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};
4
5/// The `stream_select!` macro.
6pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
7 let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
8 if args.len() < 2 {
9 return Ok(quote! {
10 compile_error!("stream select macro needs at least two arguments.")
11 });
12 }
13 let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
14 let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
15 let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
16 let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
17 let args = args.iter().map(|e| e.to_token_stream());
18
19 Ok(quote! {
20 {
21 #[derive(Debug)]
22 struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);
23
24 enum StreamEnum<#(#generic_idents),*> {
25 #(
26 #generic_idents(#generic_idents)
27 ),*,
28 None,
29 }
30
31 impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
32 where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
33 {
34 type Item = ITEM;
35
36 fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
37 match self.get_mut() {
38 #(
39 Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
40 ),*,
41 Self::None => panic!("StreamEnum::None should never be polled!"),
42 }
43 }
44 }
45
46 impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
47 where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
48 {
49 type Item = ITEM;
50
51 fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
52 let Self(#(ref mut #field_idents),*) = self.get_mut();
53 #(
54 let mut #field_idents_2 = false;
55 )*
56 let mut any_pending = false;
57 {
58 let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
59 __futures_crate::async_await::shuffle(&mut stream_array);
60
61 for mut s in stream_array {
62 if let StreamEnum::None = s {
63 continue;
64 } else {
65 match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
66 r @ __futures_crate::task::Poll::Ready(Some(_)) => {
67 return r;
68 },
69 __futures_crate::task::Poll::Pending => {
70 any_pending = true;
71 },
72 __futures_crate::task::Poll::Ready(None) => {
73 match s {
74 #(
75 StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
76 ),*,
77 StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
78 }
79 },
80 }
81 }
82 }
83 }
84 #(
85 if #field_idents_2 {
86 *#field_idents = None;
87 }
88 )*
89 if any_pending {
90 __futures_crate::task::Poll::Pending
91 } else {
92 __futures_crate::task::Poll::Ready(None)
93 }
94 }
95
96 fn size_hint(&self) -> (usize, Option<usize>) {
97 let mut s = (0, Some(0));
98 #(
99 if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
100 s.0 += new_hint.0;
101 // We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
102 s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
103 }
104 )*
105 s
106 }
107 }
108
109 StreamSelect(#(Some(#args)),*)
110
111 }
112 })
113}
114