with_id_derive/
lib.rs

1use proc_macro::TokenStream;
2use std::ops::Deref;
3use quote::quote;
4use syn::{spanned::Spanned, DeriveInput, Data, DataStruct, Field, parse_macro_input};
5extern crate syn;
6
7
8fn get_by_attr(data_struct: &DataStruct) -> Option<&Field> {
9    data_struct.fields.iter().find(|field| {
10        field.attrs.iter().any(|attr| attr.path().is_ident("id"))
11    })
12}
13
14
15fn get_by_name(data_struct: &DataStruct) -> Option<&Field> {
16    data_struct.fields.iter().find(|field| {
17        field.ident.as_ref().map_or(false, |ident| ident == "id")
18    })
19}
20
21
22
23fn get_id_field(ast: &DeriveInput)->Result<&Field,syn::Error>{
24
25    let id_field =
26        match ast.data {
27            Data::Struct(ref data_struct) => get_by_attr(data_struct),
28            _ => Err(syn::Error::new(ast.span(),
29                "WithId can only be derived for structs",
30            ))?
31        };
32
33
34    // If no #[id] attribute is present, try to find a field named "id"
35    let id_field =match id_field {
36        Some(field) => Some(field),
37        None => {
38            match ast.data {
39                Data::Struct(ref data_struct) => get_by_name(data_struct),
40                _ => Err(syn::Error::new(
41                    ast.span(),
42                    "WithId can only be derived for structs",
43                ))?
44            }
45        }
46    };
47    // If no "id" field is present, return an error
48    return match id_field {
49        Some(field) => Ok(field),
50        None => {
51            Err( syn::Error::new(
52                ast.span(),
53                "struct must have a field marked with #[id] attribute or named 'id'",
54            ))?
55        }
56    };
57}
58
59
60#[proc_macro_derive(WithStringId, attributes(id))]
61pub fn with_string_id_derive(input: TokenStream) -> TokenStream {
62    // Parse the input tokens into a syntax tree
63    let ast = parse_macro_input!(input as DeriveInput);
64
65    let name = &ast.ident;
66    // Try to find the field marked with the #[id] attribute
67    let id_field =  match get_id_field(&ast) {
68        Ok(field) => field,
69        Err(err) => return err.to_compile_error().into()
70    };
71
72
73    let id_field_name = id_field.ident.as_ref().unwrap();
74
75    let lifetimes= ast.generics.lifetimes();
76    let lifetimes_count = ast.generics.lifetimes().count();
77    let lifetime_params = if lifetimes_count == 0 {
78        quote!{}
79    } else {
80        quote! { <#(#lifetimes),*> }
81    };
82
83
84    // Generate the implementation for the trait
85    let gen =
86        quote! {
87                    impl#lifetime_params WithStringId for #name#lifetime_params {
88                        fn id(&self) -> String {
89                            self.#id_field_name.to_string()
90                        }
91                    }
92        };
93    // Return the generated implementation
94    gen.into()
95}
96
97
98#[proc_macro_derive(WithId, attributes(id))]
99pub fn with_id_derive(input: TokenStream) -> TokenStream {
100    // Parse the input tokens into a syntax tree
101    let ast = parse_macro_input!(input as DeriveInput);
102
103    let name = &ast.ident;
104    // Try to find the field marked with the #[id] attribute
105    let id_field =  match get_id_field(&ast) {
106        Ok(field) => field,
107        Err(err) => return err.to_compile_error().into()
108    };
109
110    let id_field_name = id_field.ident.as_ref().unwrap();
111    let id_field_ty = &id_field.ty;
112    let lifetimes= ast.generics.lifetimes();
113    let lifetimes_count = ast.generics.lifetimes().count();
114    let lifetime_params = if lifetimes_count == 0 {
115        quote!{}
116    } else {
117        quote! { <#(#lifetimes),*> }
118    };
119    // Generate the implementation for the trait
120    let gen =
121        quote! {
122                    impl#lifetime_params WithId<#id_field_ty> for #name#lifetime_params {
123                        fn id(&self) -> #id_field_ty {
124                            self.#id_field_name.clone()
125                        }
126                    }
127        };
128    // Return the generated implementation
129    gen.into()
130}
131
132
133#[proc_macro_derive(WithRefId, attributes(id))]
134pub fn with_ref_id_derive(input: TokenStream) -> TokenStream {
135    // Parse the input tokens into a syntax tree
136    let ast = parse_macro_input!(input as DeriveInput);
137
138    let name = &ast.ident;
139    // Try to find the field marked with the #[id] attribute
140    let id_field =  match get_id_field(&ast) {
141        Ok(field) => field,
142        Err(err) => return err.to_compile_error().into()
143    };
144
145    let id_field_name = id_field.ident.as_ref().unwrap();
146    let id_field_ty = &id_field.ty;
147
148    let lifetimes= ast.generics.lifetimes();
149    let lifetimes_count = ast.generics.lifetimes().count();
150    let lifetime_params = if lifetimes_count == 0 {
151        quote!{}
152    } else {
153        quote! { <#(#lifetimes),*> }
154    };
155
156
157    let gen = if let syn::Type::Path(type_path) = id_field_ty {
158        if let Some(segment) = type_path.path.segments.first() {
159            if segment.ident == "String" {
160                quote! {
161                    impl#lifetime_params WithRefId<str> for #name#lifetime_params {
162                        fn id(&self) -> &str {
163                            self.#id_field_name.as_str()
164                        }
165                    }
166                }
167            }else{
168                quote! {
169                    impl#lifetime_params WithRefId<#id_field_ty> for #name#lifetime_params {
170                        fn id(&self) -> &#id_field_ty {
171                            &self.#id_field_name
172                        }
173                    }
174                }
175            }
176        }else{
177            return syn::Error::new(id_field_ty.span(), "unexpected error: id field has an empty path").to_compile_error().into();
178        }
179    }else if let syn::Type::Reference(type_reference) = id_field_ty  {
180        let ref_type = type_reference.elem.deref();
181            quote! {
182                    impl#lifetime_params WithRefId<#ref_type> for #name#lifetime_params {
183                        fn id(&self) -> &#ref_type {
184                            self.#id_field_name
185                        }
186                    }
187                }
188    }else{
189        return syn::Error::new(id_field_ty.span(), "unexpected error: id field is not a path or reference type").to_compile_error().into();
190    };
191
192    // Return the generated implementation
193    gen.into()
194}