wherror_impl/
valid.rs

1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use syn::{Error, GenericArgument, PathArguments, Result, Type};
4
5impl Input<'_> {
6    pub(crate) fn validate(&self) -> Result<()> {
7        match self {
8            Input::Struct(input) => input.validate(),
9            Input::Enum(input) => input.validate(),
10        }
11    }
12}
13
14impl Struct<'_> {
15    fn validate(&self) -> Result<()> {
16        check_non_field_attrs(&self.attrs)?;
17        if let Some(transparent) = self.attrs.transparent {
18            if self.fields.len() != 1 {
19                return Err(Error::new_spanned(
20                    transparent.original,
21                    "#[error(transparent)] requires exactly one field",
22                ));
23            }
24            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
25                return Err(Error::new_spanned(
26                    source.original,
27                    "transparent error struct can't contain #[source]",
28                ));
29            }
30        }
31        if let Some(fmt) = &self.attrs.fmt {
32            return Err(Error::new_spanned(
33                fmt.original,
34                "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
35            ));
36        }
37        check_field_attrs(&self.fields)?;
38        for field in &self.fields {
39            field.validate()?;
40        }
41        Ok(())
42    }
43}
44
45impl Enum<'_> {
46    fn validate(&self) -> Result<()> {
47        check_non_field_attrs(&self.attrs)?;
48        let has_display = self.has_display();
49        for variant in &self.variants {
50            variant.validate()?;
51            // Only require explicit display attributes if the enum has some display capability
52            // but this specific variant lacks any display mechanism
53            if has_display
54                && variant.attrs.display.is_none()
55                && variant.attrs.transparent.is_none()
56                && variant.attrs.fmt.is_none()
57                && variant.attrs.debug.is_none()
58            {
59                // Deny if the enum lacks #[error(debug)] fallback
60                if !self.attrs.debug.is_some() {
61                    return Err(Error::new_spanned(
62                        variant.original,
63                        "missing #[error(\"...\")] display attribute",
64                    ));
65                }
66            }
67        }
68        Ok(())
69    }
70}
71
72impl Variant<'_> {
73    fn validate(&self) -> Result<()> {
74        check_non_field_attrs(&self.attrs)?;
75        if self.attrs.transparent.is_some() {
76            if self.fields.len() != 1 {
77                return Err(Error::new_spanned(
78                    self.original,
79                    "#[error(transparent)] requires exactly one field",
80                ));
81            }
82            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
83                return Err(Error::new_spanned(
84                    source.original,
85                    "transparent variant can't contain #[source]",
86                ));
87            }
88        }
89        check_field_attrs(&self.fields)?;
90        for field in &self.fields {
91            field.validate()?;
92        }
93        Ok(())
94    }
95}
96
97impl Field<'_> {
98    fn validate(&self) -> Result<()> {
99        if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
100            Some(display.original)
101        } else if let Some(fmt) = &self.attrs.fmt {
102            Some(fmt.original)
103        } else if let Some(debug) = &self.attrs.debug {
104            Some(debug.original)
105        } else {
106            None
107        } {
108            return Err(Error::new_spanned(
109                unexpected_display_attr,
110                "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
111            ));
112        }
113        Ok(())
114    }
115}
116
117fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
118    if let Some(from) = &attrs.from {
119        return Err(Error::new_spanned(
120            from.original,
121            "not expected here; the #[from] attribute belongs on a specific field",
122        ));
123    }
124    if let Some(source) = &attrs.source {
125        return Err(Error::new_spanned(
126            source.original,
127            "not expected here; the #[source] attribute belongs on a specific field",
128        ));
129    }
130    if let Some(backtrace) = &attrs.backtrace {
131        return Err(Error::new_spanned(
132            backtrace,
133            "not expected here; the #[backtrace] attribute belongs on a specific field",
134        ));
135    }
136    if attrs.transparent.is_some() {
137        if let Some(display) = &attrs.display {
138            return Err(Error::new_spanned(
139                display.original,
140                "cannot have both #[error(transparent)] and a display attribute",
141            ));
142        }
143        if let Some(fmt) = &attrs.fmt {
144            return Err(Error::new_spanned(
145                fmt.original,
146                "cannot have both #[error(transparent)] and #[error(fmt = ...)]",
147            ));
148        }
149        if let Some(debug) = &attrs.debug {
150            return Err(Error::new_spanned(
151                debug.original,
152                "cannot have both #[error(transparent)] and #[error(debug)]",
153            ));
154        }
155    } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
156        return Err(Error::new_spanned(
157            display.original,
158            "cannot have both #[error(fmt = ...)] and a format arguments attribute",
159        ));
160    } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.debug) {
161        return Err(Error::new_spanned(
162            display.original,
163            "cannot have both #[error(debug)] and a display attribute",
164        ));
165    } else if let (Some(fmt), Some(_)) = (&attrs.fmt, &attrs.debug) {
166        return Err(Error::new_spanned(
167            fmt.original,
168            "cannot have both #[error(fmt = ...)] and #[error(debug)]",
169        ));
170    }
171
172    Ok(())
173}
174
175fn check_field_attrs(fields: &[Field]) -> Result<()> {
176    let mut from_field = None;
177    let mut source_field = None;
178    let mut location_field: Option<&Field> = None;
179    let mut backtrace_field = None;
180    let mut has_backtrace = false;
181    let mut has_location = false;
182    for field in fields {
183        if let Some(from) = field.attrs.from {
184            if from_field.is_some() {
185                return Err(Error::new_spanned(
186                    from.original,
187                    "duplicate #[from] attribute",
188                ));
189            }
190            from_field = Some(field);
191        }
192        if let Some(source) = field.attrs.source {
193            if source_field.is_some() {
194                return Err(Error::new_spanned(
195                    source.original,
196                    "duplicate #[source] attribute",
197                ));
198            }
199            source_field = Some(field);
200        }
201        if let Some(backtrace) = field.attrs.backtrace {
202            if backtrace_field.is_some() {
203                return Err(Error::new_spanned(
204                    backtrace,
205                    "duplicate #[backtrace] attribute",
206                ));
207            }
208            backtrace_field = Some(field);
209            has_backtrace = true;
210        }
211        if let Some(location) = field.attrs.location {
212            if location_field.is_some() {
213                return Err(Error::new_spanned(
214                    location,
215                    "duplicate #[location] attribute",
216                ));
217            }
218
219            location_field = Some(field);
220            has_location = true;
221        }
222        if let Some(transparent) = field.attrs.transparent {
223            return Err(Error::new_spanned(
224                transparent.original,
225                "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
226            ));
227        }
228        has_backtrace |= field.is_backtrace();
229        has_location |= field.is_location();
230    }
231    if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
232        if from_field.member != source_field.member {
233            return Err(Error::new_spanned(
234                from_field.attrs.from.unwrap().original,
235                "#[from] is only supported on the source field, not any other field",
236            ));
237        }
238    }
239    if let Some(from_field) = from_field {
240        let extra_fields = has_backtrace as usize + has_location as usize;
241        let max_expected_fields = match (backtrace_field, location_field) {
242            (Some(backtrace_field), Some(_)) => {
243                2 + (from_field.member != backtrace_field.member) as usize
244            }
245            (Some(backtrace_field), None) => {
246                1 + (from_field.member != backtrace_field.member) as usize
247            }
248            (None, Some(_)) => 1 + extra_fields,
249            (None, None) => 1 + extra_fields,
250        };
251        if fields.len() > max_expected_fields {
252            return Err(Error::new_spanned(
253                from_field.attrs.from.unwrap().original,
254                "deriving From requires no fields other than source, backtrace, and location",
255            ));
256        }
257    }
258    if let Some(source_field) = source_field.or(from_field) {
259        if contains_non_static_lifetime(source_field.ty) {
260            return Err(Error::new_spanned(
261                &source_field.original.ty,
262                "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
263            ));
264        }
265    }
266    Ok(())
267}
268
269fn contains_non_static_lifetime(ty: &Type) -> bool {
270    match ty {
271        Type::Path(ty) => {
272            let bracketed = match &ty.path.segments.last().unwrap().arguments {
273                PathArguments::AngleBracketed(bracketed) => bracketed,
274                _ => return false,
275            };
276            for arg in &bracketed.args {
277                match arg {
278                    GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
279                    GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
280                        return true
281                    }
282                    _ => {}
283                }
284            }
285            false
286        }
287        Type::Reference(ty) => ty
288            .lifetime
289            .as_ref()
290            .map_or(false, |lifetime| lifetime.ident != "static"),
291        _ => false, // maybe implement later if there are common other cases
292    }
293}