1 | use proc_macro::TokenStream; |
2 | use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree}; |
3 | use quote::quote; |
4 | use syn::parse::{Parse, ParseStream, Parser, Result}; |
5 | use syn::visit_mut::VisitMut; |
6 | |
7 | struct Scrub<'a> { |
8 | /// Whether the stream is a try stream. |
9 | is_try: bool, |
10 | /// The unit expression, `()`. |
11 | unit: Box<syn::Expr>, |
12 | has_yielded: bool, |
13 | crate_path: &'a TokenStream2, |
14 | } |
15 | |
16 | fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)> { |
17 | let mut input = TokenStream2::from(input).into_iter(); |
18 | let crate_path = match input.next().unwrap() { |
19 | TokenTree::Group(group) => group.stream(), |
20 | _ => panic!(), |
21 | }; |
22 | let stmts = syn::Block::parse_within.parse2(replace_for_await(input))?; |
23 | Ok((crate_path, stmts)) |
24 | } |
25 | |
26 | impl<'a> Scrub<'a> { |
27 | fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self { |
28 | Self { |
29 | is_try, |
30 | unit: syn::parse_quote!(()), |
31 | has_yielded: false, |
32 | crate_path, |
33 | } |
34 | } |
35 | } |
36 | |
37 | struct Partial<T>(T, TokenStream2); |
38 | |
39 | impl<T: Parse> Parse for Partial<T> { |
40 | fn parse(input: ParseStream) -> Result<Self> { |
41 | Ok(Partial(input.parse()?, input.parse()?)) |
42 | } |
43 | } |
44 | |
45 | fn visit_token_stream_impl( |
46 | visitor: &mut Scrub<'_>, |
47 | tokens: TokenStream2, |
48 | modified: &mut bool, |
49 | out: &mut TokenStream2, |
50 | ) { |
51 | use quote::ToTokens; |
52 | use quote::TokenStreamExt; |
53 | |
54 | let mut tokens = tokens.into_iter().peekable(); |
55 | while let Some(tt) = tokens.next() { |
56 | match tt { |
57 | TokenTree::Ident(i) if i == "yield" => { |
58 | let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect(); |
59 | match syn::parse2(stream) { |
60 | Ok(Partial(yield_expr, rest)) => { |
61 | let mut expr = syn::Expr::Yield(yield_expr); |
62 | visitor.visit_expr_mut(&mut expr); |
63 | expr.to_tokens(out); |
64 | *modified = true; |
65 | tokens = rest.into_iter().peekable(); |
66 | } |
67 | Err(e) => { |
68 | out.append_all(&mut e.to_compile_error().into_iter()); |
69 | *modified = true; |
70 | return; |
71 | } |
72 | } |
73 | } |
74 | TokenTree::Ident(i) if i == "stream" || i == "try_stream" => { |
75 | out.append(TokenTree::Ident(i)); |
76 | match tokens.peek() { |
77 | Some(TokenTree::Punct(p)) if p.as_char() == '!' => { |
78 | out.extend(tokens.next()); // ! |
79 | if let Some(TokenTree::Group(_)) = tokens.peek() { |
80 | out.extend(tokens.next()); // { .. } or [ .. ] or ( .. ) |
81 | } |
82 | } |
83 | _ => {} |
84 | } |
85 | } |
86 | TokenTree::Group(group) => { |
87 | let mut content = group.stream(); |
88 | *modified |= visitor.visit_token_stream(&mut content); |
89 | let mut new = Group::new(group.delimiter(), content); |
90 | new.set_span(group.span()); |
91 | out.append(new); |
92 | } |
93 | other => out.append(other), |
94 | } |
95 | } |
96 | } |
97 | |
98 | impl Scrub<'_> { |
99 | fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool { |
100 | let (mut out, mut modified) = (TokenStream2::new(), false); |
101 | visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out); |
102 | |
103 | if modified { |
104 | *tokens = out; |
105 | } |
106 | |
107 | modified |
108 | } |
109 | } |
110 | |
111 | impl VisitMut for Scrub<'_> { |
112 | fn visit_expr_mut(&mut self, i: &mut syn::Expr) { |
113 | match i { |
114 | syn::Expr::Yield(yield_expr) => { |
115 | self.has_yielded = true; |
116 | |
117 | syn::visit_mut::visit_expr_yield_mut(self, yield_expr); |
118 | |
119 | let value_expr = yield_expr.expr.as_ref().unwrap_or(&self.unit); |
120 | |
121 | // let ident = &self.yielder; |
122 | |
123 | *i = if self.is_try { |
124 | syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await } |
125 | } else { |
126 | syn::parse_quote! { __yield_tx.send(#value_expr).await } |
127 | }; |
128 | } |
129 | syn::Expr::Try(try_expr) => { |
130 | syn::visit_mut::visit_expr_try_mut(self, try_expr); |
131 | // let ident = &self.yielder; |
132 | let e = &try_expr.expr; |
133 | |
134 | *i = syn::parse_quote! { |
135 | match #e { |
136 | ::core::result::Result::Ok(v) => v, |
137 | ::core::result::Result::Err(e) => { |
138 | __yield_tx.send(::core::result::Result::Err(e.into())).await; |
139 | return; |
140 | } |
141 | } |
142 | }; |
143 | } |
144 | syn::Expr::Closure(_) | syn::Expr::Async(_) => { |
145 | // Don't transform inner closures or async blocks. |
146 | } |
147 | syn::Expr::ForLoop(expr) => { |
148 | syn::visit_mut::visit_expr_for_loop_mut(self, expr); |
149 | // TODO: Should we allow other attributes? |
150 | if expr.attrs.len() != 1 || !expr.attrs[0].meta.path().is_ident(AWAIT_ATTR_NAME) { |
151 | return; |
152 | } |
153 | let syn::ExprForLoop { |
154 | attrs, |
155 | label, |
156 | pat, |
157 | expr, |
158 | body, |
159 | .. |
160 | } = expr; |
161 | |
162 | attrs.pop().unwrap(); |
163 | |
164 | let crate_path = self.crate_path; |
165 | *i = syn::parse_quote! {{ |
166 | let mut __pinned = #expr; |
167 | let mut __pinned = unsafe { |
168 | ::core::pin::Pin::new_unchecked(&mut __pinned) |
169 | }; |
170 | #label |
171 | loop { |
172 | let #pat = match #crate_path::__private::next(&mut __pinned).await { |
173 | ::core::option::Option::Some(e) => e, |
174 | ::core::option::Option::None => break, |
175 | }; |
176 | #body |
177 | } |
178 | }} |
179 | } |
180 | _ => syn::visit_mut::visit_expr_mut(self, i), |
181 | } |
182 | } |
183 | |
184 | fn visit_macro_mut(&mut self, mac: &mut syn::Macro) { |
185 | let mac_ident = mac.path.segments.last().map(|p| &p.ident); |
186 | if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream" ) { |
187 | return; |
188 | } |
189 | |
190 | self.visit_token_stream(&mut mac.tokens); |
191 | } |
192 | |
193 | fn visit_item_mut(&mut self, i: &mut syn::Item) { |
194 | // Recurse into macros but otherwise don't transform inner items. |
195 | if let syn::Item::Macro(i) = i { |
196 | self.visit_macro_mut(&mut i.mac); |
197 | } |
198 | } |
199 | } |
200 | |
201 | /// The first token tree in the stream must be a group containing the path to the `async-stream` |
202 | /// crate. |
203 | #[proc_macro ] |
204 | #[doc (hidden)] |
205 | pub fn stream_inner(input: TokenStream) -> TokenStream { |
206 | let (crate_path, mut stmts) = match parse_input(input) { |
207 | Ok(x) => x, |
208 | Err(e) => return e.to_compile_error().into(), |
209 | }; |
210 | |
211 | let mut scrub = Scrub::new(false, &crate_path); |
212 | |
213 | for stmt in &mut stmts { |
214 | scrub.visit_stmt_mut(stmt); |
215 | } |
216 | |
217 | let dummy_yield = if scrub.has_yielded { |
218 | None |
219 | } else { |
220 | Some(quote!(if false { |
221 | __yield_tx.send(()).await; |
222 | })) |
223 | }; |
224 | |
225 | quote!({ |
226 | let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() }; |
227 | #crate_path::__private::AsyncStream::new(__yield_rx, async move { |
228 | #dummy_yield |
229 | #(#stmts)* |
230 | }) |
231 | }) |
232 | .into() |
233 | } |
234 | |
235 | /// The first token tree in the stream must be a group containing the path to the `async-stream` |
236 | /// crate. |
237 | #[proc_macro ] |
238 | #[doc (hidden)] |
239 | pub fn try_stream_inner(input: TokenStream) -> TokenStream { |
240 | let (crate_path, mut stmts) = match parse_input(input) { |
241 | Ok(x) => x, |
242 | Err(e) => return e.to_compile_error().into(), |
243 | }; |
244 | |
245 | let mut scrub = Scrub::new(true, &crate_path); |
246 | |
247 | for stmt in &mut stmts { |
248 | scrub.visit_stmt_mut(stmt); |
249 | } |
250 | |
251 | let dummy_yield = if scrub.has_yielded { |
252 | None |
253 | } else { |
254 | Some(quote!(if false { |
255 | __yield_tx.send(()).await; |
256 | })) |
257 | }; |
258 | |
259 | quote!({ |
260 | let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() }; |
261 | #crate_path::__private::AsyncStream::new(__yield_rx, async move { |
262 | #dummy_yield |
263 | #(#stmts)* |
264 | }) |
265 | }) |
266 | .into() |
267 | } |
268 | |
269 | // syn 2.0 wont parse `#[await] for x in xs {}` |
270 | // because `await` is a keyword, use `await_` instead |
271 | const AWAIT_ATTR_NAME: &str = "await_" ; |
272 | |
273 | /// Replace `for await` with `#[await] for`, which will be later transformed into a `next` loop. |
274 | fn replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2 { |
275 | let mut input = input.into_iter().peekable(); |
276 | let mut tokens = Vec::new(); |
277 | |
278 | while let Some(token) = input.next() { |
279 | match token { |
280 | TokenTree::Ident(ident) => { |
281 | match input.peek() { |
282 | Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => { |
283 | let next_span = next.span(); |
284 | let next = syn::Ident::new(AWAIT_ATTR_NAME, next_span); |
285 | tokens.extend(quote!(#[#next])); |
286 | let _ = input.next(); |
287 | } |
288 | _ => {} |
289 | } |
290 | tokens.push(ident.into()); |
291 | } |
292 | TokenTree::Group(group) => { |
293 | let stream = replace_for_await(group.stream()); |
294 | let mut new_group = Group::new(group.delimiter(), stream); |
295 | new_group.set_span(group.span()); |
296 | tokens.push(new_group.into()); |
297 | } |
298 | _ => tokens.push(token), |
299 | } |
300 | } |
301 | |
302 | tokens.into_iter().collect() |
303 | } |
304 | |