rusql_alchemy_macro/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Data, DeriveInput, Fields, Lit};
4
5#[proc_macro_derive(Model, attributes(model))]
6pub fn model_derive(input: TokenStream) -> TokenStream {
7 let input = parse_macro_input!(input as DeriveInput);
8 let name = input.ident;
9
10 let fields = match input.data {
11 Data::Struct(ref data) => match data.fields {
12 Fields::Named(ref fields) => &fields.named,
13 _ => panic!("Model derive macro only supports structs with named fields"),
14 },
15 _ => panic!("Model derive macro only supports structs"),
16 };
17
18 let mut schema_fields = Vec::new();
19 let mut create_args = Vec::new();
20 let mut update_args = Vec::new();
21
22 let mut the_primary_key = quote! {};
23
24 for field in fields {
25 let field_name = field.ident.as_ref().unwrap();
26 let field_type = match &field.ty {
27 syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.to_string(),
28 _ => panic!("Unsupported field type"),
29 };
30
31 let mut is_nullable = true;
32 let mut is_primary_key = false;
33 let mut is_auto = false;
34 let mut is_unique = false;
35 let mut is_default = false;
36 let mut size = None;
37 let mut default = quote! {};
38 let mut foreign_key = quote! {};
39
40 for attr in &field.attrs {
41 if attr.path.is_ident("model") {
42 let meta = attr.parse_meta().unwrap();
43 if let syn::Meta::List(ref list) = meta {
44 for nested in &list.nested {
45 if let syn::NestedMeta::Meta(syn::Meta::NameValue(ref nv)) = nested {
46 if nv.path.is_ident("primary_key") {
47 if let Lit::Bool(ref lit) = nv.lit {
48 the_primary_key = quote! { #field_name.clone() };
49 is_primary_key = lit.value;
50 }
51 } else if nv.path.is_ident("auto") {
52 if let Lit::Bool(ref lit) = nv.lit {
53 is_auto = lit.value;
54 }
55 } else if nv.path.is_ident("size") {
56 if let Lit::Int(ref lit) = nv.lit {
57 size = Some(lit.clone());
58 }
59 } else if nv.path.is_ident("unique") {
60 if let Lit::Bool(ref lit) = nv.lit {
61 is_unique = lit.value;
62 }
63 } else if nv.path.is_ident("null") {
64 if let Lit::Bool(ref lit) = nv.lit {
65 is_nullable = lit.value;
66 }
67 } else if nv.path.is_ident("default") {
68 is_default = true;
69 if let Lit::Str(ref str) = nv.lit {
70 default = if str.value() == "now" {
71 if field_type == "Date" {
72 quote! { default current_date}
73 } else if field_type == "DateTime" {
74 quote! { default current_timestamp}
75 } else {
76 panic!("'now' is work only with Date or DateTime");
77 }
78 } else {
79 let str = format!("'{str}'", str = str.value());
80 quote! { default #str }
81 }
82 } else if let Lit::Bool(ref bool) = nv.lit {
83 default = if bool.value {
84 quote! {default 1}
85 } else {
86 quote! {default 0}
87 };
88 } else if let Lit::Int(ref int) = nv.lit {
89 default = quote! { default #int }
90 }
91 } else if nv.path.is_ident("foreign_key") {
92 if let Lit::Str(ref lit) = nv.lit {
93 let fk = lit.value();
94 let foreign_key_parts: Vec<&str> = fk.split('.').collect();
95 if foreign_key_parts.len() != 2 {
96 panic!("Invalid foreign key");
97 }
98 let foreign_key_table = foreign_key_parts[0];
99 let foreign_key_field = foreign_key_parts[1];
100
101 foreign_key = quote! {
102 references #foreign_key_table(#foreign_key_field)
103 };
104 }
105 }
106 }
107 }
108 }
109 }
110 }
111
112 let field_schema = {
113 let base_type = match field_type.as_str() {
114 "Serial" => quote! { serial },
115 "Integer" => quote! { integer },
116 "String" => {
117 if let Some(size) = size {
118 quote! {varchar(#size)}
119 } else {
120 quote! {varchar(255)}
121 }
122 }
123 "Float" => quote! { float },
124 "Text" => quote! { text },
125 "Date" => quote! { varchar(10) },
126 "Boolean" => quote! { integer },
127 "DateTime" => quote! { varchar(40) },
128 p_type => panic!(
129 "Unexpected field type: '{}'. Expected one of: 'Serial', 'Integer', 'String', 'Float', 'Text', 'Date', 'Boolean', 'DateTime'. Please check the field type.",
130 p_type
131 ),
132 };
133
134 let primary_key = if is_primary_key {
135 let auto = if is_auto {
136 quote! { autoincrement }
137 } else if field_type.as_str() == "Serial" {
138 quote! {}
139 } else {
140 create_args.push(quote! { #field_name });
141 quote! {}
142 };
143 quote! { primary key #auto}
144 } else {
145 create_args.push(quote! { #field_name });
146 update_args.push(quote! { #field_name });
147 quote! {}
148 };
149
150 if is_default {
151 create_args.pop();
152 }
153
154 let nullable = if is_nullable {
155 quote! {}
156 } else {
157 quote! {not null}
158 };
159 let unique = if is_unique {
160 quote! { unique }
161 } else {
162 quote! {}
163 };
164
165 quote! { #field_name #base_type #primary_key #unique #default #nullable #foreign_key }
166 };
167
168 schema_fields.push(field_schema);
169 }
170
171 let primary_key = {
172 let pk = the_primary_key.to_string().replace(".clone()", "");
173 quote! {
174 const PK: &'static str = #pk;
175 }
176 };
177
178 let schema = {
179 let fields = schema_fields
180 .iter()
181 .map(|f| f.to_string())
182 .collect::<Vec<_>>()
183 .join(", ");
184
185 let schema = format!("create table if not exists {name} ({fields});").replace('"', "");
186
187 quote! {
188 const SCHEMA: &'static str = #schema;
189 }
190 };
191
192 let create = quote! {
193 async fn save(&self, conn: &Connection) -> bool {
194 Self::create(
195 kwargs!(
196 #(#create_args = self.#create_args),*
197 ),
198 conn,
199 )
200 .await
201 }
202 };
203
204 let update = quote! {
205 async fn update(&self, conn: &Connection) -> bool {
206 Self::set(
207 self.#the_primary_key,
208 kwargs!(
209 #(#update_args = self.#update_args),*
210 ),
211 conn,
212 )
213 .await
214 }
215 };
216
217 let delete = {
218 let query =
219 format!("delete from {name} where {the_primary_key}=?1;").replace(".clone()", "");
220 quote! {
221 async fn delete(&self, conn: &Connection) -> bool {
222 let placeholder = rusql_alchemy::PLACEHOLDER.to_string();
223 sqlx::query(&#query.replace("?", &placeholder).replace("$", &placeholder))
224 .bind(self.#the_primary_key)
225 .execute(conn)
226 .await
227 .is_ok()
228 }
229 }
230 };
231
232 let expanded = quote! {
233 #[async_trait]
234 impl Model for #name {
235 const NAME: &'static str = stringify!(#name);
236 #schema
237 #primary_key
238 #create
239 #update
240 #delete
241 }
242 };
243
244 TokenStream::from(expanded)
245}