turbosql_impl/
lib.rs

1//! This crate provides Turbosql's procedural macros.
2//!
3//! Please refer to the `turbosql` crate for how to set this up.
4
5#![forbid(unsafe_code)]
6
7const SQLITE_U64_ERROR: &str = r##"SQLite cannot natively store unsigned 64-bit integers, so Turbosql does not support u64 fields. Use i64, u32, f64, or a string or binary format instead. (see https://github.com/trevyn/turbosql/issues/3 )"##;
8
9use once_cell::sync::Lazy;
10use proc_macro2::Span;
11use proc_macro_error::{abort, abort_call_site, proc_macro_error};
12use quote::{format_ident, quote, ToTokens};
13use rusqlite::{params, Connection, Statement};
14use serde::{Deserialize, Serialize};
15use std::collections::BTreeMap;
16use syn::{
17	parse::{Parse, ParseStream},
18	punctuated::Punctuated,
19	spanned::Spanned,
20	*,
21};
22
23#[cfg(not(feature = "test"))]
24const MIGRATIONS_FILENAME: &str = "migrations.toml";
25#[cfg(feature = "test")]
26const MIGRATIONS_FILENAME: &str = "test.migrations.toml";
27
28mod delete;
29mod insert;
30mod update;
31
32#[derive(Debug, Clone)]
33struct Table {
34	ident: Ident,
35	span: Span,
36	name: String,
37	columns: Vec<Column>,
38}
39
40#[derive(Clone, Serialize, Deserialize, Debug)]
41struct MiniTable {
42	name: String,
43	columns: Vec<MiniColumn>,
44}
45
46impl ToTokens for Table {
47	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
48		let ident = &self.ident;
49		tokens.extend(quote!(#ident));
50	}
51}
52
53#[derive(Debug, Clone)]
54struct Column {
55	ident: Ident,
56	span: Span,
57	name: String,
58	rust_type: String,
59	sql_type: &'static str,
60	sql_default: Option<String>,
61}
62
63#[derive(Clone, Serialize, Deserialize, Debug)]
64struct MiniColumn {
65	name: String,
66	rust_type: String,
67	sql_type: String,
68}
69
70static OPTION_U8_ARRAY_RE: Lazy<regex::Regex> =
71	Lazy::new(|| regex::Regex::new(r"^Option\s*<\s*\[\s*u8\s*;\s*\d+\s*\]\s*>$").unwrap());
72static U8_ARRAY_RE: Lazy<regex::Regex> =
73	Lazy::new(|| regex::Regex::new(r"^\[\s*u8\s*;\s*\d+\s*\]$").unwrap());
74
75#[derive(Clone, Debug)]
76struct SingleColumn {
77	table: Ident,
78	column: Ident,
79}
80
81#[derive(Clone, Debug)]
82enum Content {
83	Type(Type),
84	#[allow(dead_code)]
85	SingleColumn(SingleColumn),
86}
87
88impl ToTokens for Content {
89	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
90		match self {
91			Content::Type(ty) => ty.to_tokens(tokens),
92			Content::SingleColumn(_) => unimplemented!(),
93		}
94	}
95}
96
97impl Content {
98	fn ty(&self) -> Result<Type> {
99		match self {
100			Content::Type(ty) => Ok(ty.clone()),
101			Content::SingleColumn(_) => unimplemented!(), // parse_str(&c.rust_type),
102		}
103	}
104	fn is_primitive(&self) -> bool {
105		match self {
106			Content::Type(Type::Path(TypePath { path, .. })) => {
107				["f32", "f64", "i8", "u8", "i16", "u16", "i32", "u32", "i64", "String", "bool", "Blob"]
108					.contains(&path.segments.last().unwrap().ident.to_string().as_str())
109			}
110			Content::Type(Type::Array(_)) => true,
111			_ => abort_call_site!("Unsupported content type in is_primitive {:#?}", self),
112		}
113	}
114	fn table_ident(&self) -> &Ident {
115		match self {
116			Content::Type(Type::Path(TypePath { path, .. })) => &path.segments.last().unwrap().ident,
117			Content::SingleColumn(c) => &c.table,
118			_ => abort_call_site!("Unsupported content type in table_ident {:#?}", self),
119		}
120	}
121}
122
123#[derive(Clone, Debug)]
124struct ResultType {
125	container: Option<Ident>,
126	content: Content,
127}
128
129impl ResultType {
130	fn ty(&self) -> Result<Type> {
131		let content = self.content.ty()?;
132		match self.container {
133			Some(ref container) => Ok(parse_quote!(#container < #content >)),
134			None => Ok(content),
135		}
136	}
137}
138
139impl Parse for ResultType {
140	fn parse(input: ParseStream) -> Result<Self> {
141		input.parse::<Type>().map(|ty| match ty {
142			Type::Path(TypePath { qself: None, ref path }) => {
143				let path = path.segments.last().unwrap();
144				let mut container = None;
145				let content = match &path.ident {
146					ident if ["Vec", "Option"].contains(&ident.to_string().as_str()) => {
147						container = Some(ident.clone());
148						match path.arguments {
149							PathArguments::AngleBracketed(AngleBracketedGenericArguments { ref args, .. }) => {
150								(args.len() != 1).then(|| abort!(args, "Expected 1 argument, found {}", args.len()));
151								match args.first().unwrap() {
152									GenericArgument::Type(ty) => Content::Type(ty.clone()),
153									ty => abort!(ty, "Expected type, found {:?}", ty),
154								}
155							}
156							ref args => abort!(args, "Expected angle bracketed arguments, found {:?}", args),
157						}
158					}
159					_ => Content::Type(ty),
160				};
161				ResultType { container, content }
162			}
163			Type::Array(array) => ResultType { container: None, content: Content::Type(Type::Array(array)) },
164			ty => abort!(ty, "Unknown type {:?}", ty),
165		})
166	}
167}
168
169#[derive(Debug)]
170struct MembersAndCasters {
171	//  members: Vec<(Ident, Ident, usize)>,
172	//  struct_members: Vec<proc_macro2::TokenStream>,
173	row_casters: Vec<proc_macro2::TokenStream>,
174}
175
176impl MembersAndCasters {
177	fn create(members: Vec<(Ident, Ident, usize)>) -> MembersAndCasters {
178		// let struct_members: Vec<_> = members.iter().map(|(name, ty, _i)| quote!(#name: #ty)).collect();
179		let row_casters = members
180			.iter()
181			.map(|(name, _ty, i)| {
182				if name.to_string().ends_with("__serialized") {
183					let name = name.to_string();
184					let real_name = format_ident!("{}", name.strip_suffix("__serialized").unwrap());
185					quote!(#real_name: {
186						let string: String = row.get(#i)?;
187						::turbosql::serde_json::from_str(&string)?
188					})
189				} else {
190					quote!(#name: row.get(#i)?)
191				}
192			})
193			.collect::<Vec<_>>();
194
195		Self { row_casters }
196	}
197}
198
199fn _extract_explicit_members(columns: &[String]) -> Option<MembersAndCasters> {
200	// let members: Vec<_> = columns
201	//  .iter()
202	//  .enumerate()
203	//  .filter_map(|(i, cap)| {
204	//   let col_name = cap;
205	//   let mut parts: Vec<_> = col_name.split('_').collect();
206	//   if parts.len() < 2 {
207	//    return None;
208	//   }
209	//   let ty = parts.pop()?;
210	//   let name = parts.join("_");
211	//   Some((format_ident!("{}", name), format_ident!("{}", ty), i))
212	//  })
213	//  .collect();
214
215	println!("extractexplicitmembers: {:#?}", columns);
216
217	// MembersAndCasters::create(members);
218	// parse_str::<Ident>
219
220	None
221}
222
223fn _extract_stmt_members(stmt: &Statement, span: &Span) -> MembersAndCasters {
224	let members: Vec<_> = stmt
225		.column_names()
226		.iter()
227		.enumerate()
228		.map(|(i, col_name)| {
229			let mut parts: Vec<_> = col_name.split('_').collect();
230
231			if parts.len() < 2 {
232				abort!(
233					span,
234					"SQL column name {:#?} must include a type annotation, e.g. {}_String or {}_i64.",
235					col_name,
236					col_name,
237					col_name
238				)
239			}
240
241			let ty = parts.pop().unwrap();
242
243			match ty {
244				"i64" | "String" => (),
245				_ => abort!(span, "Invalid type annotation \"_{}\", try e.g. _String or _i64.", ty),
246			}
247
248			let name = parts.join("_");
249
250			(format_ident!("{}", name), format_ident!("{}", ty), i)
251		})
252		.collect();
253
254	// let struct_members: Vec<_> = members.iter().map(|(name, ty, _i)| quote!(#name: #ty)).collect();
255	// let row_casters: Vec<_> =
256	//  members.iter().map(|(name, _ty, i)| quote!(#name: row.get(#i).unwrap())).collect();
257
258	MembersAndCasters::create(members)
259}
260
261const SELECT: usize = 1;
262const EXECUTE: usize = 2;
263const UPDATE: usize = 3;
264
265#[derive(Debug)]
266struct StatementInfo {
267	positional_parameter_count: usize,
268	named_parameters: Vec<String>,
269	column_names: Vec<String>,
270}
271
272impl StatementInfo {
273	fn membersandcasters(&self) -> Result<MembersAndCasters> {
274		Ok(MembersAndCasters::create(
275			self
276				.column_names
277				.iter()
278				.enumerate()
279				.map(|(i, col_name)| Ok((parse_str::<Ident>(col_name)?, format_ident!("None"), i)))
280				.collect::<Result<Vec<_>>>()?,
281		))
282	}
283}
284
285#[derive(Clone, Debug, Serialize, Deserialize, Default)]
286struct MigrationsToml {
287	migrations_append_only: Option<Vec<String>>,
288	output_generated_schema_for_your_information_do_not_edit: Option<String>,
289	output_generated_tables_do_not_edit: Option<BTreeMap<String, MiniTable>>,
290}
291
292fn migrations_to_tempdb(migrations: &[String]) -> Connection {
293	let tempdb = rusqlite::Connection::open_in_memory().unwrap();
294
295	tempdb
296		.execute_batch(if cfg!(feature = "sqlite-compat-no-strict-tables") {
297			"CREATE TABLE _turbosql_migrations (rowid INTEGER PRIMARY KEY, migration TEXT NOT NULL);"
298		} else {
299			"CREATE TABLE _turbosql_migrations (rowid INTEGER PRIMARY KEY, migration TEXT NOT NULL) STRICT;"
300		})
301		.unwrap();
302
303	migrations.iter().filter(|m| !m.starts_with("--")).for_each(|m| {
304		match tempdb.execute(m, params![]) {
305			Ok(_) => (),
306			Err(rusqlite::Error::ExecuteReturnedResults) => (), // pragmas
307			Err(e) => abort_call_site!("Running migrations on temp db: {:?} {:?}", m, e),
308		}
309	});
310
311	tempdb
312}
313
314fn migrations_to_schema(migrations: &[String]) -> rusqlite::Result<String> {
315	Ok(
316		migrations_to_tempdb(migrations)
317			.prepare("SELECT sql FROM sqlite_master WHERE type='table' ORDER BY sql")?
318			.query_map(params![], |row| row.get(0))?
319			.collect::<rusqlite::Result<Vec<String>>>()?
320			.join("\n"),
321	)
322}
323
324fn read_migrations_toml() -> MigrationsToml {
325	let lockfile = std::fs::File::create(std::env::temp_dir().join("migrations.toml.lock")).unwrap();
326	fs2::FileExt::lock_exclusive(&lockfile).unwrap();
327
328	let migrations_toml_path = migrations_toml_path();
329	let migrations_toml_path_lossy = migrations_toml_path.to_string_lossy();
330
331	match migrations_toml_path.exists() {
332		true => {
333			let toml_str = std::fs::read_to_string(&migrations_toml_path)
334				.unwrap_or_else(|e| abort_call_site!("Unable to read {}: {:?}", migrations_toml_path_lossy, e));
335
336			let toml_decoded: MigrationsToml = toml::from_str(&toml_str).unwrap_or_else(|e| {
337				abort_call_site!("Unable to decode toml in {}: {:?}", migrations_toml_path_lossy, e)
338			});
339
340			toml_decoded
341		}
342		false => MigrationsToml::default(),
343	}
344}
345
346fn validate_sql<S: AsRef<str>>(sql: S) -> rusqlite::Result<StatementInfo> {
347	let tempdb = migrations_to_tempdb(&read_migrations_toml().migrations_append_only.unwrap());
348
349	let stmt = tempdb.prepare(sql.as_ref())?;
350	let mut positional_parameter_count = stmt.parameter_count();
351	let mut named_parameters = Vec::new();
352
353	for idx in 1..=stmt.parameter_count() {
354		if let Some(parameter_name) = stmt.parameter_name(idx) {
355			named_parameters.push(parameter_name.to_string());
356			positional_parameter_count -= 1;
357		}
358	}
359
360	Ok(StatementInfo {
361		positional_parameter_count,
362		named_parameters,
363		column_names: stmt.column_names().into_iter().map(str::to_string).collect(),
364	})
365}
366
367fn validate_sql_or_abort<S: AsRef<str> + std::fmt::Debug>(sql: S) -> StatementInfo {
368	validate_sql(sql.as_ref()).unwrap_or_else(|e| {
369		abort_call_site!(r#"Error validating SQL statement: "{}". SQL: {:?}"#, e, sql)
370	})
371}
372
373fn parse_interpolated_sql(
374	input: ParseStream,
375) -> Result<(Option<String>, Punctuated<Expr, Token![,]>, proc_macro2::TokenStream)> {
376	if input.is_empty() {
377		return Ok(Default::default());
378	}
379
380	let sql_token = input.parse::<LitStr>()?;
381	let mut sql = sql_token.value();
382
383	if let Ok(comma_token) = input.parse::<Token![,]>() {
384		let punctuated_tokens = input.parse_terminated(Expr::parse, Token![,])?;
385		return Ok((
386			Some(sql),
387			punctuated_tokens.clone(),
388			quote!(#sql_token #comma_token #punctuated_tokens),
389		));
390	}
391
392	let mut params = Punctuated::new();
393
394	loop {
395		while input.peek(LitStr) {
396			sql.push(' ');
397			sql.push_str(&input.parse::<LitStr>()?.value());
398		}
399
400		if input.is_empty() {
401			break;
402		}
403
404		params.push(input.parse()?);
405		sql.push_str(" ? ");
406		if input.parse::<Token![,]>().is_ok() {
407			sql.push(',');
408		}
409	}
410
411	Ok((Some(sql), params, Default::default()))
412}
413
414fn do_parse_tokens<const T: usize>(input: ParseStream) -> Result<proc_macro2::TokenStream> {
415	let span = input.span();
416	let result_type = input.parse::<ResultType>().ok();
417	let (mut sql, params, sql_and_parameters_tokens) = parse_interpolated_sql(input)?;
418
419	// Try validating SQL as-is
420
421	let mut stmt_info = sql.as_ref().and_then(|s| validate_sql(s).ok());
422
423	// Try adding SELECT or UPDATE if it didn't validate
424
425	if let (true, Some(orig_sql), None) = (T == SELECT || T == UPDATE, &sql, &stmt_info) {
426		let sql_modified = format!("{} {}", if T == SELECT { "SELECT" } else { "UPDATE" }, orig_sql);
427		if let Ok(stmt_info_modified) = validate_sql(&sql_modified) {
428			sql = Some(sql_modified);
429			stmt_info = Some(stmt_info_modified);
430		}
431	}
432
433	// rust-analyzer just gets the result type
434
435	if is_rust_analyzer() {
436		return Ok(if let Some(ty) = result_type {
437			let ty = ty.ty()?;
438			quote!(Ok({let x: #ty = Default::default(); x}))
439		} else {
440			quote!()
441		});
442	}
443
444	// If it still didn't validate and we have a non-inferred result type, try adding SELECT ... FROM
445
446	let (sql, stmt_info) = match (result_type.clone(), sql, stmt_info) {
447		//
448		// Have result type and SQL did not validate, try generating SELECT ... FROM
449		(Some(ResultType { content, .. }), sql, None) => {
450			let table_type = content.table_ident().to_string();
451			let table_name = table_type.to_lowercase();
452
453			let table = {
454				let t = match read_migrations_toml().output_generated_tables_do_not_edit {
455					Some(m) => m.get(&table_name).cloned(),
456					None => None,
457				};
458
459				match t {
460					Some(t) => t,
461					None => {
462						abort!(
463							span,
464							"Table {:?} not found. Does struct {} exist and have #[derive(Turbosql, Default)]?",
465							table_name,
466							table_type
467						);
468					}
469				}
470			};
471
472			let column_names_str = table
473				.columns
474				.iter()
475				.filter_map(|c| {
476					if match &content {
477						Content::SingleColumn(col) => col.column == c.name,
478						_ => true,
479					} {
480						if c.sql_type.starts_with("TEXT")
481							&& c.rust_type != "Option < String >"
482							&& c.rust_type != "String"
483						{
484							Some(format!("{} AS {}__serialized", c.name, c.name))
485						} else {
486							Some(c.name.clone())
487						}
488					} else {
489						None
490					}
491				})
492				.collect::<Vec<_>>()
493				.join(", ");
494
495			let sql = format!("SELECT {} FROM {} {}", column_names_str, table_name, sql.unwrap_or_default());
496
497			(sql.clone(), validate_sql_or_abort(sql))
498		}
499
500		// Otherwise, everything is validated, just unwrap
501		(_, Some(sql), Some(stmt_info)) => (sql, stmt_info),
502		(_, Some(sql), _) => abort_call_site!("sql did not validate: {}", sql),
503		_ => abort_call_site!("no predicate and no result type found"),
504	};
505
506	if !stmt_info.named_parameters.is_empty() {
507		abort_call_site!("SQLite named parameters not currently supported.");
508	}
509
510	if params.len() != stmt_info.positional_parameter_count {
511		abort!(
512			sql_and_parameters_tokens,
513			"Expected {} bound parameter{}, got {}: {:?}",
514			stmt_info.positional_parameter_count,
515			if stmt_info.positional_parameter_count == 1 { "" } else { "s" },
516			params.len(),
517			sql
518		);
519	}
520
521	if !input.is_empty() {
522		return Err(input.error("Expected parameters"));
523	}
524
525	let params = if stmt_info.named_parameters.is_empty() {
526		quote! { ::turbosql::params![#params] }
527	} else {
528		let param_quotes = stmt_info.named_parameters.iter().map(|p| {
529			let var_ident = format_ident!("{}", &p[1..]);
530			quote!(#p: &#var_ident,)
531		});
532		quote! { ::turbosql::named_params![#(#param_quotes),*] }
533	};
534
535	// if we return no columns, this should be an execute or update
536
537	if stmt_info.column_names.is_empty() {
538		if T == SELECT {
539			abort_call_site!("No rows returned from SQL, use execute! instead.");
540		}
541
542		return Ok(quote! {
543		{
544			(|| -> std::result::Result<usize, ::turbosql::Error> {
545				::turbosql::__TURBOSQL_DB.with(|db| {
546					let db = db.borrow_mut();
547					let mut stmt = db.prepare_cached(#sql)?;
548					Ok(stmt.execute(#params)?)
549				})
550			})()
551		}
552		});
553	}
554
555	if T != SELECT {
556		abort_call_site!("Rows returned from SQL, use select! instead.");
557	}
558
559	// Decide how to handle selected rows depending on content type.
560
561	let Some(ResultType { container, content }) = result_type else {
562		abort_call_site!("unknown result_type")
563	};
564
565	let handle_row;
566	let content_ty;
567
568	if content.is_primitive() {
569		handle_row = quote! { row.get(0)? };
570		content_ty = quote! { #content };
571	} else {
572		let MembersAndCasters { row_casters } = stmt_info
573			.membersandcasters()
574			.unwrap_or_else(|_| abort_call_site!("stmt_info.membersandcasters failed"));
575
576		handle_row = quote! {
577			#[allow(clippy::needless_update)]
578			#content {
579				#(#row_casters),*,
580				..Default::default()
581			}
582		};
583		content_ty = quote! { #content };
584	}
585
586	// Content::SingleColumn(col) => {
587	// 	handle_row = quote! { row.get(0)? };
588	// 	let rust_ty: Type = parse_str(
589	// 		&read_migrations_toml()
590	// 			.output_generated_tables_do_not_edit
591	// 			.unwrap()
592	// 			.get(&col.table.to_string().to_lowercase())
593	// 			.unwrap()
594	// 			.columns
595	// 			.iter()
596	// 			.find(|c| col.column == c.name)
597	// 			.unwrap()
598	// 			.rust_type,
599	// 	)?;
600	// 	content_ty = quote! { #rust_ty };
601	// }
602	// };
603
604	// Decide how to handle the iterator over rows depending on container.
605
606	let return_type;
607	let handle_result;
608
609	match container {
610		Some(ident) if ident == "Vec" => {
611			return_type = quote! { Vec<#content_ty> };
612			handle_result = quote! { result.collect::<Vec<_>>() };
613		}
614		Some(ident) if ident == "Option" => {
615			return_type = quote! { Option<#content_ty> };
616			handle_result = quote! { result.next() };
617		}
618		None => {
619			return_type = quote! { #content_ty };
620			handle_result =
621				quote! { result.next().ok_or(::turbosql::rusqlite::Error::QueryReturnedNoRows)? };
622		}
623		_ => unreachable!("No other container type is possible"),
624	}
625
626	// Put it all together
627
628	Ok(quote! {
629		{
630			(|| -> std::result::Result<#return_type, ::turbosql::Error> {
631				::turbosql::__TURBOSQL_DB.with(|db| {
632					let db = db.borrow_mut();
633					let mut stmt = db.prepare_cached(#sql)?;
634					let mut result = stmt.query_and_then(#params, |row| -> std::result::Result<#content_ty, ::turbosql::Error> {
635						Ok(#handle_row)
636					})?.flatten();
637					Ok(#handle_result)
638				})
639			})()
640		}
641	})
642}
643
644/// Executes a SQL statement. On success, returns the number of rows that were changed or inserted or deleted.
645#[proc_macro]
646#[proc_macro_error]
647pub fn execute(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
648	parse_macro_input!(input with do_parse_tokens::<EXECUTE>).into()
649}
650
651/// Executes a SQL SELECT statement with optionally automatic `SELECT` and `FROM` clauses.
652#[proc_macro]
653#[proc_macro_error]
654pub fn select(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
655	parse_macro_input!(input with do_parse_tokens::<SELECT>).into()
656}
657
658/// Executes a SQL statement with optionally automatic `UPDATE` clause. On success, returns the number of rows that were changed.
659#[proc_macro]
660#[proc_macro_error]
661pub fn update(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
662	parse_macro_input!(input with do_parse_tokens::<UPDATE>).into()
663}
664
665/// Derive this on a `struct` to create a corresponding SQLite table and `Turbosql` trait methods.
666#[proc_macro_derive(Turbosql, attributes(turbosql))]
667#[proc_macro_error]
668pub fn turbosql_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
669	let input = parse_macro_input!(input as DeriveInput);
670	let table_span = input.span();
671	let table_ident = input.ident;
672	let table_name = table_ident.to_string().to_lowercase();
673
674	let dummy_impl = quote! {
675		impl ::turbosql::Turbosql for #table_ident {
676			fn insert(&self) -> Result<i64, ::turbosql::Error> { unimplemented!() }
677			fn insert_mut(&mut self) -> Result<i64, ::turbosql::Error> { unimplemented!() }
678			fn insert_batch<T: AsRef<Self>>(rows: &[T]) -> Result<(), ::turbosql::Error> { unimplemented!() }
679			fn update(&self) -> Result<usize, ::turbosql::Error> { unimplemented!() }
680			fn update_batch<T: AsRef<Self>>(rows: &[T]) -> Result<(), ::turbosql::Error> { unimplemented!() }
681			fn delete(&self) -> Result<usize, ::turbosql::Error> { unimplemented!() }
682		}
683	};
684
685	if is_rust_analyzer() {
686		return dummy_impl.into();
687	}
688
689	proc_macro_error::set_dummy(dummy_impl);
690
691	let Data::Struct(DataStruct { fields: Fields::Named(ref fields), .. }) = input.data else {
692		abort_call_site!("The Turbosql derive macro only supports structs with named fields");
693	};
694
695	let table = Table {
696		ident: table_ident,
697		span: table_span,
698		name: table_name.clone(),
699		columns: extract_columns(fields),
700	};
701
702	let minitable = MiniTable {
703		name: table_name,
704		columns: table
705			.columns
706			.iter()
707			.map(|c| MiniColumn {
708				name: c.name.clone(),
709				sql_type: c.sql_type.to_string(),
710				rust_type: c.rust_type.clone(),
711			})
712			.collect(),
713	};
714
715	create(&table, &minitable);
716
717	// create trait functions
718
719	let fn_insert = insert::insert(&table);
720	let fn_update = update::update(&table);
721	let fn_delete = delete::delete(&table);
722
723	// output tokenstream
724
725	quote! {
726		#[cfg(not(target_arch = "wasm32"))]
727		impl ::turbosql::Turbosql for #table {
728			#fn_insert
729			#fn_update
730			#fn_delete
731		}
732	}
733	.into()
734}
735
736/// Convert syn::FieldsNamed to our Column type.
737fn extract_columns(fields: &FieldsNamed) -> Vec<Column> {
738	let columns = fields
739		.named
740		.iter()
741		.filter_map(|f| {
742			let mut sql_default = None;
743
744			for attr in &f.attrs {
745				if attr.path().is_ident("turbosql") {
746					for meta in attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated).unwrap() {
747						match &meta {
748							Meta::Path(path) if path.is_ident("skip") => {
749								return None;
750							}
751							Meta::NameValue(MetaNameValue { path, value: Expr::Lit(ExprLit { lit, .. }), .. })
752								if path.is_ident("sql_default") =>
753							{
754								match lit {
755									Lit::Bool(value) => sql_default = Some(value.value().to_string()),
756									Lit::Int(token) => sql_default = Some(token.to_string()),
757									Lit::Float(token) => sql_default = Some(token.to_string()),
758									Lit::Str(token) => sql_default = Some(format!("'{}'", token.value())),
759									Lit::ByteStr(token) => {
760										use std::fmt::Write;
761										sql_default = Some(format!(
762											"x'{}'",
763											token.value().iter().fold(String::new(), |mut o, b| {
764												let _ = write!(o, "{b:02x}");
765												o
766											})
767										))
768									}
769									_ => (),
770								}
771							}
772							_ => (),
773						}
774					}
775				}
776			}
777
778			let ident = &f.ident;
779			let name = ident.as_ref().unwrap().to_string();
780
781			let ty = &f.ty;
782			let ty_str = quote!(#ty).to_string();
783
784			let (sql_type, default_example) = match (
785				name.as_str(),
786				if OPTION_U8_ARRAY_RE.is_match(&ty_str) {
787					"Option < [u8; _] >"
788				} else if U8_ARRAY_RE.is_match(&ty_str) {
789					"[u8; _]"
790				} else {
791					ty_str.as_str()
792				},
793			) {
794				("rowid", "Option < i64 >") => ("INTEGER PRIMARY KEY", "NULL"),
795				(_, "Option < i8 >") => ("INTEGER", "0"),
796				(_, "i8") => ("INTEGER NOT NULL", "0"),
797				(_, "Option < u8 >") => ("INTEGER", "0"),
798				(_, "u8") => ("INTEGER NOT NULL", "0"),
799				(_, "Option < i16 >") => ("INTEGER", "0"),
800				(_, "i16") => ("INTEGER NOT NULL", "0"),
801				(_, "Option < u16 >") => ("INTEGER", "0"),
802				(_, "u16") => ("INTEGER NOT NULL", "0"),
803				(_, "Option < i32 >") => ("INTEGER", "0"),
804				(_, "i32") => ("INTEGER NOT NULL", "0"),
805				(_, "Option < u32 >") => ("INTEGER", "0"),
806				(_, "u32") => ("INTEGER NOT NULL", "0"),
807				(_, "Option < i64 >") => ("INTEGER", "0"),
808				(_, "i64") => ("INTEGER NOT NULL", "0"),
809				(_, "Option < u64 >") => abort!(ty, SQLITE_U64_ERROR),
810				(_, "u64") => abort!(ty, SQLITE_U64_ERROR),
811				(_, "Option < f64 >") => ("REAL", "0.0"),
812				(_, "f64") => ("REAL NOT NULL", "0.0"),
813				(_, "Option < f32 >") => ("REAL", "0.0"),
814				(_, "f32") => ("REAL NOT NULL", "0.0"),
815				(_, "Option < bool >") => ("INTEGER", "false"),
816				(_, "bool") => ("INTEGER NOT NULL", "false"),
817				(_, "Option < String >") => ("TEXT", "\"\""),
818				(_, "String") => ("TEXT NOT NULL", "''"),
819				// SELECT LENGTH(blob_column) ... will be null if blob is null
820				(_, "Option < Blob >") => ("BLOB", "b\"\""),
821				(_, "Blob") => ("BLOB NOT NULL", "''"),
822				(_, "Option < Vec < u8 > >") => ("BLOB", "b\"\""),
823				(_, "Vec < u8 >") => ("BLOB NOT NULL", "''"),
824				(_, "Option < [u8; _] >") => ("BLOB", "b\"\\x00\\x01\\xff\""),
825				(_, "[u8; _]") => ("BLOB NOT NULL", "''"),
826				_ => {
827					// JSON-serialized
828					if ty_str.starts_with("Option < ") {
829						("TEXT", "\"\"")
830					} else {
831						("TEXT NOT NULL", "''")
832					}
833				}
834			};
835
836			if sql_default.is_none() && sql_type.ends_with("NOT NULL") {
837				sql_default = Some(default_example.into());
838				// abort!(f, "Field `{}` has no default value and is not nullable. Either add a default value with e.g. #[turbosql(sql_default = {default_example})] or make it Option<{ty_str}>.", name);
839			}
840
841			Some(Column {
842				ident: ident.clone().unwrap(),
843				span: ty.span(),
844				rust_type: ty_str,
845				name,
846				sql_type,
847				sql_default,
848			})
849		})
850		.collect::<Vec<_>>();
851
852	// Make sure we have a rowid column, to keep a persistent rowid for blob access.
853	// see https://www.sqlite.org/rowidtable.html :
854	// "If the rowid is not aliased by INTEGER PRIMARY KEY then it is not persistent and might change."
855
856	if !matches!(
857		columns.iter().find(|c| c.name == "rowid"),
858		Some(Column { sql_type: "INTEGER PRIMARY KEY", .. })
859	) {
860		abort_call_site!("derive(Turbosql) structs must include a 'rowid: Option<i64>' field")
861	};
862
863	columns
864}
865
866use std::fs;
867
868/// CREATE TABLE
869fn create(table: &Table, minitable: &MiniTable) {
870	// create the migrations
871
872	let sql = makesql_create(table);
873
874	rusqlite::Connection::open_in_memory().unwrap().execute(&sql, params![]).unwrap_or_else(|e| {
875		abort_call_site!("Error validating auto-generated CREATE TABLE statement: {} {:#?}", sql, e)
876	});
877
878	let target_migrations = make_migrations(table);
879
880	// read in the existing migrations from toml
881
882	let lockfile = std::fs::File::create(std::env::temp_dir().join("migrations.toml.lock")).unwrap();
883	fs2::FileExt::lock_exclusive(&lockfile).unwrap();
884
885	let migrations_toml_path = migrations_toml_path();
886	let migrations_toml_path_lossy = migrations_toml_path.to_string_lossy();
887
888	let old_toml_str = if migrations_toml_path.exists() {
889		fs::read_to_string(&migrations_toml_path)
890			.unwrap_or_else(|e| abort_call_site!("Unable to read {}: {:?}", migrations_toml_path_lossy, e))
891	} else {
892		String::new()
893	};
894
895	let source_migrations_toml: MigrationsToml = toml::from_str(&old_toml_str).unwrap_or_else(|e| {
896		abort_call_site!("Unable to decode toml in {}: {:?}", migrations_toml_path_lossy, e)
897	});
898
899	// add any migrations that aren't already present
900
901	let mut output_migrations = source_migrations_toml.migrations_append_only.unwrap_or_default();
902
903	#[allow(clippy::search_is_some)]
904	target_migrations.iter().for_each(|target_m| {
905		if output_migrations
906			.iter()
907			.find(|source_m| (source_m == &target_m) || (source_m == &&format!("--{}", target_m)))
908			.is_none()
909		{
910			output_migrations.push(target_m.clone());
911		}
912	});
913
914	let mut tables = source_migrations_toml.output_generated_tables_do_not_edit.unwrap_or_default();
915	tables.insert(table.name.clone(), minitable.clone());
916
917	// save to toml
918
919	let mut new_toml_str = String::new();
920	let serializer = toml::Serializer::pretty(&mut new_toml_str);
921
922	MigrationsToml {
923		output_generated_schema_for_your_information_do_not_edit: Some(format!(
924			"  {}\n",
925			migrations_to_schema(&output_migrations)
926				.unwrap()
927				.replace('\n', "\n  ")
928				.replace('(', "(\n    ")
929				.replace(", ", ",\n    ")
930				.replace(')', "\n  )")
931		)),
932		migrations_append_only: Some(output_migrations),
933		output_generated_tables_do_not_edit: Some(tables),
934	}
935	.serialize(serializer)
936	.unwrap_or_else(|e| abort_call_site!("Unable to serialize migrations toml: {:?}", e));
937
938	let new_toml_str = format!("# This file is auto-generated by Turbosql.\n# It is used to create and apply automatic schema migrations.\n# It should be checked into source control.\n# Modifying it by hand may be dangerous; see the docs.\n\n{}", &new_toml_str);
939
940	// Only write migrations.toml file if it has actually changed;
941	// this keeps file mod date clean so cargo doesn't pathologically rebuild
942
943	if old_toml_str.replace("\r\n", "\n") != new_toml_str {
944		#[cfg(not(feature = "test"))]
945		if std::env::var("CI").is_ok() || std::env::var("TURBOSQL_LOCKED_MODE").is_ok() {
946			abort_call_site!("Change in `{}` detected with CI or TURBOSQL_LOCKED_MODE environment variable set. Make sure your `migrations.toml` file is committed and up-to-date.", migrations_toml_path_lossy);
947		};
948		fs::write(&migrations_toml_path, new_toml_str)
949			.unwrap_or_else(|e| abort_call_site!("Unable to write {}: {:?}", migrations_toml_path_lossy, e));
950	}
951}
952
953fn makesql_create(table: &Table) -> String {
954	let columns =
955		table.columns.iter().map(|c| format!("{} {}", c.name, c.sql_type)).collect::<Vec<_>>().join(",");
956
957	if cfg!(feature = "sqlite-compat-no-strict-tables") {
958		format!("CREATE TABLE {} ({})", table.name, columns)
959	} else {
960		format!("CREATE TABLE {} ({}) STRICT", table.name, columns)
961	}
962}
963
964fn make_migrations(table: &Table) -> Vec<String> {
965	let sql = if cfg!(feature = "sqlite-compat-no-strict-tables") {
966		format!("CREATE TABLE {} (rowid INTEGER PRIMARY KEY)", table.name)
967	} else {
968		format!("CREATE TABLE {} (rowid INTEGER PRIMARY KEY) STRICT", table.name)
969	};
970
971	let mut vec = vec![sql];
972
973	let mut alters = table
974		.columns
975		.iter()
976		.filter_map(|c| match (c.name.as_str(), c.sql_type, &c.sql_default) {
977			("rowid", "INTEGER PRIMARY KEY", _) => None,
978			(_, _, None) => Some(format!("ALTER TABLE {} ADD COLUMN {} {}", table.name, c.name, c.sql_type)),
979			(_, _, Some(sql_default)) => Some(format!(
980				"ALTER TABLE {} ADD COLUMN {} {} DEFAULT {}",
981				table.name, c.name, c.sql_type, sql_default
982			)),
983		})
984		.collect::<Vec<_>>();
985
986	vec.append(&mut alters);
987
988	vec
989}
990
991fn migrations_toml_path() -> std::path::PathBuf {
992	let mut path = std::path::PathBuf::from(env!("OUT_DIR"));
993	while path.file_name() != Some(std::ffi::OsStr::new("target")) {
994		path.pop();
995	}
996	path.pop();
997	path.push(MIGRATIONS_FILENAME);
998	path
999}
1000
1001fn is_rust_analyzer() -> bool {
1002	std::env::current_exe()
1003		.unwrap()
1004		.file_stem()
1005		.unwrap()
1006		.to_string_lossy()
1007		.starts_with("rust-analyzer")
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012	use super::*;
1013
1014	#[test]
1015	fn test_extract_columns() {
1016		let fields_named = parse_quote!({
1017			rowid: Option<i64>,
1018			name: Option<String>,
1019			age: Option<u32>,
1020			awesomeness: Option<f64>,
1021			#[turbosql(skip)]
1022			skipped: Option<bool>
1023		});
1024
1025		let columns = extract_columns(&fields_named);
1026
1027		assert_eq!(columns.len(), 4);
1028
1029		assert_eq!(columns[0].name, "rowid");
1030		assert_eq!(columns[0].rust_type, "Option < i64 >");
1031		assert_eq!(columns[0].sql_type, "INTEGER PRIMARY KEY");
1032
1033		assert_eq!(columns[1].name, "name");
1034		assert_eq!(columns[1].rust_type, "Option < String >");
1035		assert_eq!(columns[1].sql_type, "TEXT");
1036
1037		assert_eq!(columns[2].name, "age");
1038		assert_eq!(columns[2].rust_type, "Option < u32 >");
1039		assert_eq!(columns[2].sql_type, "INTEGER");
1040
1041		assert_eq!(columns[3].name, "awesomeness");
1042		assert_eq!(columns[3].rust_type, "Option < f64 >");
1043		assert_eq!(columns[3].sql_type, "REAL");
1044
1045		assert!(!columns.iter().any(|c| c.name == "skipped"));
1046	}
1047}