sqlrender_impl/
lib.rs

1//! This crate provides SqlRender's procedural macros.
2//!
3//! Please refer to the `sqlrender` crate for how to set this up.
4
5#![forbid(unsafe_code)]
6
7use once_cell::sync::Lazy;
8use proc_macro2::Span;
9use proc_macro_error::{abort, abort_call_site, proc_macro_error};
10use quote::{quote, ToTokens};
11use syn::punctuated::Punctuated;
12use syn::spanned::Spanned;
13use syn::{parse_macro_input, Data, DeriveInput, Fields, FieldsNamed, Ident, Meta, Token};
14
15mod ddl;
16mod delete;
17mod insert;
18mod misc;
19mod select;
20mod update;
21
22#[derive(Debug, Clone)]
23struct Table {
24	ident: Ident,
25	span: Span,
26	name: String,
27	columns: Vec<Column>,
28}
29
30impl ToTokens for Table {
31	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
32		let ident = &self.ident;
33		tokens.extend(quote!(#ident));
34	}
35}
36
37#[allow(dead_code)]
38#[derive(Debug, Clone)]
39struct Column {
40	ident: Ident,
41	span: Span,
42	name: String,
43	rust_type: String,
44	sql_type: &'static str,
45}
46
47static U8_ARRAY_RE: Lazy<regex::Regex> =
48	Lazy::new(|| regex::Regex::new(r"^Option < \[u8 ; \d+\] >$").unwrap());
49
50/// Derive this on a `struct` to create a corresponding table and `SqlRender` trait methods.
51#[proc_macro_derive(SqlRender, attributes(sqlrender))]
52#[proc_macro_error]
53pub fn sqlrender_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
54	if is_rust_analyzer() {
55		return quote!().into();
56	}
57
58	// parse tokenstream and set up table struct
59
60	let input = parse_macro_input!(input as DeriveInput);
61	let table_span = input.span();
62	let table_ident = input.ident;
63	let table_name = table_ident.to_string().to_lowercase();
64
65	let fields = match input.data {
66		Data::Struct(ref data) => match data.fields {
67			Fields::Named(ref fields) => fields,
68			Fields::Unnamed(_) | Fields::Unit => unimplemented!(),
69		},
70		Data::Enum(_) | Data::Union(_) => unimplemented!(),
71	};
72
73	let table = Table {
74		ident: table_ident,
75		span: table_span,
76		name: table_name.clone(),
77		columns: extract_columns(fields),
78	};
79
80	// create trait functions
81
82	let fn_select = select::select(&table);
83	let fn_insert = insert::insert(&table);
84	let fn_update = update::update(&table);
85	let fn_delete = delete::delete(&table);
86	let fn_ddl = ddl::ddl(&table);
87	let fn_misc = misc::misc(&table);
88
89	// output tokenstream
90
91	let output = quote! {
92		#[cfg(not(target_arch = "wasm32"))]
93		impl ::sqlrender::SqlRender for #table {
94			#fn_select
95			#fn_insert
96			#fn_update
97			#fn_delete
98			#fn_ddl
99			#fn_misc
100		}
101	};
102
103	output.into()
104}
105
106/// Convert syn::FieldsNamed to our Column type.
107fn extract_columns(fields: &FieldsNamed) -> Vec<Column> {
108	let columns = fields
109		.named
110		.iter()
111		.filter_map(|f| {
112			// Skip (skip) fields
113
114			for attr in &f.attrs {
115				if attr.path().is_ident("sqlrender") {
116					for meta in attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated).unwrap() {
117						match meta {
118							Meta::Path(path) if path.is_ident("skip") => {
119								// TODO: For skipped fields, Handle derive(Default) requirement better
120								// require Option and manifest None values
121								return None;
122							}
123							_ => ()
124						}
125					}
126				}
127			}
128
129			let ident = &f.ident;
130			let name = ident.as_ref().unwrap().to_string();
131
132			let ty = &f.ty;
133			let ty_str = quote!(#ty).to_string();
134
135			let sql_type = match (
136				name.as_str(),
137				if U8_ARRAY_RE.is_match(&ty_str) { "Option < [u8; _] >" } else { ty_str.as_str() },
138			) {
139				("id", "Option < u64 >") => "BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY",
140				(_, "Option < i8 >") => "INT",
141				(_, "Option < u8 >") => "INT",
142				(_, "Option < i16 >") => "INT",
143				(_, "Option < u16 >") => "INT",
144				(_, "Option < i32 >") => "INT",
145				(_, "Option < u32 >") => "INT",
146				(_, "Option < i64 >") => "BIGINT",
147				(_, "Option < u64 >") => "BIGINT UNSIGNED",
148				(_, "Option < f64 >") => "DOUBLE",
149				(_, "Option < f32 >") => "DOUBLE",
150				(_, "Option < bool >") => "TINYINT",
151				(_, "Option < String >") => "LONGTEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
152				(_, "Option < DateTime < FixedOffset > >") => "TIMESTAMP",
153				// SELECT LENGTH(blob_column) ... will be null if blob is null
154				(_, "Option < Blob >") => "BLOB",
155				(_, "Option < Vec < u8 > >") => "BLOB",
156				(_, "Option < [u8; _] >") => "BLOB",
157				_ => {
158					if ty_str.starts_with("Option < ") {
159						"TEXT" // JSON-serialized
160					} else {
161						abort!(
162							ty,
163							"SqlRender types must be wrapped in Option for forward/backward schema compatibility. Try: Option<{}>",
164							ty_str
165						)
166					}
167				}
168			};
169
170			Some(Column {
171				ident: ident.clone().unwrap(),
172				span: ty.span(),
173				rust_type: ty_str,
174				name,
175				sql_type,
176			})
177		})
178		.collect::<Vec<_>>();
179
180	if !matches!(
181		columns.iter().find(|c| c.name == "id"),
182		Some(Column { sql_type: "BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY", .. })
183	) {
184		abort_call_site!("derive(SqlRender) structs must include a 'id: Option<u64>' field")
185	};
186
187	columns
188}
189
190fn is_rust_analyzer() -> bool {
191	std::env::current_exe()
192		.unwrap()
193		.file_stem()
194		.unwrap()
195		.to_string_lossy()
196		.starts_with("rust-analyzer")
197}
198
199#[cfg(test)]
200mod tests {
201	use super::*;
202	// use chrono::{DateTime, FixedOffset};
203	use syn::parse_quote;
204
205	#[test]
206	fn test_extract_columns() {
207		let fields_named = parse_quote!({
208			id: Option<u64>,
209			name: Option<String>,
210			age: Option<u32>,
211			awesomeness: Option<f64>,
212			#[sqlrender(skip)]
213			skipped: Option<bool>
214			// deleted_at: DateTime<FixedOffset>
215		});
216
217		let columns = extract_columns(&fields_named);
218
219		assert_eq!(columns.len(), 4);
220
221		assert_eq!(columns[0].name, "id");
222		assert_eq!(columns[0].rust_type, "Option < u64 >");
223		assert_eq!(columns[0].sql_type, "BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY");
224
225		assert_eq!(columns[1].name, "name");
226		assert_eq!(columns[1].rust_type, "Option < String >");
227		assert_eq!(columns[1].sql_type, "LONGTEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci");
228
229		assert_eq!(columns[2].name, "age");
230		assert_eq!(columns[2].rust_type, "Option < u32 >");
231		assert_eq!(columns[2].sql_type, "INT");
232
233		assert_eq!(columns[3].name, "awesomeness");
234		assert_eq!(columns[3].rust_type, "Option < f64 >");
235		assert_eq!(columns[3].sql_type, "DOUBLE");
236
237		assert!(!columns.iter().any(|c| c.name == "skipped"));
238	}
239}