1#![forbid(unsafe_code)]
6
7use once_cell::sync::Lazy;
8use proc_macro2::Span;
9use proc_macro_error::{abort, abort_call_site, proc_macro_error};
10use quote::{quote, ToTokens};
11use syn::punctuated::Punctuated;
12use syn::spanned::Spanned;
13use syn::{parse_macro_input, Data, DeriveInput, Fields, FieldsNamed, Ident, Meta, Token};
14
15mod ddl;
16mod delete;
17mod insert;
18mod misc;
19mod select;
20mod update;
21
22#[derive(Debug, Clone)]
23struct Table {
24 ident: Ident,
25 span: Span,
26 name: String,
27 columns: Vec<Column>,
28}
29
30impl ToTokens for Table {
31 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
32 let ident = &self.ident;
33 tokens.extend(quote!(#ident));
34 }
35}
36
37#[allow(dead_code)]
38#[derive(Debug, Clone)]
39struct Column {
40 ident: Ident,
41 span: Span,
42 name: String,
43 rust_type: String,
44 sql_type: &'static str,
45}
46
47static U8_ARRAY_RE: Lazy<regex::Regex> =
48 Lazy::new(|| regex::Regex::new(r"^Option < \[u8 ; \d+\] >$").unwrap());
49
50#[proc_macro_derive(SqlRender, attributes(sqlrender))]
52#[proc_macro_error]
53pub fn sqlrender_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54 if is_rust_analyzer() {
55 return quote!().into();
56 }
57
58 let input = parse_macro_input!(input as DeriveInput);
61 let table_span = input.span();
62 let table_ident = input.ident;
63 let table_name = table_ident.to_string().to_lowercase();
64
65 let fields = match input.data {
66 Data::Struct(ref data) => match data.fields {
67 Fields::Named(ref fields) => fields,
68 Fields::Unnamed(_) | Fields::Unit => unimplemented!(),
69 },
70 Data::Enum(_) | Data::Union(_) => unimplemented!(),
71 };
72
73 let table = Table {
74 ident: table_ident,
75 span: table_span,
76 name: table_name.clone(),
77 columns: extract_columns(fields),
78 };
79
80 let fn_select = select::select(&table);
83 let fn_insert = insert::insert(&table);
84 let fn_update = update::update(&table);
85 let fn_delete = delete::delete(&table);
86 let fn_ddl = ddl::ddl(&table);
87 let fn_misc = misc::misc(&table);
88
89 let output = quote! {
92 #[cfg(not(target_arch = "wasm32"))]
93 impl ::sqlrender::SqlRender for #table {
94 #fn_select
95 #fn_insert
96 #fn_update
97 #fn_delete
98 #fn_ddl
99 #fn_misc
100 }
101 };
102
103 output.into()
104}
105
106fn extract_columns(fields: &FieldsNamed) -> Vec<Column> {
108 let columns = fields
109 .named
110 .iter()
111 .filter_map(|f| {
112 for attr in &f.attrs {
115 if attr.path().is_ident("sqlrender") {
116 for meta in attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated).unwrap() {
117 match meta {
118 Meta::Path(path) if path.is_ident("skip") => {
119 return None;
122 }
123 _ => ()
124 }
125 }
126 }
127 }
128
129 let ident = &f.ident;
130 let name = ident.as_ref().unwrap().to_string();
131
132 let ty = &f.ty;
133 let ty_str = quote!(#ty).to_string();
134
135 let sql_type = match (
136 name.as_str(),
137 if U8_ARRAY_RE.is_match(&ty_str) { "Option < [u8; _] >" } else { ty_str.as_str() },
138 ) {
139 ("id", "Option < u64 >") => "BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY",
140 (_, "Option < i8 >") => "INT",
141 (_, "Option < u8 >") => "INT",
142 (_, "Option < i16 >") => "INT",
143 (_, "Option < u16 >") => "INT",
144 (_, "Option < i32 >") => "INT",
145 (_, "Option < u32 >") => "INT",
146 (_, "Option < i64 >") => "BIGINT",
147 (_, "Option < u64 >") => "BIGINT UNSIGNED",
148 (_, "Option < f64 >") => "DOUBLE",
149 (_, "Option < f32 >") => "DOUBLE",
150 (_, "Option < bool >") => "TINYINT",
151 (_, "Option < String >") => "LONGTEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
152 (_, "Option < DateTime < FixedOffset > >") => "TIMESTAMP",
153 (_, "Option < Blob >") => "BLOB",
155 (_, "Option < Vec < u8 > >") => "BLOB",
156 (_, "Option < [u8; _] >") => "BLOB",
157 _ => {
158 if ty_str.starts_with("Option < ") {
159 "TEXT" } else {
161 abort!(
162 ty,
163 "SqlRender types must be wrapped in Option for forward/backward schema compatibility. Try: Option<{}>",
164 ty_str
165 )
166 }
167 }
168 };
169
170 Some(Column {
171 ident: ident.clone().unwrap(),
172 span: ty.span(),
173 rust_type: ty_str,
174 name,
175 sql_type,
176 })
177 })
178 .collect::<Vec<_>>();
179
180 if !matches!(
181 columns.iter().find(|c| c.name == "id"),
182 Some(Column { sql_type: "BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY", .. })
183 ) {
184 abort_call_site!("derive(SqlRender) structs must include a 'id: Option<u64>' field")
185 };
186
187 columns
188}
189
190fn is_rust_analyzer() -> bool {
191 std::env::current_exe()
192 .unwrap()
193 .file_stem()
194 .unwrap()
195 .to_string_lossy()
196 .starts_with("rust-analyzer")
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use syn::parse_quote;
204
205 #[test]
206 fn test_extract_columns() {
207 let fields_named = parse_quote!({
208 id: Option<u64>,
209 name: Option<String>,
210 age: Option<u32>,
211 awesomeness: Option<f64>,
212 #[sqlrender(skip)]
213 skipped: Option<bool>
214 });
216
217 let columns = extract_columns(&fields_named);
218
219 assert_eq!(columns.len(), 4);
220
221 assert_eq!(columns[0].name, "id");
222 assert_eq!(columns[0].rust_type, "Option < u64 >");
223 assert_eq!(columns[0].sql_type, "BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY");
224
225 assert_eq!(columns[1].name, "name");
226 assert_eq!(columns[1].rust_type, "Option < String >");
227 assert_eq!(columns[1].sql_type, "LONGTEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci");
228
229 assert_eq!(columns[2].name, "age");
230 assert_eq!(columns[2].rust_type, "Option < u32 >");
231 assert_eq!(columns[2].sql_type, "INT");
232
233 assert_eq!(columns[3].name, "awesomeness");
234 assert_eq!(columns[3].rust_type, "Option < f64 >");
235 assert_eq!(columns[3].sql_type, "DOUBLE");
236
237 assert!(!columns.iter().any(|c| c.name == "skipped"));
238 }
239}