| 1 | use proc_macro::TokenStream; |
| 2 | use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree}; |
| 3 | use quote::quote; |
| 4 | use syn::parse::{Parse, ParseStream, Parser, Result}; |
| 5 | use syn::visit_mut::VisitMut; |
| 6 | |
| 7 | struct Scrub<'a> { |
| 8 | /// Whether the stream is a try stream. |
| 9 | is_try: bool, |
| 10 | /// The unit expression, `()`. |
| 11 | unit: Box<syn::Expr>, |
| 12 | has_yielded: bool, |
| 13 | crate_path: &'a TokenStream2, |
| 14 | } |
| 15 | |
| 16 | fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)> { |
| 17 | let mut input = TokenStream2::from(input).into_iter(); |
| 18 | let crate_path = match input.next().unwrap() { |
| 19 | TokenTree::Group(group) => group.stream(), |
| 20 | _ => panic!(), |
| 21 | }; |
| 22 | let stmts = syn::Block::parse_within.parse2(replace_for_await(input))?; |
| 23 | Ok((crate_path, stmts)) |
| 24 | } |
| 25 | |
| 26 | impl<'a> Scrub<'a> { |
| 27 | fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self { |
| 28 | Self { |
| 29 | is_try, |
| 30 | unit: syn::parse_quote!(()), |
| 31 | has_yielded: false, |
| 32 | crate_path, |
| 33 | } |
| 34 | } |
| 35 | } |
| 36 | |
| 37 | struct Partial<T>(T, TokenStream2); |
| 38 | |
| 39 | impl<T: Parse> Parse for Partial<T> { |
| 40 | fn parse(input: ParseStream) -> Result<Self> { |
| 41 | Ok(Partial(input.parse()?, input.parse()?)) |
| 42 | } |
| 43 | } |
| 44 | |
| 45 | fn visit_token_stream_impl( |
| 46 | visitor: &mut Scrub<'_>, |
| 47 | tokens: TokenStream2, |
| 48 | modified: &mut bool, |
| 49 | out: &mut TokenStream2, |
| 50 | ) { |
| 51 | use quote::ToTokens; |
| 52 | use quote::TokenStreamExt; |
| 53 | |
| 54 | let mut tokens = tokens.into_iter().peekable(); |
| 55 | while let Some(tt) = tokens.next() { |
| 56 | match tt { |
| 57 | TokenTree::Ident(i) if i == "yield" => { |
| 58 | let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect(); |
| 59 | match syn::parse2(stream) { |
| 60 | Ok(Partial(yield_expr, rest)) => { |
| 61 | let mut expr = syn::Expr::Yield(yield_expr); |
| 62 | visitor.visit_expr_mut(&mut expr); |
| 63 | expr.to_tokens(out); |
| 64 | *modified = true; |
| 65 | tokens = rest.into_iter().peekable(); |
| 66 | } |
| 67 | Err(e) => { |
| 68 | out.append_all(&mut e.to_compile_error().into_iter()); |
| 69 | *modified = true; |
| 70 | return; |
| 71 | } |
| 72 | } |
| 73 | } |
| 74 | TokenTree::Ident(i) if i == "stream" || i == "try_stream" => { |
| 75 | out.append(TokenTree::Ident(i)); |
| 76 | match tokens.peek() { |
| 77 | Some(TokenTree::Punct(p)) if p.as_char() == '!' => { |
| 78 | out.extend(tokens.next()); // ! |
| 79 | if let Some(TokenTree::Group(_)) = tokens.peek() { |
| 80 | out.extend(tokens.next()); // { .. } or [ .. ] or ( .. ) |
| 81 | } |
| 82 | } |
| 83 | _ => {} |
| 84 | } |
| 85 | } |
| 86 | TokenTree::Group(group) => { |
| 87 | let mut content = group.stream(); |
| 88 | *modified |= visitor.visit_token_stream(&mut content); |
| 89 | let mut new = Group::new(group.delimiter(), content); |
| 90 | new.set_span(group.span()); |
| 91 | out.append(new); |
| 92 | } |
| 93 | other => out.append(other), |
| 94 | } |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | impl Scrub<'_> { |
| 99 | fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool { |
| 100 | let (mut out, mut modified) = (TokenStream2::new(), false); |
| 101 | visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out); |
| 102 | |
| 103 | if modified { |
| 104 | *tokens = out; |
| 105 | } |
| 106 | |
| 107 | modified |
| 108 | } |
| 109 | } |
| 110 | |
| 111 | impl VisitMut for Scrub<'_> { |
| 112 | fn visit_expr_mut(&mut self, i: &mut syn::Expr) { |
| 113 | match i { |
| 114 | syn::Expr::Yield(yield_expr) => { |
| 115 | self.has_yielded = true; |
| 116 | |
| 117 | syn::visit_mut::visit_expr_yield_mut(self, yield_expr); |
| 118 | |
| 119 | let value_expr = yield_expr.expr.as_ref().unwrap_or(&self.unit); |
| 120 | |
| 121 | // let ident = &self.yielder; |
| 122 | |
| 123 | *i = if self.is_try { |
| 124 | syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await } |
| 125 | } else { |
| 126 | syn::parse_quote! { __yield_tx.send(#value_expr).await } |
| 127 | }; |
| 128 | } |
| 129 | syn::Expr::Try(try_expr) => { |
| 130 | syn::visit_mut::visit_expr_try_mut(self, try_expr); |
| 131 | // let ident = &self.yielder; |
| 132 | let e = &try_expr.expr; |
| 133 | |
| 134 | *i = syn::parse_quote! { |
| 135 | match #e { |
| 136 | ::core::result::Result::Ok(v) => v, |
| 137 | ::core::result::Result::Err(e) => { |
| 138 | __yield_tx.send(::core::result::Result::Err(e.into())).await; |
| 139 | return; |
| 140 | } |
| 141 | } |
| 142 | }; |
| 143 | } |
| 144 | syn::Expr::Closure(_) | syn::Expr::Async(_) => { |
| 145 | // Don't transform inner closures or async blocks. |
| 146 | } |
| 147 | syn::Expr::ForLoop(expr) => { |
| 148 | syn::visit_mut::visit_expr_for_loop_mut(self, expr); |
| 149 | // TODO: Should we allow other attributes? |
| 150 | if expr.attrs.len() != 1 || !expr.attrs[0].meta.path().is_ident(AWAIT_ATTR_NAME) { |
| 151 | return; |
| 152 | } |
| 153 | let syn::ExprForLoop { |
| 154 | attrs, |
| 155 | label, |
| 156 | pat, |
| 157 | expr, |
| 158 | body, |
| 159 | .. |
| 160 | } = expr; |
| 161 | |
| 162 | attrs.pop().unwrap(); |
| 163 | |
| 164 | let crate_path = self.crate_path; |
| 165 | *i = syn::parse_quote! {{ |
| 166 | let mut __pinned = #expr; |
| 167 | let mut __pinned = unsafe { |
| 168 | ::core::pin::Pin::new_unchecked(&mut __pinned) |
| 169 | }; |
| 170 | #label |
| 171 | loop { |
| 172 | let #pat = match #crate_path::__private::next(&mut __pinned).await { |
| 173 | ::core::option::Option::Some(e) => e, |
| 174 | ::core::option::Option::None => break, |
| 175 | }; |
| 176 | #body |
| 177 | } |
| 178 | }} |
| 179 | } |
| 180 | _ => syn::visit_mut::visit_expr_mut(self, i), |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | fn visit_macro_mut(&mut self, mac: &mut syn::Macro) { |
| 185 | let mac_ident = mac.path.segments.last().map(|p| &p.ident); |
| 186 | if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream" ) { |
| 187 | return; |
| 188 | } |
| 189 | |
| 190 | self.visit_token_stream(&mut mac.tokens); |
| 191 | } |
| 192 | |
| 193 | fn visit_item_mut(&mut self, i: &mut syn::Item) { |
| 194 | // Recurse into macros but otherwise don't transform inner items. |
| 195 | if let syn::Item::Macro(i) = i { |
| 196 | self.visit_macro_mut(&mut i.mac); |
| 197 | } |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | /// The first token tree in the stream must be a group containing the path to the `async-stream` |
| 202 | /// crate. |
| 203 | #[proc_macro ] |
| 204 | #[doc (hidden)] |
| 205 | pub fn stream_inner(input: TokenStream) -> TokenStream { |
| 206 | let (crate_path, mut stmts) = match parse_input(input) { |
| 207 | Ok(x) => x, |
| 208 | Err(e) => return e.to_compile_error().into(), |
| 209 | }; |
| 210 | |
| 211 | let mut scrub = Scrub::new(false, &crate_path); |
| 212 | |
| 213 | for stmt in &mut stmts { |
| 214 | scrub.visit_stmt_mut(stmt); |
| 215 | } |
| 216 | |
| 217 | let dummy_yield = if scrub.has_yielded { |
| 218 | None |
| 219 | } else { |
| 220 | Some(quote!(if false { |
| 221 | __yield_tx.send(()).await; |
| 222 | })) |
| 223 | }; |
| 224 | |
| 225 | quote!({ |
| 226 | let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() }; |
| 227 | #crate_path::__private::AsyncStream::new(__yield_rx, async move { |
| 228 | #dummy_yield |
| 229 | #(#stmts)* |
| 230 | }) |
| 231 | }) |
| 232 | .into() |
| 233 | } |
| 234 | |
| 235 | /// The first token tree in the stream must be a group containing the path to the `async-stream` |
| 236 | /// crate. |
| 237 | #[proc_macro ] |
| 238 | #[doc (hidden)] |
| 239 | pub fn try_stream_inner(input: TokenStream) -> TokenStream { |
| 240 | let (crate_path, mut stmts) = match parse_input(input) { |
| 241 | Ok(x) => x, |
| 242 | Err(e) => return e.to_compile_error().into(), |
| 243 | }; |
| 244 | |
| 245 | let mut scrub = Scrub::new(true, &crate_path); |
| 246 | |
| 247 | for stmt in &mut stmts { |
| 248 | scrub.visit_stmt_mut(stmt); |
| 249 | } |
| 250 | |
| 251 | let dummy_yield = if scrub.has_yielded { |
| 252 | None |
| 253 | } else { |
| 254 | Some(quote!(if false { |
| 255 | __yield_tx.send(()).await; |
| 256 | })) |
| 257 | }; |
| 258 | |
| 259 | quote!({ |
| 260 | let (mut __yield_tx, __yield_rx) = unsafe { #crate_path::__private::yielder::pair() }; |
| 261 | #crate_path::__private::AsyncStream::new(__yield_rx, async move { |
| 262 | #dummy_yield |
| 263 | #(#stmts)* |
| 264 | }) |
| 265 | }) |
| 266 | .into() |
| 267 | } |
| 268 | |
| 269 | // syn 2.0 wont parse `#[await] for x in xs {}` |
| 270 | // because `await` is a keyword, use `await_` instead |
| 271 | const AWAIT_ATTR_NAME: &str = "await_" ; |
| 272 | |
| 273 | /// Replace `for await` with `#[await] for`, which will be later transformed into a `next` loop. |
| 274 | fn replace_for_await(input: impl IntoIterator<Item = TokenTree>) -> TokenStream2 { |
| 275 | let mut input = input.into_iter().peekable(); |
| 276 | let mut tokens = Vec::new(); |
| 277 | |
| 278 | while let Some(token) = input.next() { |
| 279 | match token { |
| 280 | TokenTree::Ident(ident) => { |
| 281 | match input.peek() { |
| 282 | Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => { |
| 283 | let next_span = next.span(); |
| 284 | let next = syn::Ident::new(AWAIT_ATTR_NAME, next_span); |
| 285 | tokens.extend(quote!(#[#next])); |
| 286 | let _ = input.next(); |
| 287 | } |
| 288 | _ => {} |
| 289 | } |
| 290 | tokens.push(ident.into()); |
| 291 | } |
| 292 | TokenTree::Group(group) => { |
| 293 | let stream = replace_for_await(group.stream()); |
| 294 | let mut new_group = Group::new(group.delimiter(), stream); |
| 295 | new_group.set_span(group.span()); |
| 296 | tokens.push(new_group.into()); |
| 297 | } |
| 298 | _ => tokens.push(token), |
| 299 | } |
| 300 | } |
| 301 | |
| 302 | tokens.into_iter().collect() |
| 303 | } |
| 304 | |