1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use syn::{Error, GenericArgument, PathArguments, Result, Type};
4
5impl Input<'_> {
6 pub(crate) fn validate(&self) -> Result<()> {
7 match self {
8 Input::Struct(input) => input.validate(),
9 Input::Enum(input) => input.validate(),
10 }
11 }
12}
13
14impl Struct<'_> {
15 fn validate(&self) -> Result<()> {
16 check_non_field_attrs(&self.attrs)?;
17 if let Some(transparent) = self.attrs.transparent {
18 if self.fields.len() != 1 {
19 return Err(Error::new_spanned(
20 transparent.original,
21 "#[error(transparent)] requires exactly one field",
22 ));
23 }
24 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
25 return Err(Error::new_spanned(
26 source.original,
27 "transparent error struct can't contain #[source]",
28 ));
29 }
30 }
31 if let Some(fmt) = &self.attrs.fmt {
32 return Err(Error::new_spanned(
33 fmt.original,
34 "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
35 ));
36 }
37 check_field_attrs(&self.fields)?;
38 for field in &self.fields {
39 field.validate()?;
40 }
41 Ok(())
42 }
43}
44
45impl Enum<'_> {
46 fn validate(&self) -> Result<()> {
47 check_non_field_attrs(&self.attrs)?;
48 let has_display = self.has_display();
49 for variant in &self.variants {
50 variant.validate()?;
51 if has_display
52 && variant.attrs.display.is_none()
53 && variant.attrs.transparent.is_none()
54 && variant.attrs.fmt.is_none()
55 {
56 return Err(Error::new_spanned(
57 variant.original,
58 "missing #[error(\"...\")] display attribute",
59 ));
60 }
61 }
62 Ok(())
63 }
64}
65
66impl Variant<'_> {
67 fn validate(&self) -> Result<()> {
68 check_non_field_attrs(&self.attrs)?;
69 if self.attrs.transparent.is_some() {
70 if self.fields.len() != 1 {
71 return Err(Error::new_spanned(
72 self.original,
73 "#[error(transparent)] requires exactly one field",
74 ));
75 }
76 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
77 return Err(Error::new_spanned(
78 source.original,
79 "transparent variant can't contain #[source]",
80 ));
81 }
82 }
83 check_field_attrs(&self.fields)?;
84 for field in &self.fields {
85 field.validate()?;
86 }
87 Ok(())
88 }
89}
90
91impl Field<'_> {
92 fn validate(&self) -> Result<()> {
93 if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
94 Some(display.original)
95 } else if let Some(fmt) = &self.attrs.fmt {
96 Some(fmt.original)
97 } else {
98 None
99 } {
100 return Err(Error::new_spanned(
101 unexpected_display_attr,
102 "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
103 ));
104 }
105 Ok(())
106 }
107}
108
109fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
110 if let Some(from) = &attrs.from {
111 return Err(Error::new_spanned(
112 from.original,
113 "not expected here; the #[from] attribute belongs on a specific field",
114 ));
115 }
116 if let Some(source) = &attrs.source {
117 return Err(Error::new_spanned(
118 source.original,
119 "not expected here; the #[source] attribute belongs on a specific field",
120 ));
121 }
122 if let Some(backtrace) = &attrs.backtrace {
123 return Err(Error::new_spanned(
124 backtrace,
125 "not expected here; the #[backtrace] attribute belongs on a specific field",
126 ));
127 }
128 if attrs.transparent.is_some() {
129 if let Some(display) = &attrs.display {
130 return Err(Error::new_spanned(
131 display.original,
132 "cannot have both #[error(transparent)] and a display attribute",
133 ));
134 }
135 if let Some(fmt) = &attrs.fmt {
136 return Err(Error::new_spanned(
137 fmt.original,
138 "cannot have both #[error(transparent)] and #[error(fmt = ...)]",
139 ));
140 }
141 } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
142 return Err(Error::new_spanned(
143 display.original,
144 "cannot have both #[error(fmt = ...)] and a format arguments attribute",
145 ));
146 }
147
148 Ok(())
149}
150
151fn check_field_attrs(fields: &[Field]) -> Result<()> {
152 let mut from_field = None;
153 let mut source_field = None;
154 let mut location_field: Option<&Field> = None;
155 let mut backtrace_field = None;
156 let mut has_backtrace = false;
157 let mut has_location = false;
158 for field in fields {
159 if let Some(from) = field.attrs.from {
160 if from_field.is_some() {
161 return Err(Error::new_spanned(
162 from.original,
163 "duplicate #[from] attribute",
164 ));
165 }
166 from_field = Some(field);
167 }
168 if let Some(source) = field.attrs.source {
169 if source_field.is_some() {
170 return Err(Error::new_spanned(
171 source.original,
172 "duplicate #[source] attribute",
173 ));
174 }
175 source_field = Some(field);
176 }
177 if let Some(backtrace) = field.attrs.backtrace {
178 if backtrace_field.is_some() {
179 return Err(Error::new_spanned(
180 backtrace,
181 "duplicate #[backtrace] attribute",
182 ));
183 }
184 backtrace_field = Some(field);
185 has_backtrace = true;
186 }
187 if let Some(location) = field.attrs.location {
188 if location_field.is_some() {
189 return Err(Error::new_spanned(
190 location,
191 "duplicate #[location] attribute",
192 ));
193 }
194
195 location_field = Some(field);
196 has_location = true;
197 }
198 if let Some(transparent) = field.attrs.transparent {
199 return Err(Error::new_spanned(
200 transparent.original,
201 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
202 ));
203 }
204 has_backtrace |= field.is_backtrace();
205 has_location |= field.is_location();
206 }
207 if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
208 if from_field.member != source_field.member {
209 return Err(Error::new_spanned(
210 from_field.attrs.from.unwrap().original,
211 "#[from] is only supported on the source field, not any other field",
212 ));
213 }
214 }
215 if let Some(from_field) = from_field {
216 let extra_fields = has_backtrace as usize + has_location as usize;
217 let max_expected_fields = match (backtrace_field, location_field) {
218 (Some(backtrace_field), Some(_)) => {
219 2 + (from_field.member != backtrace_field.member) as usize
220 }
221 (Some(backtrace_field), None) => {
222 1 + (from_field.member != backtrace_field.member) as usize
223 }
224 (None, Some(_)) => 1 + extra_fields,
225 (None, None) => 1 + extra_fields,
226 };
227 if fields.len() > max_expected_fields {
228 return Err(Error::new_spanned(
229 from_field.attrs.from.unwrap().original,
230 "deriving From requires no fields other than source, backtrace, and location",
231 ));
232 }
233 }
234 if let Some(source_field) = source_field.or(from_field) {
235 if contains_non_static_lifetime(source_field.ty) {
236 return Err(Error::new_spanned(
237 &source_field.original.ty,
238 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
239 ));
240 }
241 }
242 Ok(())
243}
244
245fn contains_non_static_lifetime(ty: &Type) -> bool {
246 match ty {
247 Type::Path(ty) => {
248 let bracketed = match &ty.path.segments.last().unwrap().arguments {
249 PathArguments::AngleBracketed(bracketed) => bracketed,
250 _ => return false,
251 };
252 for arg in &bracketed.args {
253 match arg {
254 GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
255 GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
256 return true
257 }
258 _ => {}
259 }
260 }
261 false
262 }
263 Type::Reference(ty) => ty
264 .lifetime
265 .as_ref()
266 .map_or(false, |lifetime| lifetime.ident != "static"),
267 _ => false, }
269}