rusql_alchemy_macro/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit};
4
5#[proc_macro_derive(Model, attributes(model))]
6pub fn model_derive(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = input.ident;
9
10    let fields = match input.data {
11        Data::Struct(ref data) => match data.fields {
12            Fields::Named(ref fields) => &fields.named,
13            _ => panic!("Model derive macro only supports structs with named fields"),
14        },
15        _ => panic!("Model derive macro only supports structs"),
16    };
17
18    let mut schema_fields = Vec::new();
19    let mut create_args = Vec::new();
20    let mut update_args = Vec::new();
21
22    let mut the_primary_key = quote! {};
23
24    for field in fields {
25        let field_name = field.ident.as_ref().unwrap();
26        let field_type = match &field.ty {
27            syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.to_string(),
28            _ => panic!("Unsupported field type"),
29        };
30
31        let mut is_nullable = true;
32        let mut is_primary_key = false;
33        let mut is_auto = false;
34        let mut is_unique = false;
35        let mut is_default = false;
36        let mut size = None;
37        let mut default = quote! {};
38        let mut foreign_key = quote! {};
39
40        for attr in &field.attrs {
41            if attr.path.is_ident("model") {
42                let meta = attr.parse_meta().unwrap();
43                if let syn::Meta::List(ref list) = meta {
44                    for nested in &list.nested {
45                        if let syn::NestedMeta::Meta(syn::Meta::NameValue(ref nv)) = nested {
46                            if nv.path.is_ident("primary_key") {
47                                if let Lit::Bool(ref lit) = nv.lit {
48                                    the_primary_key = quote! { #field_name.clone() };
49                                    is_primary_key = lit.value;
50                                }
51                            } else if nv.path.is_ident("auto") {
52                                if let Lit::Bool(ref lit) = nv.lit {
53                                    is_auto = lit.value;
54                                }
55                            } else if nv.path.is_ident("size") {
56                                if let Lit::Int(ref lit) = nv.lit {
57                                    size = Some(lit.clone());
58                                }
59                            } else if nv.path.is_ident("unique") {
60                                if let Lit::Bool(ref lit) = nv.lit {
61                                    is_unique = lit.value;
62                                }
63                            } else if nv.path.is_ident("null") {
64                                if let Lit::Bool(ref lit) = nv.lit {
65                                    is_nullable = lit.value;
66                                }
67                            } else if nv.path.is_ident("default") {
68                                is_default = true;
69                                if let Lit::Str(ref str) = nv.lit {
70                                    default = if str.value() == "now" {
71                                        if field_type == "Date" {
72                                            quote! { default current_date}
73                                        } else if field_type == "DateTime" {
74                                            quote! { default current_timestamp}
75                                        } else {
76                                            panic!("'now' is work only with Date or DateTime");
77                                        }
78                                    } else {
79                                        let str = format!("'{str}'", str = str.value());
80                                        quote! { default #str }
81                                    }
82                                } else if let Lit::Bool(ref bool) = nv.lit {
83                                    default = if bool.value {
84                                        quote! {default 1}
85                                    } else {
86                                        quote! {default 0}
87                                    };
88                                } else if let Lit::Int(ref int) = nv.lit {
89                                    default = quote! { default #int }
90                                }
91                            } else if nv.path.is_ident("foreign_key") {
92                                if let Lit::Str(ref lit) = nv.lit {
93                                    let fk = lit.value();
94                                    let foreign_key_parts: Vec<&str> = fk.split('.').collect();
95                                    if foreign_key_parts.len() != 2 {
96                                        panic!("Invalid foreign key");
97                                    }
98                                    let foreign_key_table = foreign_key_parts[0];
99                                    let foreign_key_field = foreign_key_parts[1];
100
101                                    foreign_key = quote! {
102                                         references #foreign_key_table(#foreign_key_field)
103                                    };
104                                }
105                            }
106                        }
107                    }
108                }
109            }
110        }
111
112        let field_schema = {
113            let base_type = match field_type.as_str() {
114                "Serial" => quote! { serial },
115                "Integer" => quote! { integer },
116                "String" => {
117                    if let Some(size) = size {
118                        quote! {varchar(#size)}
119                    } else {
120                        quote! {varchar(255)}
121                    }
122                }
123                "Float" => quote! { float },
124                "Text" => quote! { text },
125                "Date" => quote! { varchar(10) },
126                "Boolean" => quote! { integer },
127                "DateTime" => quote! { varchar(40) },
128                 p_type => panic!(
129                    "Unexpected field type: '{}'. Expected one of: 'Serial', 'Integer', 'String', 'Float', 'Text', 'Date', 'Boolean', 'DateTime'. Please check the field type.",
130                    p_type
131                ),
132            };
133
134            let primary_key = if is_primary_key {
135                let auto = if is_auto {
136                    quote! { autoincrement }
137                } else if field_type.as_str() == "Serial" {
138                    quote! {}
139                } else {
140                    create_args.push(quote! { #field_name });
141                    quote! {}
142                };
143                quote! { primary key #auto}
144            } else {
145                create_args.push(quote! { #field_name });
146                update_args.push(quote! { #field_name });
147                quote! {}
148            };
149
150            if is_default {
151                create_args.pop();
152            }
153
154            let nullable = if is_nullable {
155                quote! {}
156            } else {
157                quote! {not null}
158            };
159            let unique = if is_unique {
160                quote! { unique }
161            } else {
162                quote! {}
163            };
164
165            quote! { #field_name #base_type #primary_key #unique #default #nullable #foreign_key }
166        };
167
168        schema_fields.push(field_schema);
169    }
170
171    let primary_key = {
172        let pk = the_primary_key.to_string().replace(".clone()", "");
173        quote! {
174            const PK: &'static str = #pk;
175        }
176    };
177
178    let schema = {
179        let fields = schema_fields
180            .iter()
181            .map(|f| f.to_string())
182            .collect::<Vec<_>>()
183            .join(", ");
184
185        let schema = format!("create table if not exists {name} ({fields});").replace('"', "");
186
187        quote! {
188            const SCHEMA: &'static str = #schema;
189        }
190    };
191
192    let create = quote! {
193        async fn save(&self, conn: &Connection) -> bool {
194            Self::create(
195                kwargs!(
196                    #(#create_args = self.#create_args),*
197                ),
198                conn,
199            )
200            .await
201        }
202    };
203
204    let update = quote! {
205        async fn update(&self, conn: &Connection) -> bool {
206            Self::set(
207                self.#the_primary_key,
208                kwargs!(
209                    #(#update_args = self.#update_args),*
210                ),
211                conn,
212            )
213            .await
214        }
215    };
216
217    let delete = {
218        let query =
219            format!("delete from {name} where {the_primary_key}=?1;").replace(".clone()", "");
220        quote! {
221            async fn delete(&self, conn: &Connection) -> bool {
222                let placeholder = rusql_alchemy::PLACEHOLDER.to_string();
223                sqlx::query(&#query.replace("?", &placeholder).replace("$", &placeholder))
224                    .bind(self.#the_primary_key)
225                    .execute(conn)
226                    .await
227                    .is_ok()
228            }
229        }
230    };
231
232    let expanded = quote! {
233        #[async_trait]
234        impl Model for #name {
235            const NAME: &'static str = stringify!(#name);
236            #schema
237            #primary_key
238            #create
239            #update
240            #delete
241        }
242    };
243
244    TokenStream::from(expanded)
245}