silent_db_macros/
lib.rs

1#![allow(dead_code)]
2
3use std::fmt::Display;
4
5use darling::{ast, FromDeriveInput, FromField, FromMeta};
6use proc_macro2::TokenStream;
7use quote::{quote, ToTokens};
8use syn::parse_macro_input;
9
10use crate::utils::{to_camel_case, to_snake_case};
11
12mod utils;
13
14#[proc_macro_derive(Table, attributes(table, field))]
15pub fn derive_table(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
16    let input = parse_macro_input!(item as syn::DeriveInput);
17    let table_attr = TableAttr::from_derive_input(&input).unwrap();
18    let mut tokens = TokenStream::new();
19    table_attr.to_tokens(&mut tokens);
20    tokens.into()
21}
22
23#[derive(Debug, FromDeriveInput)]
24#[darling(
25    attributes(table),
26    supports(struct_any),
27    forward_attrs(allow, doc, cfg)
28)]
29struct TableAttr {
30    ident: syn::Ident,
31    generics: syn::Generics,
32    data: ast::Data<(), FieldAttr>,
33    name: Option<String>,
34    comment: Option<String>,
35    #[darling(multiple)]
36    index: Vec<IndexAttr>,
37}
38impl ToTokens for TableAttr {
39    fn to_tokens(&self, tokens: &mut TokenStream) {
40        let TableAttr {
41            ref ident,
42            ref data,
43            name,
44            comment,
45            index,
46            ..
47        } = self;
48        let indices = index;
49        let struct_name = ident;
50
51        let comment = match comment {
52            Some(c) => quote! { Some(#c.to_string()) },
53            None => {
54                quote! { None }
55            }
56        };
57
58        let name = match name {
59            Some(c) => quote! { #c.to_string() },
60            None => {
61                let table_name = to_snake_case(&struct_name.to_string());
62                quote! { #table_name.to_string() }
63            }
64        };
65        let fields = data
66            .as_ref()
67            .take_struct()
68            .expect("Should never be enum")
69            .fields;
70
71        let mut fields_data: Vec<FieldToken> = vec![];
72        for field in fields {
73            fields_data.push(derive_field_attribute(
74                field,
75                field.ident.as_ref().unwrap().to_string(),
76            ));
77        }
78        let field_name_list = fields_data
79            .iter()
80            .map(|f| f.name.clone())
81            .collect::<Vec<String>>();
82        for field in fields_data {
83            tokens.extend(field.token_stream);
84        }
85
86        let fields_code = format!(
87            "vec![{}]",
88            field_name_list
89                .iter()
90                .map(|field| format!("{}::new().rc()", to_camel_case(field)))
91                .collect::<Vec<String>>()
92                .join(", ")
93        );
94
95        let indices_code = format!(
96            "vec![{}]",
97            indices
98                .iter()
99                .map(|index| {
100                    if !index.check_fields(&field_name_list) {
101                        panic!("Index fields is empty");
102                    }
103                    format!("Rc::new({})", index)
104                })
105                .collect::<Vec<String>>()
106                .join(", ")
107        );
108        let fields_token: TokenStream = fields_code.parse().unwrap();
109        let indices_token: TokenStream = indices_code.parse().unwrap();
110
111        // Generate the code for implementing the trait
112        let expanded = quote! {
113            impl TableManage for #struct_name {
114                fn manager() -> Box<dyn Table> {
115                    Box::new(TableManager {
116                        name: #name,
117                        fields: #fields_token,
118                        indices: #indices_token,
119                        comment: #comment,
120                    })
121                }
122            }
123        };
124
125        tokens.extend(expanded);
126    }
127}
128
129#[derive(Debug, FromMeta)]
130struct IndexAttr {
131    alias: Option<String>,
132    index_type: String,
133    fields: String,
134    sort: Option<String>,
135}
136
137impl IndexAttr {
138    fn get_index_type(&self) -> String {
139        match self.index_type.as_str() {
140            "unique" => "IndexType::Unique".to_string(),
141            "fulltext" => "IndexType::FullText".to_string(),
142            "spatial" => "IndexType::Spatial".to_string(),
143            _ => "IndexType::Index".to_string(),
144        }
145    }
146
147    fn get_sort(&self) -> String {
148        if self.sort == Some("desc".to_string()) {
149            "IndexSort::DESC".to_string()
150        } else {
151            self.sort.clone().unwrap_or("IndexSort::ASC".to_string())
152        }
153    }
154    pub(crate) fn check_fields(&self, fields: &[String]) -> bool {
155        let index_fields = self
156            .fields
157            .clone()
158            .split(',')
159            .map(|s| s.to_string())
160            .collect::<Vec<String>>();
161        index_fields.iter().all(|f| fields.contains(f))
162    }
163}
164
165impl Display for IndexAttr {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        let alias = match &self.alias {
168            Some(alias) => format!("Some(\"{}\".to_string())", alias),
169            None => "None".to_string(),
170        };
171        let fields = self
172            .fields
173            .split(',')
174            .map(|f| format!("\"{}\".to_string()", f))
175            .collect::<Vec<String>>()
176            .join(",");
177
178        write!(
179            f,
180            "Index {{alias:{}, index_type: {}, fields: vec![{}],sort: {}}}",
181            alias,
182            self.get_index_type(),
183            fields,
184            self.get_sort()
185        )
186    }
187}
188
189#[derive(Debug)]
190struct FieldToken {
191    name: String,
192    token_stream: TokenStream,
193}
194
195#[derive(Debug, FromField)]
196#[darling(attributes(field))]
197struct FieldAttr {
198    ident: Option<syn::Ident>,
199    ty: syn::Type,
200    field_type: String,
201    name: Option<String>,
202    default: Option<String>,
203    nullable: Option<bool>,
204    primary_key: Option<bool>,
205    auto_increment: Option<bool>,
206    unique: Option<bool>,
207    comment: Option<String>,
208    max_digits: Option<u8>,
209    decimal_places: Option<u8>,
210    max_length: Option<u16>,
211}
212
213fn derive_field_attribute(field_attr: &FieldAttr, field_name: String) -> FieldToken {
214    let FieldAttr {
215        field_type,
216        name,
217        default,
218        nullable,
219        primary_key,
220        auto_increment,
221        unique,
222        comment,
223        max_digits,
224        decimal_places,
225        max_length,
226        ..
227    } = field_attr;
228
229    let snake_field_name = to_snake_case(&field_name.to_string());
230    let camel_field_name = to_camel_case(&field_name.to_string());
231
232    let args = quote! {
233        name: Self::get_field(),
234    };
235    // 设置字段名称
236    let args = match name.clone() {
237        Some(c) => quote! { name: #c.to_string(), },
238        None => quote! { #args },
239    };
240    // 设置字段默认值
241    let args = match default {
242        Some(c) => quote! { #args
243        default: Some(#c.to_string()), },
244        None => quote! { #args },
245    };
246    // 设置字段是否为空
247    let args = match nullable {
248        Some(c) => quote! { #args
249        nullable: #c, },
250        None => quote! { #args },
251    };
252    // 设置字段是否为主键
253    let args = match primary_key {
254        Some(c) => quote! { #args
255        primary_key: #c, },
256        None => quote! { #args },
257    };
258    // 设置字段是否唯一
259    let args = match unique {
260        Some(c) => quote! { #args
261        unique: #c, },
262        None => quote! { #args },
263    };
264    // 设置字段注释
265    let args = match comment {
266        Some(c) => quote! { #args
267        comment: Some(#c.to_string()), },
268        None => quote! { #args },
269    };
270    // 设置字段是否自增
271    let args = match auto_increment {
272        Some(c) => quote! { #args
273        auto_increment: #c, },
274        None => quote! { #args },
275    };
276    // 设置字段最大位数
277    let args = match max_digits {
278        Some(c) => quote! { #args
279        max_digits: #c, },
280        None => quote! { #args },
281    };
282    // 设置字段小数位数
283    let args = match decimal_places {
284        Some(c) => quote! { #args
285        decimal_places: #c, },
286        None => quote! { #args },
287    };
288    // 设置字段长度
289    let args = match max_length {
290        Some(c) => quote! { #args
291        length: #c, },
292        None => quote! { #args },
293    };
294
295    let code = format!(
296        r#"
297    pub struct {camel_field_name}({field_type});
298    
299    impl Query for {camel_field_name} {{
300        fn get_field() -> String {{
301            "{snake_field_name}".to_string()
302        }}
303    }}
304    
305    impl {camel_field_name} {{
306        pub fn new() -> Self {{
307            {camel_field_name}({field_type} {{
308                {args}
309                ..Default::default()
310            }})
311        }}
312        pub fn rc(&self) -> Rc<{field_type}> {{
313            Rc::new(self.0.clone())
314        }}
315    }}
316    "#,
317        camel_field_name = camel_field_name,
318        field_type = field_type,
319        snake_field_name = snake_field_name,
320        args = args
321    );
322    FieldToken {
323        name: name.clone().unwrap_or(snake_field_name),
324        token_stream: code.parse().unwrap(),
325    }
326}