wherror_impl/
valid.rs

1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use syn::{Error, GenericArgument, PathArguments, Result, Type};
4
5impl Input<'_> {
6    pub(crate) fn validate(&self) -> Result<()> {
7        match self {
8            Input::Struct(input) => input.validate(),
9            Input::Enum(input) => input.validate(),
10        }
11    }
12}
13
14impl Struct<'_> {
15    fn validate(&self) -> Result<()> {
16        check_non_field_attrs(&self.attrs)?;
17        if let Some(transparent) = self.attrs.transparent {
18            if self.fields.len() != 1 {
19                return Err(Error::new_spanned(
20                    transparent.original,
21                    "#[error(transparent)] requires exactly one field",
22                ));
23            }
24            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
25                return Err(Error::new_spanned(
26                    source.original,
27                    "transparent error struct can't contain #[source]",
28                ));
29            }
30        }
31        if let Some(fmt) = &self.attrs.fmt {
32            return Err(Error::new_spanned(
33                fmt.original,
34                "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
35            ));
36        }
37        check_field_attrs(&self.fields)?;
38        for field in &self.fields {
39            field.validate()?;
40        }
41        Ok(())
42    }
43}
44
45impl Enum<'_> {
46    fn validate(&self) -> Result<()> {
47        check_non_field_attrs(&self.attrs)?;
48        let has_display = self.has_display();
49        for variant in &self.variants {
50            variant.validate()?;
51            if has_display
52                && variant.attrs.display.is_none()
53                && variant.attrs.transparent.is_none()
54                && variant.attrs.fmt.is_none()
55            {
56                return Err(Error::new_spanned(
57                    variant.original,
58                    "missing #[error(\"...\")] display attribute",
59                ));
60            }
61        }
62        Ok(())
63    }
64}
65
66impl Variant<'_> {
67    fn validate(&self) -> Result<()> {
68        check_non_field_attrs(&self.attrs)?;
69        if self.attrs.transparent.is_some() {
70            if self.fields.len() != 1 {
71                return Err(Error::new_spanned(
72                    self.original,
73                    "#[error(transparent)] requires exactly one field",
74                ));
75            }
76            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
77                return Err(Error::new_spanned(
78                    source.original,
79                    "transparent variant can't contain #[source]",
80                ));
81            }
82        }
83        check_field_attrs(&self.fields)?;
84        for field in &self.fields {
85            field.validate()?;
86        }
87        Ok(())
88    }
89}
90
91impl Field<'_> {
92    fn validate(&self) -> Result<()> {
93        if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
94            Some(display.original)
95        } else if let Some(fmt) = &self.attrs.fmt {
96            Some(fmt.original)
97        } else {
98            None
99        } {
100            return Err(Error::new_spanned(
101                unexpected_display_attr,
102                "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
103            ));
104        }
105        Ok(())
106    }
107}
108
109fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
110    if let Some(from) = &attrs.from {
111        return Err(Error::new_spanned(
112            from.original,
113            "not expected here; the #[from] attribute belongs on a specific field",
114        ));
115    }
116    if let Some(source) = &attrs.source {
117        return Err(Error::new_spanned(
118            source.original,
119            "not expected here; the #[source] attribute belongs on a specific field",
120        ));
121    }
122    if let Some(backtrace) = &attrs.backtrace {
123        return Err(Error::new_spanned(
124            backtrace,
125            "not expected here; the #[backtrace] attribute belongs on a specific field",
126        ));
127    }
128    if attrs.transparent.is_some() {
129        if let Some(display) = &attrs.display {
130            return Err(Error::new_spanned(
131                display.original,
132                "cannot have both #[error(transparent)] and a display attribute",
133            ));
134        }
135        if let Some(fmt) = &attrs.fmt {
136            return Err(Error::new_spanned(
137                fmt.original,
138                "cannot have both #[error(transparent)] and #[error(fmt = ...)]",
139            ));
140        }
141    } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
142        return Err(Error::new_spanned(
143            display.original,
144            "cannot have both #[error(fmt = ...)] and a format arguments attribute",
145        ));
146    }
147
148    Ok(())
149}
150
151fn check_field_attrs(fields: &[Field]) -> Result<()> {
152    let mut from_field = None;
153    let mut source_field = None;
154    let mut location_field: Option<&Field> = None;
155    let mut backtrace_field = None;
156    let mut has_backtrace = false;
157    let mut has_location = false;
158    for field in fields {
159        if let Some(from) = field.attrs.from {
160            if from_field.is_some() {
161                return Err(Error::new_spanned(
162                    from.original,
163                    "duplicate #[from] attribute",
164                ));
165            }
166            from_field = Some(field);
167        }
168        if let Some(source) = field.attrs.source {
169            if source_field.is_some() {
170                return Err(Error::new_spanned(
171                    source.original,
172                    "duplicate #[source] attribute",
173                ));
174            }
175            source_field = Some(field);
176        }
177        if let Some(backtrace) = field.attrs.backtrace {
178            if backtrace_field.is_some() {
179                return Err(Error::new_spanned(
180                    backtrace,
181                    "duplicate #[backtrace] attribute",
182                ));
183            }
184            backtrace_field = Some(field);
185            has_backtrace = true;
186        }
187        if let Some(location) = field.attrs.location {
188            if location_field.is_some() {
189                return Err(Error::new_spanned(
190                    location,
191                    "duplicate #[location] attribute",
192                ));
193            }
194
195            location_field = Some(field);
196            has_location = true;
197        }
198        if let Some(transparent) = field.attrs.transparent {
199            return Err(Error::new_spanned(
200                transparent.original,
201                "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
202            ));
203        }
204        has_backtrace |= field.is_backtrace();
205        has_location |= field.is_location();
206    }
207    if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
208        if from_field.member != source_field.member {
209            return Err(Error::new_spanned(
210                from_field.attrs.from.unwrap().original,
211                "#[from] is only supported on the source field, not any other field",
212            ));
213        }
214    }
215    if let Some(from_field) = from_field {
216        let extra_fields = has_backtrace as usize + has_location as usize;
217        let max_expected_fields = match (backtrace_field, location_field) {
218            (Some(backtrace_field), Some(_)) => {
219                2 + (from_field.member != backtrace_field.member) as usize
220            }
221            (Some(backtrace_field), None) => {
222                1 + (from_field.member != backtrace_field.member) as usize
223            }
224            (None, Some(_)) => 1 + extra_fields,
225            (None, None) => 1 + extra_fields,
226        };
227        if fields.len() > max_expected_fields {
228            return Err(Error::new_spanned(
229                from_field.attrs.from.unwrap().original,
230                "deriving From requires no fields other than source, backtrace, and location",
231            ));
232        }
233    }
234    if let Some(source_field) = source_field.or(from_field) {
235        if contains_non_static_lifetime(source_field.ty) {
236            return Err(Error::new_spanned(
237                &source_field.original.ty,
238                "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
239            ));
240        }
241    }
242    Ok(())
243}
244
245fn contains_non_static_lifetime(ty: &Type) -> bool {
246    match ty {
247        Type::Path(ty) => {
248            let bracketed = match &ty.path.segments.last().unwrap().arguments {
249                PathArguments::AngleBracketed(bracketed) => bracketed,
250                _ => return false,
251            };
252            for arg in &bracketed.args {
253                match arg {
254                    GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
255                    GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
256                        return true
257                    }
258                    _ => {}
259                }
260            }
261            false
262        }
263        Type::Reference(ty) => ty
264            .lifetime
265            .as_ref()
266            .map_or(false, |lifetime| lifetime.ident != "static"),
267        _ => false, // maybe implement later if there are common other cases
268    }
269}