1 | use proc_macro2::{Span, TokenStream}; |
2 | use quote::quote; |
3 | use syn::{Data, DeriveInput, Fields, Ident}; |
4 | |
5 | use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties}; |
6 | |
7 | pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> { |
8 | let name = &ast.ident; |
9 | let gen = &ast.generics; |
10 | let (impl_generics, ty_generics, where_clause) = gen.split_for_impl(); |
11 | let vis = &ast.vis; |
12 | let type_properties = ast.get_type_properties()?; |
13 | let strum_module_path = type_properties.crate_module_path(); |
14 | |
15 | if gen.lifetimes().count() > 0 { |
16 | return Err(syn::Error::new( |
17 | Span::call_site(), |
18 | "This macro doesn't support enums with lifetimes. \ |
19 | The resulting enums would be unbounded." , |
20 | )); |
21 | } |
22 | |
23 | let phantom_data = if gen.type_params().count() > 0 { |
24 | let g = gen.type_params().map(|param| ¶m.ident); |
25 | quote! { < ( #(#g),* ) > } |
26 | } else { |
27 | quote! { < () > } |
28 | }; |
29 | |
30 | let variants = match &ast.data { |
31 | Data::Enum(v) => &v.variants, |
32 | _ => return Err(non_enum_error()), |
33 | }; |
34 | |
35 | let mut arms = Vec::new(); |
36 | let mut idx = 0usize; |
37 | for variant in variants { |
38 | if variant.get_variant_properties()?.disabled.is_some() { |
39 | continue; |
40 | } |
41 | |
42 | let ident = &variant.ident; |
43 | let params = match &variant.fields { |
44 | Fields::Unit => quote! {}, |
45 | Fields::Unnamed(fields) => { |
46 | let defaults = ::core::iter::repeat(quote!(::core::default::Default::default())) |
47 | .take(fields.unnamed.len()); |
48 | quote! { (#(#defaults),*) } |
49 | } |
50 | Fields::Named(fields) => { |
51 | let fields = fields |
52 | .named |
53 | .iter() |
54 | .map(|field| field.ident.as_ref().unwrap()); |
55 | quote! { {#(#fields: ::core::default::Default::default()),*} } |
56 | } |
57 | }; |
58 | |
59 | arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)}); |
60 | idx += 1; |
61 | } |
62 | |
63 | let variant_count = arms.len(); |
64 | arms.push(quote! { _ => ::core::option::Option::None }); |
65 | let iter_name = syn::parse_str::<Ident>(&format!(" {}Iter" , name)).unwrap(); |
66 | |
67 | Ok(quote! { |
68 | #[doc = "An iterator over the variants of [Self]" ] |
69 | #[allow( |
70 | missing_copy_implementations, |
71 | missing_debug_implementations, |
72 | )] |
73 | #vis struct #iter_name #ty_generics { |
74 | idx: usize, |
75 | back_idx: usize, |
76 | marker: ::core::marker::PhantomData #phantom_data, |
77 | } |
78 | |
79 | impl #impl_generics #iter_name #ty_generics #where_clause { |
80 | fn get(&self, idx: usize) -> Option<#name #ty_generics> { |
81 | match idx { |
82 | #(#arms),* |
83 | } |
84 | } |
85 | } |
86 | |
87 | impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause { |
88 | type Iterator = #iter_name #ty_generics; |
89 | fn iter() -> #iter_name #ty_generics { |
90 | #iter_name { |
91 | idx: 0, |
92 | back_idx: 0, |
93 | marker: ::core::marker::PhantomData, |
94 | } |
95 | } |
96 | } |
97 | |
98 | impl #impl_generics Iterator for #iter_name #ty_generics #where_clause { |
99 | type Item = #name #ty_generics; |
100 | |
101 | fn next(&mut self) -> Option<<Self as Iterator>::Item> { |
102 | self.nth(0) |
103 | } |
104 | |
105 | fn size_hint(&self) -> (usize, Option<usize>) { |
106 | let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx }; |
107 | (t, Some(t)) |
108 | } |
109 | |
110 | fn nth(&mut self, n: usize) -> Option<<Self as Iterator>::Item> { |
111 | let idx = self.idx + n + 1; |
112 | if idx + self.back_idx > #variant_count { |
113 | // We went past the end of the iterator. Freeze idx at #variant_count |
114 | // so that it doesn't overflow if the user calls this repeatedly. |
115 | // See PR #76 for context. |
116 | self.idx = #variant_count; |
117 | ::core::option::Option::None |
118 | } else { |
119 | self.idx = idx; |
120 | self.get(idx - 1) |
121 | } |
122 | } |
123 | } |
124 | |
125 | impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause { |
126 | fn len(&self) -> usize { |
127 | self.size_hint().0 |
128 | } |
129 | } |
130 | |
131 | impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause { |
132 | fn next_back(&mut self) -> Option<<Self as Iterator>::Item> { |
133 | let back_idx = self.back_idx + 1; |
134 | |
135 | if self.idx + back_idx > #variant_count { |
136 | // We went past the end of the iterator. Freeze back_idx at #variant_count |
137 | // so that it doesn't overflow if the user calls this repeatedly. |
138 | // See PR #76 for context. |
139 | self.back_idx = #variant_count; |
140 | ::core::option::Option::None |
141 | } else { |
142 | self.back_idx = back_idx; |
143 | self.get(#variant_count - self.back_idx) |
144 | } |
145 | } |
146 | } |
147 | |
148 | impl #impl_generics Clone for #iter_name #ty_generics #where_clause { |
149 | fn clone(&self) -> #iter_name #ty_generics { |
150 | #iter_name { |
151 | idx: self.idx, |
152 | back_idx: self.back_idx, |
153 | marker: self.marker.clone(), |
154 | } |
155 | } |
156 | } |
157 | }) |
158 | } |
159 | |