1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use syn::{Error, GenericArgument, PathArguments, Result, Type};
4
5impl Input<'_> {
6 pub(crate) fn validate(&self) -> Result<()> {
7 match self {
8 Input::Struct(input) => input.validate(),
9 Input::Enum(input) => input.validate(),
10 }
11 }
12}
13
14impl Struct<'_> {
15 fn validate(&self) -> Result<()> {
16 check_non_field_attrs(&self.attrs)?;
17 if let Some(transparent) = self.attrs.transparent {
18 if self.fields.len() != 1 {
19 return Err(Error::new_spanned(
20 transparent.original,
21 "#[error(transparent)] requires exactly one field",
22 ));
23 }
24 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
25 return Err(Error::new_spanned(
26 source.original,
27 "transparent error struct can't contain #[source]",
28 ));
29 }
30 }
31 if let Some(fmt) = &self.attrs.fmt {
32 return Err(Error::new_spanned(
33 fmt.original,
34 "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
35 ));
36 }
37 check_field_attrs(&self.fields)?;
38 for field in &self.fields {
39 field.validate()?;
40 }
41 Ok(())
42 }
43}
44
45impl Enum<'_> {
46 fn validate(&self) -> Result<()> {
47 check_non_field_attrs(&self.attrs)?;
48 let has_display = self.has_display();
49 for variant in &self.variants {
50 variant.validate()?;
51 if has_display
54 && variant.attrs.display.is_none()
55 && variant.attrs.transparent.is_none()
56 && variant.attrs.fmt.is_none()
57 && variant.attrs.debug.is_none()
58 {
59 if !self.attrs.debug.is_some() {
61 return Err(Error::new_spanned(
62 variant.original,
63 "missing #[error(\"...\")] display attribute",
64 ));
65 }
66 }
67 }
68 Ok(())
69 }
70}
71
72impl Variant<'_> {
73 fn validate(&self) -> Result<()> {
74 check_non_field_attrs(&self.attrs)?;
75 if self.attrs.transparent.is_some() {
76 if self.fields.len() != 1 {
77 return Err(Error::new_spanned(
78 self.original,
79 "#[error(transparent)] requires exactly one field",
80 ));
81 }
82 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
83 return Err(Error::new_spanned(
84 source.original,
85 "transparent variant can't contain #[source]",
86 ));
87 }
88 }
89 check_field_attrs(&self.fields)?;
90 for field in &self.fields {
91 field.validate()?;
92 }
93 Ok(())
94 }
95}
96
97impl Field<'_> {
98 fn validate(&self) -> Result<()> {
99 if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
100 Some(display.original)
101 } else if let Some(fmt) = &self.attrs.fmt {
102 Some(fmt.original)
103 } else if let Some(debug) = &self.attrs.debug {
104 Some(debug.original)
105 } else {
106 None
107 } {
108 return Err(Error::new_spanned(
109 unexpected_display_attr,
110 "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
111 ));
112 }
113 Ok(())
114 }
115}
116
117fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
118 if let Some(from) = &attrs.from {
119 return Err(Error::new_spanned(
120 from.original,
121 "not expected here; the #[from] attribute belongs on a specific field",
122 ));
123 }
124 if let Some(source) = &attrs.source {
125 return Err(Error::new_spanned(
126 source.original,
127 "not expected here; the #[source] attribute belongs on a specific field",
128 ));
129 }
130 if let Some(backtrace) = &attrs.backtrace {
131 return Err(Error::new_spanned(
132 backtrace,
133 "not expected here; the #[backtrace] attribute belongs on a specific field",
134 ));
135 }
136 if attrs.transparent.is_some() {
137 if let Some(display) = &attrs.display {
138 return Err(Error::new_spanned(
139 display.original,
140 "cannot have both #[error(transparent)] and a display attribute",
141 ));
142 }
143 if let Some(fmt) = &attrs.fmt {
144 return Err(Error::new_spanned(
145 fmt.original,
146 "cannot have both #[error(transparent)] and #[error(fmt = ...)]",
147 ));
148 }
149 if let Some(debug) = &attrs.debug {
150 return Err(Error::new_spanned(
151 debug.original,
152 "cannot have both #[error(transparent)] and #[error(debug)]",
153 ));
154 }
155 } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
156 return Err(Error::new_spanned(
157 display.original,
158 "cannot have both #[error(fmt = ...)] and a format arguments attribute",
159 ));
160 } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.debug) {
161 return Err(Error::new_spanned(
162 display.original,
163 "cannot have both #[error(debug)] and a display attribute",
164 ));
165 } else if let (Some(fmt), Some(_)) = (&attrs.fmt, &attrs.debug) {
166 return Err(Error::new_spanned(
167 fmt.original,
168 "cannot have both #[error(fmt = ...)] and #[error(debug)]",
169 ));
170 }
171
172 Ok(())
173}
174
175fn check_field_attrs(fields: &[Field]) -> Result<()> {
176 let mut from_field = None;
177 let mut source_field = None;
178 let mut location_field: Option<&Field> = None;
179 let mut backtrace_field = None;
180 let mut has_backtrace = false;
181 let mut has_location = false;
182 for field in fields {
183 if let Some(from) = field.attrs.from {
184 if from_field.is_some() {
185 return Err(Error::new_spanned(
186 from.original,
187 "duplicate #[from] attribute",
188 ));
189 }
190 from_field = Some(field);
191 }
192 if let Some(source) = field.attrs.source {
193 if source_field.is_some() {
194 return Err(Error::new_spanned(
195 source.original,
196 "duplicate #[source] attribute",
197 ));
198 }
199 source_field = Some(field);
200 }
201 if let Some(backtrace) = field.attrs.backtrace {
202 if backtrace_field.is_some() {
203 return Err(Error::new_spanned(
204 backtrace,
205 "duplicate #[backtrace] attribute",
206 ));
207 }
208 backtrace_field = Some(field);
209 has_backtrace = true;
210 }
211 if let Some(location) = field.attrs.location {
212 if location_field.is_some() {
213 return Err(Error::new_spanned(
214 location,
215 "duplicate #[location] attribute",
216 ));
217 }
218
219 location_field = Some(field);
220 has_location = true;
221 }
222 if let Some(transparent) = field.attrs.transparent {
223 return Err(Error::new_spanned(
224 transparent.original,
225 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
226 ));
227 }
228 has_backtrace |= field.is_backtrace();
229 has_location |= field.is_location();
230 }
231 if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
232 if from_field.member != source_field.member {
233 return Err(Error::new_spanned(
234 from_field.attrs.from.unwrap().original,
235 "#[from] is only supported on the source field, not any other field",
236 ));
237 }
238 }
239 if let Some(from_field) = from_field {
240 let extra_fields = has_backtrace as usize + has_location as usize;
241 let max_expected_fields = match (backtrace_field, location_field) {
242 (Some(backtrace_field), Some(_)) => {
243 2 + (from_field.member != backtrace_field.member) as usize
244 }
245 (Some(backtrace_field), None) => {
246 1 + (from_field.member != backtrace_field.member) as usize
247 }
248 (None, Some(_)) => 1 + extra_fields,
249 (None, None) => 1 + extra_fields,
250 };
251 if fields.len() > max_expected_fields {
252 return Err(Error::new_spanned(
253 from_field.attrs.from.unwrap().original,
254 "deriving From requires no fields other than source, backtrace, and location",
255 ));
256 }
257 }
258 if let Some(source_field) = source_field.or(from_field) {
259 if contains_non_static_lifetime(source_field.ty) {
260 return Err(Error::new_spanned(
261 &source_field.original.ty,
262 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
263 ));
264 }
265 }
266 Ok(())
267}
268
269fn contains_non_static_lifetime(ty: &Type) -> bool {
270 match ty {
271 Type::Path(ty) => {
272 let bracketed = match &ty.path.segments.last().unwrap().arguments {
273 PathArguments::AngleBracketed(bracketed) => bracketed,
274 _ => return false,
275 };
276 for arg in &bracketed.args {
277 match arg {
278 GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
279 GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
280 return true
281 }
282 _ => {}
283 }
284 }
285 false
286 }
287 Type::Reference(ty) => ty
288 .lifetime
289 .as_ref()
290 .map_or(false, |lifetime| lifetime.ident != "static"),
291 _ => false, }
293}