| 1 | use proc_macro2::{Span, TokenStream}; |
| 2 | use quote::quote; |
| 3 | use syn::{Data, DeriveInput, Fields, Ident}; |
| 4 | |
| 5 | use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties}; |
| 6 | |
| 7 | pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> { |
| 8 | let name = &ast.ident; |
| 9 | let gen = &ast.generics; |
| 10 | let (impl_generics, ty_generics, where_clause) = gen.split_for_impl(); |
| 11 | let vis = &ast.vis; |
| 12 | let type_properties = ast.get_type_properties()?; |
| 13 | let strum_module_path = type_properties.crate_module_path(); |
| 14 | let doc_comment = format!("An iterator over the variants of [ {}]" , name); |
| 15 | |
| 16 | if gen.lifetimes().count() > 0 { |
| 17 | return Err(syn::Error::new( |
| 18 | Span::call_site(), |
| 19 | "This macro doesn't support enums with lifetimes. \ |
| 20 | The resulting enums would be unbounded." , |
| 21 | )); |
| 22 | } |
| 23 | |
| 24 | let phantom_data = if gen.type_params().count() > 0 { |
| 25 | let g = gen.type_params().map(|param| ¶m.ident); |
| 26 | quote! { < fn() -> ( #(#g),* ) > } |
| 27 | } else { |
| 28 | quote! { < fn() -> () > } |
| 29 | }; |
| 30 | |
| 31 | let variants = match &ast.data { |
| 32 | Data::Enum(v) => &v.variants, |
| 33 | _ => return Err(non_enum_error()), |
| 34 | }; |
| 35 | |
| 36 | let mut arms = Vec::new(); |
| 37 | let mut idx = 0usize; |
| 38 | for variant in variants { |
| 39 | if variant.get_variant_properties()?.disabled.is_some() { |
| 40 | continue; |
| 41 | } |
| 42 | |
| 43 | let ident = &variant.ident; |
| 44 | let params = match &variant.fields { |
| 45 | Fields::Unit => quote! {}, |
| 46 | Fields::Unnamed(fields) => { |
| 47 | let defaults = ::core::iter::repeat(quote!(::core::default::Default::default())) |
| 48 | .take(fields.unnamed.len()); |
| 49 | quote! { (#(#defaults),*) } |
| 50 | } |
| 51 | Fields::Named(fields) => { |
| 52 | let fields = fields |
| 53 | .named |
| 54 | .iter() |
| 55 | .map(|field| field.ident.as_ref().unwrap()); |
| 56 | quote! { {#(#fields: ::core::default::Default::default()),*} } |
| 57 | } |
| 58 | }; |
| 59 | |
| 60 | arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)}); |
| 61 | idx += 1; |
| 62 | } |
| 63 | |
| 64 | let variant_count = arms.len(); |
| 65 | arms.push(quote! { _ => ::core::option::Option::None }); |
| 66 | let iter_name = syn::parse_str::<Ident>(&format!(" {}Iter" , name)).unwrap(); |
| 67 | |
| 68 | // Create a string literal "MyEnumIter" to use in the debug impl. |
| 69 | let iter_name_debug_struct = |
| 70 | syn::parse_str::<syn::LitStr>(&format!(" \"{}\"" , iter_name)).unwrap(); |
| 71 | |
| 72 | Ok(quote! { |
| 73 | #[doc = #doc_comment] |
| 74 | #[allow( |
| 75 | missing_copy_implementations, |
| 76 | )] |
| 77 | #vis struct #iter_name #impl_generics { |
| 78 | idx: usize, |
| 79 | back_idx: usize, |
| 80 | marker: ::core::marker::PhantomData #phantom_data, |
| 81 | } |
| 82 | |
| 83 | impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause { |
| 84 | fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result { |
| 85 | // We don't know if the variants implement debug themselves so the only thing we |
| 86 | // can really show is how many elements are left. |
| 87 | f.debug_struct(#iter_name_debug_struct) |
| 88 | .field("len" , &self.len()) |
| 89 | .finish() |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | impl #impl_generics #iter_name #ty_generics #where_clause { |
| 94 | fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> { |
| 95 | match idx { |
| 96 | #(#arms),* |
| 97 | } |
| 98 | } |
| 99 | } |
| 100 | |
| 101 | impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause { |
| 102 | type Iterator = #iter_name #ty_generics; |
| 103 | |
| 104 | #[inline] |
| 105 | fn iter() -> #iter_name #ty_generics { |
| 106 | #iter_name { |
| 107 | idx: 0, |
| 108 | back_idx: 0, |
| 109 | marker: ::core::marker::PhantomData, |
| 110 | } |
| 111 | } |
| 112 | } |
| 113 | |
| 114 | impl #impl_generics Iterator for #iter_name #ty_generics #where_clause { |
| 115 | type Item = #name #ty_generics; |
| 116 | |
| 117 | #[inline] |
| 118 | fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> { |
| 119 | self.nth(0) |
| 120 | } |
| 121 | |
| 122 | #[inline] |
| 123 | fn size_hint(&self) -> (usize, ::core::option::Option<usize>) { |
| 124 | let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx }; |
| 125 | (t, Some(t)) |
| 126 | } |
| 127 | |
| 128 | #[inline] |
| 129 | fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> { |
| 130 | let idx = self.idx + n + 1; |
| 131 | if idx + self.back_idx > #variant_count { |
| 132 | // We went past the end of the iterator. Freeze idx at #variant_count |
| 133 | // so that it doesn't overflow if the user calls this repeatedly. |
| 134 | // See PR #76 for context. |
| 135 | self.idx = #variant_count; |
| 136 | ::core::option::Option::None |
| 137 | } else { |
| 138 | self.idx = idx; |
| 139 | #iter_name::get(self, idx - 1) |
| 140 | } |
| 141 | } |
| 142 | } |
| 143 | |
| 144 | impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause { |
| 145 | #[inline] |
| 146 | fn len(&self) -> usize { |
| 147 | self.size_hint().0 |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause { |
| 152 | #[inline] |
| 153 | fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> { |
| 154 | let back_idx = self.back_idx + 1; |
| 155 | |
| 156 | if self.idx + back_idx > #variant_count { |
| 157 | // We went past the end of the iterator. Freeze back_idx at #variant_count |
| 158 | // so that it doesn't overflow if the user calls this repeatedly. |
| 159 | // See PR #76 for context. |
| 160 | self.back_idx = #variant_count; |
| 161 | ::core::option::Option::None |
| 162 | } else { |
| 163 | self.back_idx = back_idx; |
| 164 | #iter_name::get(self, #variant_count - self.back_idx) |
| 165 | } |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | impl #impl_generics ::core::iter::FusedIterator for #iter_name #ty_generics #where_clause { } |
| 170 | |
| 171 | impl #impl_generics Clone for #iter_name #ty_generics #where_clause { |
| 172 | #[inline] |
| 173 | fn clone(&self) -> #iter_name #ty_generics { |
| 174 | #iter_name { |
| 175 | idx: self.idx, |
| 176 | back_idx: self.back_idx, |
| 177 | marker: self.marker.clone(), |
| 178 | } |
| 179 | } |
| 180 | } |
| 181 | }) |
| 182 | } |
| 183 | |