1#![doc(html_root_url = "https://docs.rs/sea-bae/0.2.0")]
95#![allow(clippy::let_and_return)]
96#![deny(
97 unused_variables,
98 dead_code,
99 unused_must_use,
100 unused_imports,
101 missing_docs
102)]
103
104extern crate proc_macro;
105
106use heck::ToSnakeCase;
107use proc_macro2::TokenStream;
108use proc_macro_error2::{abort, proc_macro_error};
109use quote::*;
110use syn::{spanned::Spanned, *};
111
112#[proc_macro_derive(FromAttributes, attributes())]
114#[proc_macro_error]
115pub fn from_attributes(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
116 let item = parse_macro_input!(input as ItemStruct);
117 FromAttributes::new(item).expand().into()
118}
119
120#[derive(Debug)]
121struct FromAttributes {
122 item: ItemStruct,
123 tokens: TokenStream,
124}
125
126impl FromAttributes {
127 fn new(item: ItemStruct) -> Self {
128 Self {
129 item,
130 tokens: TokenStream::new(),
131 }
132 }
133
134 fn expand(mut self) -> TokenStream {
135 self.expand_from_attributes_method();
136 self.expand_parse_impl();
137
138 if std::env::var("BAE_DEBUG").is_ok() {
139 eprintln!("{}", self.tokens);
140 }
141
142 self.tokens
143 }
144
145 fn struct_name(&self) -> &Ident {
146 &self.item.ident
147 }
148
149 fn attr_name(&self) -> LitStr {
150 let struct_name = self.struct_name();
151 let name = struct_name.to_string().to_snake_case();
152 LitStr::new(&name, struct_name.span())
153 }
154
155 fn expand_from_attributes_method(&mut self) {
156 let struct_name = self.struct_name();
157 let attr_name = self.attr_name();
158
159 let code = quote! {
160 impl #struct_name {
161 pub fn try_from_attributes(attrs: &[syn::Attribute]) -> syn::Result<Option<Self>> {
162 use syn::spanned::Spanned;
163
164 for attr in attrs {
165 if attr.path().is_ident(#attr_name) {
166 return Some(attr.parse_args::<Self>()).transpose()
167 }
168 }
169
170 Ok(None)
171 }
172
173 pub fn from_attributes(attrs: &[syn::Attribute]) -> syn::Result<Self> {
174 if let Some(attr) = Self::try_from_attributes(attrs)? {
175 Ok(attr)
176 } else {
177 Err(syn::Error::new(
178 proc_macro2::Span::call_site(),
179 &format!("missing attribute `#[{}]`", #attr_name),
180 ))
181 }
182 }
183 }
184 };
185 self.tokens.extend(code);
186 }
187
188 fn expand_parse_impl(&mut self) {
189 let struct_name = self.struct_name();
190 let attr_name = self.attr_name();
191
192 let variable_declarations = self.item.fields.iter().map(|field| {
193 let name = &field.ident;
194 quote! { let mut #name = std::option::Option::None; }
195 });
196
197 let match_arms = self.item.fields.iter().map(|field| {
198 let field_name = get_field_name(field);
199 let pattern = LitStr::new(&field_name.to_string(), field.span());
200
201 if field_is_switch(field) {
202 quote! {
203 #pattern => {
204 #field_name = std::option::Option::Some(());
205 }
206 }
207 } else {
208 quote! {
209 #pattern => {
210 input.parse::<syn::Token![=]>()?;
211 #field_name = std::option::Option::Some(input.parse()?);
212 }
213 }
214 }
215 });
216
217 let unwrap_mandatory_fields = self
218 .item
219 .fields
220 .iter()
221 .filter(|field| !field_is_optional(field))
222 .map(|field| {
223 let field_name = get_field_name(field);
224 let arg_name = LitStr::new(&field_name.to_string(), field.span());
225
226 quote! {
227 let #field_name = if let std::option::Option::Some(#field_name) = #field_name {
228 #field_name
229 } else {
230 return syn::Result::Err(
231 input.error(
232 &format!("`#[{}]` is missing `{}` argument", #attr_name, #arg_name),
233 )
234 );
235 };
236 }
237 });
238
239 let set_fields = self.item.fields.iter().map(|field| {
240 let field_name = get_field_name(field);
241 quote! { #field_name, }
242 });
243
244 let mut supported_args = self
245 .item
246 .fields
247 .iter()
248 .map(|field| get_field_name(field))
249 .map(|field_name| format!("`{}`", field_name))
250 .collect::<Vec<_>>();
251 supported_args.sort_unstable();
252 let supported_args = supported_args.join(", ");
253
254 let code = quote! {
255 impl syn::parse::Parse for #struct_name {
256 #[allow(unreachable_code, unused_imports, unused_variables)]
257 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
258 #(#variable_declarations)*
259
260 while !input.is_empty() {
261 let bae_attr_ident = input.parse::<syn::Ident>()?;
262
263 match &*bae_attr_ident.to_string() {
264 #(#match_arms)*
265 other => {
266 return syn::Result::Err(
267 syn::Error::new(
268 bae_attr_ident.span(),
269 &format!(
270 "`#[{}]` got unknown `{}` argument. Supported arguments are {}",
271 #attr_name,
272 other,
273 #supported_args,
274 ),
275 )
276 );
277 }
278 }
279
280 input.parse::<syn::Token![,]>().ok();
281 }
282
283 #(#unwrap_mandatory_fields)*
284
285 syn::Result::Ok(Self { #(#set_fields)* })
286 }
287 }
288 };
289 self.tokens.extend(code);
290 }
291}
292
293fn get_field_name(field: &Field) -> &Ident {
294 field
295 .ident
296 .as_ref()
297 .unwrap_or_else(|| abort!(field.span(), "Field without a name"))
298}
299
300fn field_is_optional(field: &Field) -> bool {
301 let type_path = if let Type::Path(type_path) = &field.ty {
302 type_path
303 } else {
304 return false;
305 };
306
307 let ident = &type_path
308 .path
309 .segments
310 .last()
311 .unwrap_or_else(|| abort!(field.span(), "Empty type path"))
312 .ident;
313
314 ident == "Option"
315}
316
317fn field_is_switch(field: &Field) -> bool {
318 let unit_type = syn::parse_str::<Type>("()").unwrap();
319 inner_type(&field.ty) == Some(&unit_type)
320}
321
322fn inner_type(ty: &Type) -> Option<&Type> {
323 let type_path = if let Type::Path(type_path) = ty {
324 type_path
325 } else {
326 return None;
327 };
328
329 let ty_args = &type_path
330 .path
331 .segments
332 .last()
333 .unwrap_or_else(|| abort!(ty.span(), "Empty type path"))
334 .arguments;
335
336 let ty_args = if let PathArguments::AngleBracketed(ty_args) = ty_args {
337 ty_args
338 } else {
339 return None;
340 };
341
342 let generic_arg = &ty_args
343 .args
344 .last()
345 .unwrap_or_else(|| abort!(ty_args.span(), "Empty generic argument"));
346
347 let ty = if let GenericArgument::Type(ty) = generic_arg {
348 ty
349 } else {
350 return None;
351 };
352
353 Some(ty)
354}
355
356#[cfg(test)]
357mod test {
358 #[allow(unused_imports)]
359 use super::*;
360
361 #[test]
362 fn test_ui() {
363 let t = trybuild::TestCases::new();
364 t.pass("tests/compile_pass/*.rs");
365 t.compile_fail("tests/compile_fail/*.rs");
366 }
367}