1use proc_macro::TokenStream;
2use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree};
3use quote::quote;
4use syn::parse::{Parse, ParseStream, Parser, Result};
5use syn::visit_mut::VisitMut;
6
7struct Scrub<'a> {
8 /// Whether the stream is a try stream.
9 is_try: bool,
10 /// The unit expression, `()`.
11 unit: Box<syn::Expr>,
12 has_yielded: bool,
13 crate_path: &'a TokenStream2,
14}
15
16fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)> {
17 let mut input = TokenStream2::from(input).into_iter();
18 let crate_path = match input.next().unwrap() {
19 TokenTree::Group(group) => group.stream(),
20 _ => panic!(),
21 };
22 let stmts = syn::Block::parse_within.parse2(replace_for_await(input))?;
23 Ok((crate_path, stmts))
24}
25
26impl<'a> Scrub<'a> {
27 fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self {
28 Self {
29 is_try,
30 unit: syn::parse_quote!(()),
31 has_yielded: false,
32 crate_path,
33 }
34 }
35}
36
37struct Partial<T>(T, TokenStream2);
38
39impl<T: Parse> Parse for Partial<T> {
40 fn parse(input: ParseStream) -> Result<Self> {
41 Ok(Partial(input.parse()?, input.parse()?))
42 }
43}
44
45fn visit_token_stream_impl(
46 visitor: &mut Scrub<'_>,
47 tokens: TokenStream2,
48 modified: &mut bool,
49 out: &mut TokenStream2,
50) {
51 use quote::ToTokens;
52 use quote::TokenStreamExt;
53
54 let mut tokens = tokens.into_iter().peekable();
55 while let Some(tt) = tokens.next() {
56 match tt {
57 TokenTree::Ident(i) if i == "yield" => {
58 let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect();
59 match syn::parse2(stream) {
60 Ok(Partial(yield_expr, rest)) => {
61 let mut expr = syn::Expr::Yield(yield_expr);
62 visitor.visit_expr_mut(&mut expr);
63 expr.to_tokens(out);
64 *modified = true;
65 tokens = rest.into_iter().peekable();
66 }
67 Err(e) => {
68 out.append_all(&mut e.to_compile_error().into_iter());
69 *modified = true;
70 return;
71 }
72 }
73 }
74 TokenTree::Ident(i) if i == "stream" || i == "try_stream" => {
75 out.append(TokenTree::Ident(i));
76 match tokens.peek() {
77 Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
78 out.extend(tokens.next()); // !
79 if let Some(TokenTree::Group(_)) = tokens.peek() {
80 out.extend(tokens.next()); // { .. } or [ .. ] or ( .. )
81 }
82 }
83 _ => {}
84 }
85 }
86 TokenTree::Group(group) => {
87 let mut content = group.stream();
88 *modified |= visitor.visit_token_stream(&mut content);
89 let mut new = Group::new(group.delimiter(), content);
90 new.set_span(group.span());
91 out.append(new);
92 }
93 other => out.append(other),
94 }
95 }
96}
97
98impl Scrub<'_> {
99 fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool {
100 let (mut out, mut modified) = (TokenStream2::new(), false);
101 visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out);
102
103 if modified {
104 *tokens = out;
105 }
106
107 modified
108 }
109}
110
111impl VisitMut for Scrub<'_> {
112 fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
113 match i {
114 syn::Expr::Yield(yield_expr) => {
115 self.has_yielded = true;
116
117 syn::visit_mut::visit_expr_yield_mut(self, yield_expr);
118
119 let value_expr = yield_expr.expr.as_ref().unwrap_or(&self.unit);
120
121 // let ident = &self.yielder;
122
123 *i = if self.is_try {
124 syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await }
125 } else {
126 syn::parse_quote! { __yield_tx.send(#value_expr).await }
127 };
128 }
129 syn::Expr::Try(try_expr) => {
130 syn::visit_mut::visit_expr_try_mut(self, try_expr);
131 // let ident = &self.yielder;
132 let e = &try_expr.expr;
133
134 *i = syn::parse_quote! {
135 match #e {
136 ::core::result::Result::Ok(v) => v,
137 ::core::result::Result::Err(e) => {
138 __yield_tx.send(::core::result::Result::Err(e.into())).await;
139 return;
140 }
141 }
142 };
143 }
144 syn::Expr::Closure(_) | syn::Expr::Async(_) => {
145 // Don't transform inner closures or async blocks.
146 }
147 syn::Expr::ForLoop(expr) => {
148 syn::visit_mut::visit_expr_for_loop_mut(self, expr);
149 // TODO: Should we allow other attributes?
150 if expr.attrs.len() != 1 || !expr.attrs[0].meta.path().is_ident(AWAIT_ATTR_NAME) {
151 return;
152 }
153 let syn::ExprForLoop {
154 attrs,
155 label,
156 pat,
157 expr,
158 body,
159 ..
160 } = expr;
161
162 attrs.pop().unwrap();
163
164 let crate_path = self.crate_path;
165 *i = syn::parse_quote! {{
166 let mut __pinned = #expr;
167 let mut __pinned = unsafe {
168 ::core::pin::Pin::new_unchecked(&mut __pinned)
169 };
170 #label
171 loop {
172 let #pat = match #crate_path::__private::next(&mut __pinned).await {
173 ::core::option::Option::Some(e) => e,
174 ::core::option::Option::None => break,
175 };
176 #body
177 }
178 }}
179 }
180 _ => syn::visit_mut::visit_expr_mut(self, i),
181 }
182 }
183
184 fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
185 let mac_ident = mac.path.segments.last().map(|p| &p.ident);
186 if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream") {
187 return;
188 }
189
190 self.visit_token_stream(&mut mac.tokens);
191 }
192
193 fn visit_item_mut(&mut self, i: &mut syn::Item) {
194 // Recurse into macros but otherwise don't transform inner items.
195 if let syn::Item::Macro(i) = i {
196 self.visit_macro_mut(&mut i.mac);
197 }
198 }
199}
200
201/// The first token tree in the stream must be a group containing the path to the `async-stream`
202/// crate.
203#[proc_macro]
204#[doc(hidden)]
205pub fn stream_inner(input: TokenStream) -> TokenStream {
206 let (crate_path, mut stmts) = match parse_input(input) {
207 Ok(x) => x,
208 Err(e) => return e.to_compile_error().into(),
209 };
210
211 let mut scrub = Scrub::new(false, &crate_path);
212
213 for stmt in &mut stmts {
214 scrub.visit_stmt_mut(stmt);
215 }
216
217 let dummy_yield = if scrub.has_yielded {
218 None
219 } else {
220 Some(quote!(if false {
221 __yield_tx.send(()).await;
222 }))
223 };
224
225 quote!({
226 let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() };
227 #crate_path::__private::AsyncStream::new(__yield_rx, async move {
228 #dummy_yield
229 #(#stmts)*
230 })
231 })
232 .into()
233}
234
235/// The first token tree in the stream must be a group containing the path to the `async-stream`
236/// crate.
237#[proc_macro]
238#[doc(hidden)]
239pub fn try_stream_inner(input: TokenStream) -> TokenStream {
240 let (crate_path, mut stmts) = match parse_input(input) {
241 Ok(x) => x,
242 Err(e) => return e.to_compile_error().into(),
243 };
244
245 let mut scrub = Scrub::new(true, &crate_path);
246
247 for stmt in &mut stmts {
248 scrub.visit_stmt_mut(stmt);
249 }
250
251 let dummy_yield = if scrub.has_yielded {
252 None
253 } else {
254 Some(quote!(if false {
255 __yield_tx.send(()).await;
256 }))
257 };
258
259 quote!({
260 let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() };
261 #crate_path::__private::AsyncStream::new(__yield_rx, async move {
262 #dummy_yield
263 #(#stmts)*
264 })
265 })
266 .into()
267}
268
269// syn 2.0 wont parse `#[await] for x in xs {}`
270// because `await` is a keyword, use `await_` instead
271const AWAIT_ATTR_NAME: &str = "await_";
272
273/// Replace `for await` with `#[await] for`, which will be later transformed into a `next` loop.
274fn replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2 {
275 let mut input = input.into_iter().peekable();
276 let mut tokens = Vec::new();
277
278 while let Some(token) = input.next() {
279 match token {
280 TokenTree::Ident(ident) => {
281 match input.peek() {
282 Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => {
283 let next_span = next.span();
284 let next = syn::Ident::new(AWAIT_ATTR_NAME, next_span);
285 tokens.extend(quote!(#[#next]));
286 let _ = input.next();
287 }
288 _ => {}
289 }
290 tokens.push(ident.into());
291 }
292 TokenTree::Group(group) => {
293 let stream = replace_for_await(group.stream());
294 let mut new_group = Group::new(group.delimiter(), stream);
295 new_group.set_span(group.span());
296 tokens.push(new_group.into());
297 }
298 _ => tokens.push(token),
299 }
300 }
301
302 tokens.into_iter().collect()
303}
304