1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Ident};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11 let vis = &ast.vis;
12 let type_properties = ast.get_type_properties()?;
13 let strum_module_path = type_properties.crate_module_path();
14 let doc_comment = format!("An iterator over the variants of [{}]", name);
15
16 if gen.lifetimes().count() > 0 {
17 return Err(syn::Error::new(
18 Span::call_site(),
19 "This macro doesn't support enums with lifetimes. \
20 The resulting enums would be unbounded.",
21 ));
22 }
23
24 let phantom_data = if gen.type_params().count() > 0 {
25 let g = gen.type_params().map(|param| &param.ident);
26 quote! { < ( #(#g),* ) > }
27 } else {
28 quote! { < () > }
29 };
30
31 let variants = match &ast.data {
32 Data::Enum(v) => &v.variants,
33 _ => return Err(non_enum_error()),
34 };
35
36 let mut arms = Vec::new();
37 let mut idx = 0usize;
38 for variant in variants {
39 if variant.get_variant_properties()?.disabled.is_some() {
40 continue;
41 }
42
43 let ident = &variant.ident;
44 let params = match &variant.fields {
45 Fields::Unit => quote! {},
46 Fields::Unnamed(fields) => {
47 let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
48 .take(fields.unnamed.len());
49 quote! { (#(#defaults),*) }
50 }
51 Fields::Named(fields) => {
52 let fields = fields
53 .named
54 .iter()
55 .map(|field| field.ident.as_ref().unwrap());
56 quote! { {#(#fields: ::core::default::Default::default()),*} }
57 }
58 };
59
60 arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
61 idx += 1;
62 }
63
64 let variant_count = arms.len();
65 arms.push(quote! { _ => ::core::option::Option::None });
66 let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
67
68 // Create a string literal "MyEnumIter" to use in the debug impl.
69 let iter_name_debug_struct =
70 syn::parse_str::<syn::LitStr>(&format!("\"{}\"", iter_name)).unwrap();
71
72 Ok(quote! {
73 #[doc = #doc_comment]
74 #[allow(
75 missing_copy_implementations,
76 )]
77 #vis struct #iter_name #impl_generics {
78 idx: usize,
79 back_idx: usize,
80 marker: ::core::marker::PhantomData #phantom_data,
81 }
82
83 impl #impl_generics ::core::fmt::Debug for #iter_name #ty_generics #where_clause {
84 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
85 // We don't know if the variants implement debug themselves so the only thing we
86 // can really show is how many elements are left.
87 f.debug_struct(#iter_name_debug_struct)
88 .field("len", &self.len())
89 .finish()
90 }
91 }
92
93 impl #impl_generics #iter_name #ty_generics #where_clause {
94 fn get(&self, idx: usize) -> ::core::option::Option<#name #ty_generics> {
95 match idx {
96 #(#arms),*
97 }
98 }
99 }
100
101 impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
102 type Iterator = #iter_name #ty_generics;
103 fn iter() -> #iter_name #ty_generics {
104 #iter_name {
105 idx: 0,
106 back_idx: 0,
107 marker: ::core::marker::PhantomData,
108 }
109 }
110 }
111
112 impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
113 type Item = #name #ty_generics;
114
115 fn next(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
116 self.nth(0)
117 }
118
119 fn size_hint(&self) -> (usize, ::core::option::Option<usize>) {
120 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
121 (t, Some(t))
122 }
123
124 fn nth(&mut self, n: usize) -> ::core::option::Option<<Self as Iterator>::Item> {
125 let idx = self.idx + n + 1;
126 if idx + self.back_idx > #variant_count {
127 // We went past the end of the iterator. Freeze idx at #variant_count
128 // so that it doesn't overflow if the user calls this repeatedly.
129 // See PR #76 for context.
130 self.idx = #variant_count;
131 ::core::option::Option::None
132 } else {
133 self.idx = idx;
134 self.get(idx - 1)
135 }
136 }
137 }
138
139 impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
140 fn len(&self) -> usize {
141 self.size_hint().0
142 }
143 }
144
145 impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
146 fn next_back(&mut self) -> ::core::option::Option<<Self as Iterator>::Item> {
147 let back_idx = self.back_idx + 1;
148
149 if self.idx + back_idx > #variant_count {
150 // We went past the end of the iterator. Freeze back_idx at #variant_count
151 // so that it doesn't overflow if the user calls this repeatedly.
152 // See PR #76 for context.
153 self.back_idx = #variant_count;
154 ::core::option::Option::None
155 } else {
156 self.back_idx = back_idx;
157 self.get(#variant_count - self.back_idx)
158 }
159 }
160 }
161
162 impl #impl_generics ::core::iter::FusedIterator for #iter_name #ty_generics #where_clause { }
163
164 impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
165 fn clone(&self) -> #iter_name #ty_generics {
166 #iter_name {
167 idx: self.idx,
168 back_idx: self.back_idx,
169 marker: self.marker.clone(),
170 }
171 }
172 }
173 })
174}
175