1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use quote::ToTokens;
4use std::collections::BTreeSet as Set;
5use syn::{Error, GenericArgument, Member, PathArguments, Result, Type};
6
7impl Input<'_> {
8 pub(crate) fn validate(&self) -> Result<()> {
9 match self {
10 Input::Struct(input: &Struct<'_>) => input.validate(),
11 Input::Enum(input: &Enum<'_>) => input.validate(),
12 }
13 }
14}
15
16impl Struct<'_> {
17 fn validate(&self) -> Result<()> {
18 check_non_field_attrs(&self.attrs)?;
19 if let Some(transparent: Transparent<'_>) = self.attrs.transparent {
20 if self.fields.len() != 1 {
21 return Err(Error::new_spanned(
22 tokens:transparent.original,
23 message:"#[error(transparent)] requires exactly one field",
24 ));
25 }
26 if let Some(source: &Attribute) = self.fields.iter().find_map(|f: &Field<'_>| f.attrs.source) {
27 return Err(Error::new_spanned(
28 tokens:source,
29 message:"transparent error struct can't contain #[source]",
30 ));
31 }
32 }
33 check_field_attrs(&self.fields)?;
34 for field: &Field<'_> in &self.fields {
35 field.validate()?;
36 }
37 Ok(())
38 }
39}
40
41impl Enum<'_> {
42 fn validate(&self) -> Result<()> {
43 check_non_field_attrs(&self.attrs)?;
44 let has_display = self.has_display();
45 for variant in &self.variants {
46 variant.validate()?;
47 if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
48 {
49 return Err(Error::new_spanned(
50 variant.original,
51 "missing #[error(\"...\")] display attribute",
52 ));
53 }
54 }
55 let mut from_types = Set::new();
56 for variant in &self.variants {
57 if let Some(from_field) = variant.from_field() {
58 let repr = from_field.ty.to_token_stream().to_string();
59 if !from_types.insert(repr) {
60 return Err(Error::new_spanned(
61 from_field.original,
62 "cannot derive From because another variant has the same source type",
63 ));
64 }
65 }
66 }
67 Ok(())
68 }
69}
70
71impl Variant<'_> {
72 fn validate(&self) -> Result<()> {
73 check_non_field_attrs(&self.attrs)?;
74 if self.attrs.transparent.is_some() {
75 if self.fields.len() != 1 {
76 return Err(Error::new_spanned(
77 self.original,
78 message:"#[error(transparent)] requires exactly one field",
79 ));
80 }
81 if let Some(source: &Attribute) = self.fields.iter().find_map(|f: &Field<'_>| f.attrs.source) {
82 return Err(Error::new_spanned(
83 tokens:source,
84 message:"transparent variant can't contain #[source]",
85 ));
86 }
87 }
88 check_field_attrs(&self.fields)?;
89 for field: &Field<'_> in &self.fields {
90 field.validate()?;
91 }
92 Ok(())
93 }
94}
95
96impl Field<'_> {
97 fn validate(&self) -> Result<()> {
98 if let Some(display: &Display<'_>) = &self.attrs.display {
99 return Err(Error::new_spanned(
100 tokens:display.original,
101 message:"not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
102 ));
103 }
104 Ok(())
105 }
106}
107
108fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
109 if let Some(from) = &attrs.from {
110 return Err(Error::new_spanned(
111 from,
112 "not expected here; the #[from] attribute belongs on a specific field",
113 ));
114 }
115 if let Some(source) = &attrs.source {
116 return Err(Error::new_spanned(
117 source,
118 "not expected here; the #[source] attribute belongs on a specific field",
119 ));
120 }
121 if let Some(backtrace) = &attrs.backtrace {
122 return Err(Error::new_spanned(
123 backtrace,
124 "not expected here; the #[backtrace] attribute belongs on a specific field",
125 ));
126 }
127 if let Some(display) = &attrs.display {
128 if attrs.transparent.is_some() {
129 return Err(Error::new_spanned(
130 display.original,
131 "cannot have both #[error(transparent)] and a display attribute",
132 ));
133 }
134 }
135 Ok(())
136}
137
138fn check_field_attrs(fields: &[Field]) -> Result<()> {
139 let mut from_field = None;
140 let mut source_field = None;
141 let mut backtrace_field = None;
142 let mut has_backtrace = false;
143 for field in fields {
144 if let Some(from) = field.attrs.from {
145 if from_field.is_some() {
146 return Err(Error::new_spanned(from, "duplicate #[from] attribute"));
147 }
148 from_field = Some(field);
149 }
150 if let Some(source) = field.attrs.source {
151 if source_field.is_some() {
152 return Err(Error::new_spanned(source, "duplicate #[source] attribute"));
153 }
154 source_field = Some(field);
155 }
156 if let Some(backtrace) = field.attrs.backtrace {
157 if backtrace_field.is_some() {
158 return Err(Error::new_spanned(
159 backtrace,
160 "duplicate #[backtrace] attribute",
161 ));
162 }
163 backtrace_field = Some(field);
164 has_backtrace = true;
165 }
166 if let Some(transparent) = field.attrs.transparent {
167 return Err(Error::new_spanned(
168 transparent.original,
169 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
170 ));
171 }
172 has_backtrace |= field.is_backtrace();
173 }
174 if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
175 if !same_member(from_field, source_field) {
176 return Err(Error::new_spanned(
177 from_field.attrs.from,
178 "#[from] is only supported on the source field, not any other field",
179 ));
180 }
181 }
182 if let Some(from_field) = from_field {
183 let max_expected_fields = match backtrace_field {
184 Some(backtrace_field) => 1 + !same_member(from_field, backtrace_field) as usize,
185 None => 1 + has_backtrace as usize,
186 };
187 if fields.len() > max_expected_fields {
188 return Err(Error::new_spanned(
189 from_field.attrs.from,
190 "deriving From requires no fields other than source and backtrace",
191 ));
192 }
193 }
194 if let Some(source_field) = source_field.or(from_field) {
195 if contains_non_static_lifetime(source_field.ty) {
196 return Err(Error::new_spanned(
197 &source_field.original.ty,
198 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
199 ));
200 }
201 }
202 Ok(())
203}
204
205fn same_member(one: &Field, two: &Field) -> bool {
206 match (&one.member, &two.member) {
207 (Member::Named(one: &Ident), Member::Named(two: &Ident)) => one == two,
208 (Member::Unnamed(one: &Index), Member::Unnamed(two: &Index)) => one.index == two.index,
209 _ => unreachable!(),
210 }
211}
212
213fn contains_non_static_lifetime(ty: &Type) -> bool {
214 match ty {
215 Type::Path(ty) => {
216 let bracketed = match &ty.path.segments.last().unwrap().arguments {
217 PathArguments::AngleBracketed(bracketed) => bracketed,
218 _ => return false,
219 };
220 for arg in &bracketed.args {
221 match arg {
222 GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
223 GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
224 return true
225 }
226 _ => {}
227 }
228 }
229 false
230 }
231 Type::Reference(ty) => ty
232 .lifetime
233 .as_ref()
234 .map_or(false, |lifetime| lifetime.ident != "static"),
235 _ => false, // maybe implement later if there are common other cases
236 }
237}
238