sqlx_model_macros/
lib.rs

1use heck::{
2    ToKebabCase, ToLowerCamelCase, ToPascalCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase,
3};
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{
7    parse_macro_input, Attribute, Data, DataEnum, DataStruct, DeriveInput, Fields, Meta, NestedMeta,
8};
9
10fn get_sqlx_field_rename(attrs: &[Attribute]) -> Option<String> {
11    for attr in attrs.iter() {
12        let meta = attr
13            .parse_meta()
14            .map_err(|e| syn::Error::new_spanned(attr, e))
15            .unwrap();
16        if let Meta::List(list) = meta {
17            for cattr in list.nested.iter() {
18                if let NestedMeta::Meta(Meta::NameValue(ref attr_ident)) = cattr {
19                    let name = attr_ident.clone();
20                    let name = name.path.get_ident().unwrap().to_string();
21                    let name = name.as_str();
22                    let ident = attr_ident.clone();
23                    if name == "rename" {
24                        let rename = match ident.lit {
25                            syn::Lit::Str(val) => val,
26                            _ => unreachable!("rename be string"),
27                        }
28                        .value();
29                        return Some(rename);
30                    }
31                }
32            }
33        }
34    }
35    None
36}
37fn change_sqlx_field_rename(change_type: &Option<String>, field_name: String) -> String {
38    if let Some(str) = change_type {
39        match str.as_str() {
40            "lowercase" => {
41                return field_name.to_lowercase();
42            }
43            "snake_case" => {
44                return field_name.to_snake_case();
45            }
46            "UPPERCASE" => {
47                return field_name.to_uppercase();
48            }
49            "SCREAMING_SNAKE_CASE" => {
50                return field_name.to_shouty_snake_case();
51            }
52            "kebab-case" => {
53                return field_name.to_kebab_case();
54            }
55            "camelCase" => {
56                return field_name.to_lower_camel_case();
57            }
58            "UpperCamelCase" => {
59                return field_name.to_upper_camel_case();
60            }
61            "PascalCase" => {
62                return field_name.to_pascal_case();
63            }
64            _ => {}
65        }
66    }
67    field_name
68}
69
70#[proc_macro_attribute]
71// model 自动填充方法
72pub fn sqlx_model(args: TokenStream, item: TokenStream) -> TokenStream {
73    let input = parse_macro_input!(item as DeriveInput);
74    let struct_name = &input.ident;
75
76    let mut db_type = None;
77    let mut table_name = None;
78    let mut rename_all = None;
79    let mut table_pk = vec![];
80    let args = syn::parse_macro_input!(args as syn::AttributeArgs);
81    // for attr in args
82    //     .iter()
83    //     .filter(|e| e.path.is_ident("sqlx_model") || e.path.is_ident("sqlx"))
84    // {
85    //     let meta = attr
86    //         .parse_meta()
87    //         .map_err(|e| syn::Error::new_spanned(attr, e))
88    //         .unwrap();
89    //     if let Meta::List(list) = meta {
90    for cattr in args.iter() {
91        if let NestedMeta::Meta(Meta::NameValue(ref attr_ident)) = cattr {
92            let name = attr_ident.clone();
93            let name = name.path.get_ident().unwrap().to_string();
94            let name = name.as_str();
95            let ident = attr_ident.clone();
96            match name {
97                "db_type" => {
98                    let val = match ident.lit {
99                        syn::Lit::Str(val) => val,
100                        _ => unreachable!("table name must be string"),
101                    }
102                    .value();
103                    db_type = Some(val);
104                }
105                "table_name" => {
106                    let val = match ident.lit {
107                        syn::Lit::Str(val) => val,
108                        _ => unreachable!("table name must be string"),
109                    }
110                    .value();
111                    table_name = Some(val);
112                }
113                "table_pk" => {
114                    let val = match ident.lit {
115                        syn::Lit::Str(val) => val,
116                        _ => unreachable!("table pk field must be string"),
117                    }
118                    .value();
119                    table_pk.push(val);
120                }
121                "rename_all" => {
122                    if let syn::Lit::Str(val) = ident.lit {
123                        let str = &*val.value();
124                        rename_all = Some(str.to_owned());
125                    }
126                }
127                _ => {}
128            }
129        }
130    }
131    //     }
132    // }
133    let db_type = quote::format_ident!("{}", db_type.expect("database type not set"));
134    let table_name = table_name.unwrap_or_else(|| {
135        let mut name = struct_name.to_string();
136        if name.clone().drain(0..5).collect::<String>() == "Model" {
137            name = name.drain(5..).collect::<String>();
138        }
139        if name.clone().drain(name.len() - 5..).collect::<String>() == "Model" {
140            name = name.drain(0..name.len() - 5).collect::<String>();
141        }
142        name.chars()
143            .enumerate()
144            .map(|(i, e)| {
145                if i != 0 && e as u8 >= 65 && e as u8 <= 90 {
146                    format!("_{}", e.to_ascii_lowercase())
147                } else {
148                    e.to_ascii_lowercase().to_string()
149                }
150            })
151            .collect::<Vec<String>>()
152            .join("")
153    });
154    let expanded = match &input.data {
155        Data::Struct(DataStruct { ref fields, .. }) => {
156            if let Fields::Named(ref fields_name) = fields {
157                let change_fields: Vec<_> = fields_name
158                    .named
159                    .iter()
160                    .map(|field| {
161                        let field_name = field.ident.as_ref().unwrap();
162                        let str_field_name = match get_sqlx_field_rename(&field.attrs) {
163                            Some(str) => str,
164                            _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
165                        };
166                        let field_type = field.ty.clone();
167                        quote! {
168                            #field_name[#str_field_name]:#field_type
169                        }
170                    })
171                    .collect();
172                let bind_fields: Vec<_> = fields_name
173                    .named
174                    .iter()
175                    .map(|field| {
176                        let field_name = field.ident.as_ref().unwrap();
177                        let str_field_name = match get_sqlx_field_rename(&field.attrs) {
178                            Some(str) => str,
179                            _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
180                        };
181                        quote! {
182                            #field_name[#str_field_name]
183                        }
184                    })
185                    .collect();
186                let change_struct = quote::format_ident!("{}Ref", struct_name);
187                let mut pk_fields = vec![];
188                for field in fields_name.named.iter() {
189                    let field_name = field.ident.as_ref().unwrap();
190                    if table_pk.contains(&field_name.to_string()) {
191                        let str_field_name = match get_sqlx_field_rename(&field.attrs) {
192                            Some(str) => str,
193                            _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
194                        };
195                        pk_fields.push(quote! {
196                            #field_name[#str_field_name]
197                        });
198                    }
199                }
200                if pk_fields.is_empty() {
201                    if let Some(field) = fields_name.named.iter().next() {
202                        let field_name = field.ident.as_ref().unwrap();
203                        let str_field_name = match get_sqlx_field_rename(&field.attrs) {
204                            Some(str) => str,
205                            _ => change_sqlx_field_rename(&rename_all, field_name.to_string()),
206                        };
207                        pk_fields.push(quote! {
208                            #field_name[#str_field_name]
209                        });
210                    }
211                }
212                let implemented_show = quote! {
213                    #input
214                    sqlx_model::model_table_value_bind_define!(sqlx::#db_type,#struct_name,#table_name,{#(#bind_fields),*},{#(#pk_fields),*});
215                    sqlx_model::model_table_ref_define!(sqlx::#db_type,#struct_name,#change_struct,{#(#change_fields),*});
216                };
217                implemented_show
218            } else {
219                panic!("sorry, may it's a complicated struct.");
220            }
221        }
222        _ => panic!("sorry, Show is not implemented for union or enum type."),
223    };
224    expanded.into()
225}
226
227#[proc_macro_attribute]
228// model 自动填充方法
229pub fn sqlx_model_status(args: TokenStream, item: TokenStream) -> TokenStream {
230    let input = parse_macro_input!(item as DeriveInput);
231    let struct_name = &input.ident;
232    let args = syn::parse_macro_input!(args as syn::AttributeArgs);
233    let mut field_type = None;
234    // for attr in input
235    //     .attrs
236    //     .iter()
237    //     .filter(|e| e.path.is_ident("sqlx_model_status"))
238    // {
239    //     let meta = attr
240    //         .parse_meta()
241    //         .map_err(|e| syn::Error::new_spanned(attr, e))
242    //         .unwrap();
243    //     if let Meta::List(list) = meta {
244    for cattr in args.iter() {
245        if let NestedMeta::Meta(Meta::NameValue(ref attr_ident)) = cattr {
246            let name = attr_ident.clone();
247            let name = name.path.get_ident().unwrap().to_string();
248            let name = name.as_str();
249            let ident = attr_ident.clone();
250            if name == "field_type" {
251                field_type = Some(
252                    match ident.lit {
253                        syn::Lit::Str(val) => val,
254                        _ => unreachable!("status type must be string"),
255                    }
256                    .value(),
257                );
258            }
259        }
260    }
261    let field_type = field_type.expect("status type not set");
262    //     }
263    // }
264    let field_type = quote::format_ident!("{}", field_type);
265    let expanded = match input.data {
266        Data::Enum(DataEnum { ref variants, .. }) => {
267            let fields: Vec<_> = variants
268                .iter()
269                .map(|field| {
270                    let field_name = field.ident.clone();
271                    quote! {
272                        #struct_name::#field_name
273                    }
274                })
275                .collect();
276            quote! {
277                #input
278                sqlx_model::model_enum_status_define!(#struct_name,#field_type,{#(#fields),*});
279            }
280        }
281        _ => panic!("sorry, Show is not implemented for union or enum type."),
282    };
283    expanded.into()
284}