1use proc_macro::TokenStream;
2use std::ops::Deref;
3use quote::quote;
4use syn::{spanned::Spanned, DeriveInput, Data, DataStruct, Field, parse_macro_input};
5extern crate syn;
6
7
8fn get_by_attr(data_struct: &DataStruct) -> Option<&Field> {
9 data_struct.fields.iter().find(|field| {
10 field.attrs.iter().any(|attr| attr.path().is_ident("id"))
11 })
12}
13
14
15fn get_by_name(data_struct: &DataStruct) -> Option<&Field> {
16 data_struct.fields.iter().find(|field| {
17 field.ident.as_ref().map_or(false, |ident| ident == "id")
18 })
19}
20
21
22
23fn get_id_field(ast: &DeriveInput)->Result<&Field,syn::Error>{
24
25 let id_field =
26 match ast.data {
27 Data::Struct(ref data_struct) => get_by_attr(data_struct),
28 _ => Err(syn::Error::new(ast.span(),
29 "WithId can only be derived for structs",
30 ))?
31 };
32
33
34 let id_field =match id_field {
36 Some(field) => Some(field),
37 None => {
38 match ast.data {
39 Data::Struct(ref data_struct) => get_by_name(data_struct),
40 _ => Err(syn::Error::new(
41 ast.span(),
42 "WithId can only be derived for structs",
43 ))?
44 }
45 }
46 };
47 return match id_field {
49 Some(field) => Ok(field),
50 None => {
51 Err( syn::Error::new(
52 ast.span(),
53 "struct must have a field marked with #[id] attribute or named 'id'",
54 ))?
55 }
56 };
57}
58
59
60#[proc_macro_derive(WithStringId, attributes(id))]
61pub fn with_string_id_derive(input: TokenStream) -> TokenStream {
62 let ast = parse_macro_input!(input as DeriveInput);
64
65 let name = &ast.ident;
66 let id_field = match get_id_field(&ast) {
68 Ok(field) => field,
69 Err(err) => return err.to_compile_error().into()
70 };
71
72
73 let id_field_name = id_field.ident.as_ref().unwrap();
74
75 let lifetimes= ast.generics.lifetimes();
76 let lifetimes_count = ast.generics.lifetimes().count();
77 let lifetime_params = if lifetimes_count == 0 {
78 quote!{}
79 } else {
80 quote! { <#(#lifetimes),*> }
81 };
82
83
84 let gen =
86 quote! {
87 impl#lifetime_params WithStringId for #name#lifetime_params {
88 fn id(&self) -> String {
89 self.#id_field_name.to_string()
90 }
91 }
92 };
93 gen.into()
95}
96
97
98#[proc_macro_derive(WithId, attributes(id))]
99pub fn with_id_derive(input: TokenStream) -> TokenStream {
100 let ast = parse_macro_input!(input as DeriveInput);
102
103 let name = &ast.ident;
104 let id_field = match get_id_field(&ast) {
106 Ok(field) => field,
107 Err(err) => return err.to_compile_error().into()
108 };
109
110 let id_field_name = id_field.ident.as_ref().unwrap();
111 let id_field_ty = &id_field.ty;
112 let lifetimes= ast.generics.lifetimes();
113 let lifetimes_count = ast.generics.lifetimes().count();
114 let lifetime_params = if lifetimes_count == 0 {
115 quote!{}
116 } else {
117 quote! { <#(#lifetimes),*> }
118 };
119 let gen =
121 quote! {
122 impl#lifetime_params WithId<#id_field_ty> for #name#lifetime_params {
123 fn id(&self) -> #id_field_ty {
124 self.#id_field_name.clone()
125 }
126 }
127 };
128 gen.into()
130}
131
132
133#[proc_macro_derive(WithRefId, attributes(id))]
134pub fn with_ref_id_derive(input: TokenStream) -> TokenStream {
135 let ast = parse_macro_input!(input as DeriveInput);
137
138 let name = &ast.ident;
139 let id_field = match get_id_field(&ast) {
141 Ok(field) => field,
142 Err(err) => return err.to_compile_error().into()
143 };
144
145 let id_field_name = id_field.ident.as_ref().unwrap();
146 let id_field_ty = &id_field.ty;
147
148 let lifetimes= ast.generics.lifetimes();
149 let lifetimes_count = ast.generics.lifetimes().count();
150 let lifetime_params = if lifetimes_count == 0 {
151 quote!{}
152 } else {
153 quote! { <#(#lifetimes),*> }
154 };
155
156
157 let gen = if let syn::Type::Path(type_path) = id_field_ty {
158 if let Some(segment) = type_path.path.segments.first() {
159 if segment.ident == "String" {
160 quote! {
161 impl#lifetime_params WithRefId<str> for #name#lifetime_params {
162 fn id(&self) -> &str {
163 self.#id_field_name.as_str()
164 }
165 }
166 }
167 }else{
168 quote! {
169 impl#lifetime_params WithRefId<#id_field_ty> for #name#lifetime_params {
170 fn id(&self) -> &#id_field_ty {
171 &self.#id_field_name
172 }
173 }
174 }
175 }
176 }else{
177 return syn::Error::new(id_field_ty.span(), "unexpected error: id field has an empty path").to_compile_error().into();
178 }
179 }else if let syn::Type::Reference(type_reference) = id_field_ty {
180 let ref_type = type_reference.elem.deref();
181 quote! {
182 impl#lifetime_params WithRefId<#ref_type> for #name#lifetime_params {
183 fn id(&self) -> &#ref_type {
184 self.#id_field_name
185 }
186 }
187 }
188 }else{
189 return syn::Error::new(id_field_ty.span(), "unexpected error: id field is not a path or reference type").to_compile_error().into();
190 };
191
192 gen.into()
194}