1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{Data, DeriveInput, Fields, Ident};
4
5use crate::helpers::{non_enum_error, HasStrumVariantProperties, HasTypeProperties};
6
7pub fn enum_iter_inner(ast: &DeriveInput) -> syn::Result<TokenStream> {
8 let name = &ast.ident;
9 let gen = &ast.generics;
10 let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
11 let vis = &ast.vis;
12 let type_properties = ast.get_type_properties()?;
13 let strum_module_path = type_properties.crate_module_path();
14
15 if gen.lifetimes().count() > 0 {
16 return Err(syn::Error::new(
17 Span::call_site(),
18 "This macro doesn't support enums with lifetimes. \
19 The resulting enums would be unbounded.",
20 ));
21 }
22
23 let phantom_data = if gen.type_params().count() > 0 {
24 let g = gen.type_params().map(|param| &param.ident);
25 quote! { < ( #(#g),* ) > }
26 } else {
27 quote! { < () > }
28 };
29
30 let variants = match &ast.data {
31 Data::Enum(v) => &v.variants,
32 _ => return Err(non_enum_error()),
33 };
34
35 let mut arms = Vec::new();
36 let mut idx = 0usize;
37 for variant in variants {
38 if variant.get_variant_properties()?.disabled.is_some() {
39 continue;
40 }
41
42 let ident = &variant.ident;
43 let params = match &variant.fields {
44 Fields::Unit => quote! {},
45 Fields::Unnamed(fields) => {
46 let defaults = ::core::iter::repeat(quote!(::core::default::Default::default()))
47 .take(fields.unnamed.len());
48 quote! { (#(#defaults),*) }
49 }
50 Fields::Named(fields) => {
51 let fields = fields
52 .named
53 .iter()
54 .map(|field| field.ident.as_ref().unwrap());
55 quote! { {#(#fields: ::core::default::Default::default()),*} }
56 }
57 };
58
59 arms.push(quote! {#idx => ::core::option::Option::Some(#name::#ident #params)});
60 idx += 1;
61 }
62
63 let variant_count = arms.len();
64 arms.push(quote! { _ => ::core::option::Option::None });
65 let iter_name = syn::parse_str::<Ident>(&format!("{}Iter", name)).unwrap();
66
67 Ok(quote! {
68 #[doc = "An iterator over the variants of [Self]"]
69 #[allow(
70 missing_copy_implementations,
71 missing_debug_implementations,
72 )]
73 #vis struct #iter_name #ty_generics {
74 idx: usize,
75 back_idx: usize,
76 marker: ::core::marker::PhantomData #phantom_data,
77 }
78
79 impl #impl_generics #iter_name #ty_generics #where_clause {
80 fn get(&self, idx: usize) -> Option<#name #ty_generics> {
81 match idx {
82 #(#arms),*
83 }
84 }
85 }
86
87 impl #impl_generics #strum_module_path::IntoEnumIterator for #name #ty_generics #where_clause {
88 type Iterator = #iter_name #ty_generics;
89 fn iter() -> #iter_name #ty_generics {
90 #iter_name {
91 idx: 0,
92 back_idx: 0,
93 marker: ::core::marker::PhantomData,
94 }
95 }
96 }
97
98 impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
99 type Item = #name #ty_generics;
100
101 fn next(&mut self) -> Option<<Self as Iterator>::Item> {
102 self.nth(0)
103 }
104
105 fn size_hint(&self) -> (usize, Option<usize>) {
106 let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx };
107 (t, Some(t))
108 }
109
110 fn nth(&mut self, n: usize) -> Option<<Self as Iterator>::Item> {
111 let idx = self.idx + n + 1;
112 if idx + self.back_idx > #variant_count {
113 // We went past the end of the iterator. Freeze idx at #variant_count
114 // so that it doesn't overflow if the user calls this repeatedly.
115 // See PR #76 for context.
116 self.idx = #variant_count;
117 ::core::option::Option::None
118 } else {
119 self.idx = idx;
120 self.get(idx - 1)
121 }
122 }
123 }
124
125 impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
126 fn len(&self) -> usize {
127 self.size_hint().0
128 }
129 }
130
131 impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
132 fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
133 let back_idx = self.back_idx + 1;
134
135 if self.idx + back_idx > #variant_count {
136 // We went past the end of the iterator. Freeze back_idx at #variant_count
137 // so that it doesn't overflow if the user calls this repeatedly.
138 // See PR #76 for context.
139 self.back_idx = #variant_count;
140 ::core::option::Option::None
141 } else {
142 self.back_idx = back_idx;
143 self.get(#variant_count - self.back_idx)
144 }
145 }
146 }
147
148 impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
149 fn clone(&self) -> #iter_name #ty_generics {
150 #iter_name {
151 idx: self.idx,
152 back_idx: self.back_idx,
153 marker: self.marker.clone(),
154 }
155 }
156 }
157 })
158}
159