1#![forbid(unsafe_code)]
6
7const SQLITE_U64_ERROR: &str = r##"SQLite cannot natively store unsigned 64-bit integers, so Turbosql does not support u64 fields. Use i64, u32, f64, or a string or binary format instead. (see https://github.com/trevyn/turbosql/issues/3 )"##;
8
9use once_cell::sync::Lazy;
10use proc_macro2::Span;
11use proc_macro_error::{abort, abort_call_site, proc_macro_error};
12use quote::{format_ident, quote, ToTokens};
13use rusqlite::{params, Connection, Statement};
14use serde::{Deserialize, Serialize};
15use std::collections::BTreeMap;
16use syn::{
17 parse::{Parse, ParseStream},
18 punctuated::Punctuated,
19 spanned::Spanned,
20 *,
21};
22
23#[cfg(not(feature = "test"))]
24const MIGRATIONS_FILENAME: &str = "migrations.toml";
25#[cfg(feature = "test")]
26const MIGRATIONS_FILENAME: &str = "test.migrations.toml";
27
28mod delete;
29mod insert;
30mod update;
31
32#[derive(Debug, Clone)]
33struct Table {
34 ident: Ident,
35 span: Span,
36 name: String,
37 columns: Vec<Column>,
38}
39
40#[derive(Clone, Serialize, Deserialize, Debug)]
41struct MiniTable {
42 name: String,
43 columns: Vec<MiniColumn>,
44}
45
46impl ToTokens for Table {
47 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
48 let ident = &self.ident;
49 tokens.extend(quote!(#ident));
50 }
51}
52
53#[derive(Debug, Clone)]
54struct Column {
55 ident: Ident,
56 span: Span,
57 name: String,
58 rust_type: String,
59 sql_type: &'static str,
60 sql_default: Option<String>,
61}
62
63#[derive(Clone, Serialize, Deserialize, Debug)]
64struct MiniColumn {
65 name: String,
66 rust_type: String,
67 sql_type: String,
68}
69
70static OPTION_U8_ARRAY_RE: Lazy<regex::Regex> =
71 Lazy::new(|| regex::Regex::new(r"^Option\s*<\s*\[\s*u8\s*;\s*\d+\s*\]\s*>$").unwrap());
72static U8_ARRAY_RE: Lazy<regex::Regex> =
73 Lazy::new(|| regex::Regex::new(r"^\[\s*u8\s*;\s*\d+\s*\]$").unwrap());
74
75#[derive(Clone, Debug)]
76struct SingleColumn {
77 table: Ident,
78 column: Ident,
79}
80
81#[derive(Clone, Debug)]
82enum Content {
83 Type(Type),
84 #[allow(dead_code)]
85 SingleColumn(SingleColumn),
86}
87
88impl ToTokens for Content {
89 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
90 match self {
91 Content::Type(ty) => ty.to_tokens(tokens),
92 Content::SingleColumn(_) => unimplemented!(),
93 }
94 }
95}
96
97impl Content {
98 fn ty(&self) -> Result<Type> {
99 match self {
100 Content::Type(ty) => Ok(ty.clone()),
101 Content::SingleColumn(_) => unimplemented!(), }
103 }
104 fn is_primitive(&self) -> bool {
105 match self {
106 Content::Type(Type::Path(TypePath { path, .. })) => {
107 ["f32", "f64", "i8", "u8", "i16", "u16", "i32", "u32", "i64", "String", "bool", "Blob"]
108 .contains(&path.segments.last().unwrap().ident.to_string().as_str())
109 }
110 Content::Type(Type::Array(_)) => true,
111 _ => abort_call_site!("Unsupported content type in is_primitive {:#?}", self),
112 }
113 }
114 fn table_ident(&self) -> &Ident {
115 match self {
116 Content::Type(Type::Path(TypePath { path, .. })) => &path.segments.last().unwrap().ident,
117 Content::SingleColumn(c) => &c.table,
118 _ => abort_call_site!("Unsupported content type in table_ident {:#?}", self),
119 }
120 }
121}
122
123#[derive(Clone, Debug)]
124struct ResultType {
125 container: Option<Ident>,
126 content: Content,
127}
128
129impl ResultType {
130 fn ty(&self) -> Result<Type> {
131 let content = self.content.ty()?;
132 match self.container {
133 Some(ref container) => Ok(parse_quote!(#container < #content >)),
134 None => Ok(content),
135 }
136 }
137}
138
139impl Parse for ResultType {
140 fn parse(input: ParseStream) -> Result<Self> {
141 input.parse::<Type>().map(|ty| match ty {
142 Type::Path(TypePath { qself: None, ref path }) => {
143 let path = path.segments.last().unwrap();
144 let mut container = None;
145 let content = match &path.ident {
146 ident if ["Vec", "Option"].contains(&ident.to_string().as_str()) => {
147 container = Some(ident.clone());
148 match path.arguments {
149 PathArguments::AngleBracketed(AngleBracketedGenericArguments { ref args, .. }) => {
150 (args.len() != 1).then(|| abort!(args, "Expected 1 argument, found {}", args.len()));
151 match args.first().unwrap() {
152 GenericArgument::Type(ty) => Content::Type(ty.clone()),
153 ty => abort!(ty, "Expected type, found {:?}", ty),
154 }
155 }
156 ref args => abort!(args, "Expected angle bracketed arguments, found {:?}", args),
157 }
158 }
159 _ => Content::Type(ty),
160 };
161 ResultType { container, content }
162 }
163 Type::Array(array) => ResultType { container: None, content: Content::Type(Type::Array(array)) },
164 ty => abort!(ty, "Unknown type {:?}", ty),
165 })
166 }
167}
168
169#[derive(Debug)]
170struct MembersAndCasters {
171 row_casters: Vec<proc_macro2::TokenStream>,
174}
175
176impl MembersAndCasters {
177 fn create(members: Vec<(Ident, Ident, usize)>) -> MembersAndCasters {
178 let row_casters = members
180 .iter()
181 .map(|(name, _ty, i)| {
182 if name.to_string().ends_with("__serialized") {
183 let name = name.to_string();
184 let real_name = format_ident!("{}", name.strip_suffix("__serialized").unwrap());
185 quote!(#real_name: {
186 let string: String = row.get(#i)?;
187 ::turbosql::serde_json::from_str(&string)?
188 })
189 } else {
190 quote!(#name: row.get(#i)?)
191 }
192 })
193 .collect::<Vec<_>>();
194
195 Self { row_casters }
196 }
197}
198
199fn _extract_explicit_members(columns: &[String]) -> Option<MembersAndCasters> {
200 println!("extractexplicitmembers: {:#?}", columns);
216
217 None
221}
222
223fn _extract_stmt_members(stmt: &Statement, span: &Span) -> MembersAndCasters {
224 let members: Vec<_> = stmt
225 .column_names()
226 .iter()
227 .enumerate()
228 .map(|(i, col_name)| {
229 let mut parts: Vec<_> = col_name.split('_').collect();
230
231 if parts.len() < 2 {
232 abort!(
233 span,
234 "SQL column name {:#?} must include a type annotation, e.g. {}_String or {}_i64.",
235 col_name,
236 col_name,
237 col_name
238 )
239 }
240
241 let ty = parts.pop().unwrap();
242
243 match ty {
244 "i64" | "String" => (),
245 _ => abort!(span, "Invalid type annotation \"_{}\", try e.g. _String or _i64.", ty),
246 }
247
248 let name = parts.join("_");
249
250 (format_ident!("{}", name), format_ident!("{}", ty), i)
251 })
252 .collect();
253
254 MembersAndCasters::create(members)
259}
260
261const SELECT: usize = 1;
262const EXECUTE: usize = 2;
263const UPDATE: usize = 3;
264
265#[derive(Debug)]
266struct StatementInfo {
267 positional_parameter_count: usize,
268 named_parameters: Vec<String>,
269 column_names: Vec<String>,
270}
271
272impl StatementInfo {
273 fn membersandcasters(&self) -> Result<MembersAndCasters> {
274 Ok(MembersAndCasters::create(
275 self
276 .column_names
277 .iter()
278 .enumerate()
279 .map(|(i, col_name)| Ok((parse_str::<Ident>(col_name)?, format_ident!("None"), i)))
280 .collect::<Result<Vec<_>>>()?,
281 ))
282 }
283}
284
285#[derive(Clone, Debug, Serialize, Deserialize, Default)]
286struct MigrationsToml {
287 migrations_append_only: Option<Vec<String>>,
288 output_generated_schema_for_your_information_do_not_edit: Option<String>,
289 output_generated_tables_do_not_edit: Option<BTreeMap<String, MiniTable>>,
290}
291
292fn migrations_to_tempdb(migrations: &[String]) -> Connection {
293 let tempdb = rusqlite::Connection::open_in_memory().unwrap();
294
295 tempdb
296 .execute_batch(if cfg!(feature = "sqlite-compat-no-strict-tables") {
297 "CREATE TABLE _turbosql_migrations (rowid INTEGER PRIMARY KEY, migration TEXT NOT NULL);"
298 } else {
299 "CREATE TABLE _turbosql_migrations (rowid INTEGER PRIMARY KEY, migration TEXT NOT NULL) STRICT;"
300 })
301 .unwrap();
302
303 migrations.iter().filter(|m| !m.starts_with("--")).for_each(|m| {
304 match tempdb.execute(m, params![]) {
305 Ok(_) => (),
306 Err(rusqlite::Error::ExecuteReturnedResults) => (), Err(e) => abort_call_site!("Running migrations on temp db: {:?} {:?}", m, e),
308 }
309 });
310
311 tempdb
312}
313
314fn migrations_to_schema(migrations: &[String]) -> rusqlite::Result<String> {
315 Ok(
316 migrations_to_tempdb(migrations)
317 .prepare("SELECT sql FROM sqlite_master WHERE type='table' ORDER BY sql")?
318 .query_map(params![], |row| row.get(0))?
319 .collect::<rusqlite::Result<Vec<String>>>()?
320 .join("\n"),
321 )
322}
323
324fn read_migrations_toml() -> MigrationsToml {
325 let lockfile = std::fs::File::create(std::env::temp_dir().join("migrations.toml.lock")).unwrap();
326 fs2::FileExt::lock_exclusive(&lockfile).unwrap();
327
328 let migrations_toml_path = migrations_toml_path();
329 let migrations_toml_path_lossy = migrations_toml_path.to_string_lossy();
330
331 match migrations_toml_path.exists() {
332 true => {
333 let toml_str = std::fs::read_to_string(&migrations_toml_path)
334 .unwrap_or_else(|e| abort_call_site!("Unable to read {}: {:?}", migrations_toml_path_lossy, e));
335
336 let toml_decoded: MigrationsToml = toml::from_str(&toml_str).unwrap_or_else(|e| {
337 abort_call_site!("Unable to decode toml in {}: {:?}", migrations_toml_path_lossy, e)
338 });
339
340 toml_decoded
341 }
342 false => MigrationsToml::default(),
343 }
344}
345
346fn validate_sql<S: AsRef<str>>(sql: S) -> rusqlite::Result<StatementInfo> {
347 let tempdb = migrations_to_tempdb(&read_migrations_toml().migrations_append_only.unwrap());
348
349 let stmt = tempdb.prepare(sql.as_ref())?;
350 let mut positional_parameter_count = stmt.parameter_count();
351 let mut named_parameters = Vec::new();
352
353 for idx in 1..=stmt.parameter_count() {
354 if let Some(parameter_name) = stmt.parameter_name(idx) {
355 named_parameters.push(parameter_name.to_string());
356 positional_parameter_count -= 1;
357 }
358 }
359
360 Ok(StatementInfo {
361 positional_parameter_count,
362 named_parameters,
363 column_names: stmt.column_names().into_iter().map(str::to_string).collect(),
364 })
365}
366
367fn validate_sql_or_abort<S: AsRef<str> + std::fmt::Debug>(sql: S) -> StatementInfo {
368 validate_sql(sql.as_ref()).unwrap_or_else(|e| {
369 abort_call_site!(r#"Error validating SQL statement: "{}". SQL: {:?}"#, e, sql)
370 })
371}
372
373fn parse_interpolated_sql(
374 input: ParseStream,
375) -> Result<(Option<String>, Punctuated<Expr, Token![,]>, proc_macro2::TokenStream)> {
376 if input.is_empty() {
377 return Ok(Default::default());
378 }
379
380 let sql_token = input.parse::<LitStr>()?;
381 let mut sql = sql_token.value();
382
383 if let Ok(comma_token) = input.parse::<Token![,]>() {
384 let punctuated_tokens = input.parse_terminated(Expr::parse, Token![,])?;
385 return Ok((
386 Some(sql),
387 punctuated_tokens.clone(),
388 quote!(#sql_token #comma_token #punctuated_tokens),
389 ));
390 }
391
392 let mut params = Punctuated::new();
393
394 loop {
395 while input.peek(LitStr) {
396 sql.push(' ');
397 sql.push_str(&input.parse::<LitStr>()?.value());
398 }
399
400 if input.is_empty() {
401 break;
402 }
403
404 params.push(input.parse()?);
405 sql.push_str(" ? ");
406 if input.parse::<Token![,]>().is_ok() {
407 sql.push(',');
408 }
409 }
410
411 Ok((Some(sql), params, Default::default()))
412}
413
414fn do_parse_tokens<const T: usize>(input: ParseStream) -> Result<proc_macro2::TokenStream> {
415 let span = input.span();
416 let result_type = input.parse::<ResultType>().ok();
417 let (mut sql, params, sql_and_parameters_tokens) = parse_interpolated_sql(input)?;
418
419 let mut stmt_info = sql.as_ref().and_then(|s| validate_sql(s).ok());
422
423 if let (true, Some(orig_sql), None) = (T == SELECT || T == UPDATE, &sql, &stmt_info) {
426 let sql_modified = format!("{} {}", if T == SELECT { "SELECT" } else { "UPDATE" }, orig_sql);
427 if let Ok(stmt_info_modified) = validate_sql(&sql_modified) {
428 sql = Some(sql_modified);
429 stmt_info = Some(stmt_info_modified);
430 }
431 }
432
433 if is_rust_analyzer() {
436 return Ok(if let Some(ty) = result_type {
437 let ty = ty.ty()?;
438 quote!(Ok({let x: #ty = Default::default(); x}))
439 } else {
440 quote!()
441 });
442 }
443
444 let (sql, stmt_info) = match (result_type.clone(), sql, stmt_info) {
447 (Some(ResultType { content, .. }), sql, None) => {
450 let table_type = content.table_ident().to_string();
451 let table_name = table_type.to_lowercase();
452
453 let table = {
454 let t = match read_migrations_toml().output_generated_tables_do_not_edit {
455 Some(m) => m.get(&table_name).cloned(),
456 None => None,
457 };
458
459 match t {
460 Some(t) => t,
461 None => {
462 abort!(
463 span,
464 "Table {:?} not found. Does struct {} exist and have #[derive(Turbosql, Default)]?",
465 table_name,
466 table_type
467 );
468 }
469 }
470 };
471
472 let column_names_str = table
473 .columns
474 .iter()
475 .filter_map(|c| {
476 if match &content {
477 Content::SingleColumn(col) => col.column == c.name,
478 _ => true,
479 } {
480 if c.sql_type.starts_with("TEXT")
481 && c.rust_type != "Option < String >"
482 && c.rust_type != "String"
483 {
484 Some(format!("{} AS {}__serialized", c.name, c.name))
485 } else {
486 Some(c.name.clone())
487 }
488 } else {
489 None
490 }
491 })
492 .collect::<Vec<_>>()
493 .join(", ");
494
495 let sql = format!("SELECT {} FROM {} {}", column_names_str, table_name, sql.unwrap_or_default());
496
497 (sql.clone(), validate_sql_or_abort(sql))
498 }
499
500 (_, Some(sql), Some(stmt_info)) => (sql, stmt_info),
502 (_, Some(sql), _) => abort_call_site!("sql did not validate: {}", sql),
503 _ => abort_call_site!("no predicate and no result type found"),
504 };
505
506 if !stmt_info.named_parameters.is_empty() {
507 abort_call_site!("SQLite named parameters not currently supported.");
508 }
509
510 if params.len() != stmt_info.positional_parameter_count {
511 abort!(
512 sql_and_parameters_tokens,
513 "Expected {} bound parameter{}, got {}: {:?}",
514 stmt_info.positional_parameter_count,
515 if stmt_info.positional_parameter_count == 1 { "" } else { "s" },
516 params.len(),
517 sql
518 );
519 }
520
521 if !input.is_empty() {
522 return Err(input.error("Expected parameters"));
523 }
524
525 let params = if stmt_info.named_parameters.is_empty() {
526 quote! { ::turbosql::params![#params] }
527 } else {
528 let param_quotes = stmt_info.named_parameters.iter().map(|p| {
529 let var_ident = format_ident!("{}", &p[1..]);
530 quote!(#p: &#var_ident,)
531 });
532 quote! { ::turbosql::named_params![#(#param_quotes),*] }
533 };
534
535 if stmt_info.column_names.is_empty() {
538 if T == SELECT {
539 abort_call_site!("No rows returned from SQL, use execute! instead.");
540 }
541
542 return Ok(quote! {
543 {
544 (|| -> std::result::Result<usize, ::turbosql::Error> {
545 ::turbosql::__TURBOSQL_DB.with(|db| {
546 let db = db.borrow_mut();
547 let mut stmt = db.prepare_cached(#sql)?;
548 Ok(stmt.execute(#params)?)
549 })
550 })()
551 }
552 });
553 }
554
555 if T != SELECT {
556 abort_call_site!("Rows returned from SQL, use select! instead.");
557 }
558
559 let Some(ResultType { container, content }) = result_type else {
562 abort_call_site!("unknown result_type")
563 };
564
565 let handle_row;
566 let content_ty;
567
568 if content.is_primitive() {
569 handle_row = quote! { row.get(0)? };
570 content_ty = quote! { #content };
571 } else {
572 let MembersAndCasters { row_casters } = stmt_info
573 .membersandcasters()
574 .unwrap_or_else(|_| abort_call_site!("stmt_info.membersandcasters failed"));
575
576 handle_row = quote! {
577 #[allow(clippy::needless_update)]
578 #content {
579 #(#row_casters),*,
580 ..Default::default()
581 }
582 };
583 content_ty = quote! { #content };
584 }
585
586 let return_type;
607 let handle_result;
608
609 match container {
610 Some(ident) if ident == "Vec" => {
611 return_type = quote! { Vec<#content_ty> };
612 handle_result = quote! { result.collect::<Vec<_>>() };
613 }
614 Some(ident) if ident == "Option" => {
615 return_type = quote! { Option<#content_ty> };
616 handle_result = quote! { result.next() };
617 }
618 None => {
619 return_type = quote! { #content_ty };
620 handle_result =
621 quote! { result.next().ok_or(::turbosql::rusqlite::Error::QueryReturnedNoRows)? };
622 }
623 _ => unreachable!("No other container type is possible"),
624 }
625
626 Ok(quote! {
629 {
630 (|| -> std::result::Result<#return_type, ::turbosql::Error> {
631 ::turbosql::__TURBOSQL_DB.with(|db| {
632 let db = db.borrow_mut();
633 let mut stmt = db.prepare_cached(#sql)?;
634 let mut result = stmt.query_and_then(#params, |row| -> std::result::Result<#content_ty, ::turbosql::Error> {
635 Ok(#handle_row)
636 })?.flatten();
637 Ok(#handle_result)
638 })
639 })()
640 }
641 })
642}
643
644#[proc_macro]
646#[proc_macro_error]
647pub fn execute(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
648 parse_macro_input!(input with do_parse_tokens::<EXECUTE>).into()
649}
650
651#[proc_macro]
653#[proc_macro_error]
654pub fn select(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
655 parse_macro_input!(input with do_parse_tokens::<SELECT>).into()
656}
657
658#[proc_macro]
660#[proc_macro_error]
661pub fn update(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
662 parse_macro_input!(input with do_parse_tokens::<UPDATE>).into()
663}
664
665#[proc_macro_derive(Turbosql, attributes(turbosql))]
667#[proc_macro_error]
668pub fn turbosql_derive_macro(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
669 let input = parse_macro_input!(input as DeriveInput);
670 let table_span = input.span();
671 let table_ident = input.ident;
672 let table_name = table_ident.to_string().to_lowercase();
673
674 let dummy_impl = quote! {
675 impl ::turbosql::Turbosql for #table_ident {
676 fn insert(&self) -> Result<i64, ::turbosql::Error> { unimplemented!() }
677 fn insert_mut(&mut self) -> Result<i64, ::turbosql::Error> { unimplemented!() }
678 fn insert_batch<T: AsRef<Self>>(rows: &[T]) -> Result<(), ::turbosql::Error> { unimplemented!() }
679 fn update(&self) -> Result<usize, ::turbosql::Error> { unimplemented!() }
680 fn update_batch<T: AsRef<Self>>(rows: &[T]) -> Result<(), ::turbosql::Error> { unimplemented!() }
681 fn delete(&self) -> Result<usize, ::turbosql::Error> { unimplemented!() }
682 }
683 };
684
685 if is_rust_analyzer() {
686 return dummy_impl.into();
687 }
688
689 proc_macro_error::set_dummy(dummy_impl);
690
691 let Data::Struct(DataStruct { fields: Fields::Named(ref fields), .. }) = input.data else {
692 abort_call_site!("The Turbosql derive macro only supports structs with named fields");
693 };
694
695 let table = Table {
696 ident: table_ident,
697 span: table_span,
698 name: table_name.clone(),
699 columns: extract_columns(fields),
700 };
701
702 let minitable = MiniTable {
703 name: table_name,
704 columns: table
705 .columns
706 .iter()
707 .map(|c| MiniColumn {
708 name: c.name.clone(),
709 sql_type: c.sql_type.to_string(),
710 rust_type: c.rust_type.clone(),
711 })
712 .collect(),
713 };
714
715 create(&table, &minitable);
716
717 let fn_insert = insert::insert(&table);
720 let fn_update = update::update(&table);
721 let fn_delete = delete::delete(&table);
722
723 quote! {
726 #[cfg(not(target_arch = "wasm32"))]
727 impl ::turbosql::Turbosql for #table {
728 #fn_insert
729 #fn_update
730 #fn_delete
731 }
732 }
733 .into()
734}
735
736fn extract_columns(fields: &FieldsNamed) -> Vec<Column> {
738 let columns = fields
739 .named
740 .iter()
741 .filter_map(|f| {
742 let mut sql_default = None;
743
744 for attr in &f.attrs {
745 if attr.path().is_ident("turbosql") {
746 for meta in attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated).unwrap() {
747 match &meta {
748 Meta::Path(path) if path.is_ident("skip") => {
749 return None;
750 }
751 Meta::NameValue(MetaNameValue { path, value: Expr::Lit(ExprLit { lit, .. }), .. })
752 if path.is_ident("sql_default") =>
753 {
754 match lit {
755 Lit::Bool(value) => sql_default = Some(value.value().to_string()),
756 Lit::Int(token) => sql_default = Some(token.to_string()),
757 Lit::Float(token) => sql_default = Some(token.to_string()),
758 Lit::Str(token) => sql_default = Some(format!("'{}'", token.value())),
759 Lit::ByteStr(token) => {
760 use std::fmt::Write;
761 sql_default = Some(format!(
762 "x'{}'",
763 token.value().iter().fold(String::new(), |mut o, b| {
764 let _ = write!(o, "{b:02x}");
765 o
766 })
767 ))
768 }
769 _ => (),
770 }
771 }
772 _ => (),
773 }
774 }
775 }
776 }
777
778 let ident = &f.ident;
779 let name = ident.as_ref().unwrap().to_string();
780
781 let ty = &f.ty;
782 let ty_str = quote!(#ty).to_string();
783
784 let (sql_type, default_example) = match (
785 name.as_str(),
786 if OPTION_U8_ARRAY_RE.is_match(&ty_str) {
787 "Option < [u8; _] >"
788 } else if U8_ARRAY_RE.is_match(&ty_str) {
789 "[u8; _]"
790 } else {
791 ty_str.as_str()
792 },
793 ) {
794 ("rowid", "Option < i64 >") => ("INTEGER PRIMARY KEY", "NULL"),
795 (_, "Option < i8 >") => ("INTEGER", "0"),
796 (_, "i8") => ("INTEGER NOT NULL", "0"),
797 (_, "Option < u8 >") => ("INTEGER", "0"),
798 (_, "u8") => ("INTEGER NOT NULL", "0"),
799 (_, "Option < i16 >") => ("INTEGER", "0"),
800 (_, "i16") => ("INTEGER NOT NULL", "0"),
801 (_, "Option < u16 >") => ("INTEGER", "0"),
802 (_, "u16") => ("INTEGER NOT NULL", "0"),
803 (_, "Option < i32 >") => ("INTEGER", "0"),
804 (_, "i32") => ("INTEGER NOT NULL", "0"),
805 (_, "Option < u32 >") => ("INTEGER", "0"),
806 (_, "u32") => ("INTEGER NOT NULL", "0"),
807 (_, "Option < i64 >") => ("INTEGER", "0"),
808 (_, "i64") => ("INTEGER NOT NULL", "0"),
809 (_, "Option < u64 >") => abort!(ty, SQLITE_U64_ERROR),
810 (_, "u64") => abort!(ty, SQLITE_U64_ERROR),
811 (_, "Option < f64 >") => ("REAL", "0.0"),
812 (_, "f64") => ("REAL NOT NULL", "0.0"),
813 (_, "Option < f32 >") => ("REAL", "0.0"),
814 (_, "f32") => ("REAL NOT NULL", "0.0"),
815 (_, "Option < bool >") => ("INTEGER", "false"),
816 (_, "bool") => ("INTEGER NOT NULL", "false"),
817 (_, "Option < String >") => ("TEXT", "\"\""),
818 (_, "String") => ("TEXT NOT NULL", "''"),
819 (_, "Option < Blob >") => ("BLOB", "b\"\""),
821 (_, "Blob") => ("BLOB NOT NULL", "''"),
822 (_, "Option < Vec < u8 > >") => ("BLOB", "b\"\""),
823 (_, "Vec < u8 >") => ("BLOB NOT NULL", "''"),
824 (_, "Option < [u8; _] >") => ("BLOB", "b\"\\x00\\x01\\xff\""),
825 (_, "[u8; _]") => ("BLOB NOT NULL", "''"),
826 _ => {
827 if ty_str.starts_with("Option < ") {
829 ("TEXT", "\"\"")
830 } else {
831 ("TEXT NOT NULL", "''")
832 }
833 }
834 };
835
836 if sql_default.is_none() && sql_type.ends_with("NOT NULL") {
837 sql_default = Some(default_example.into());
838 }
840
841 Some(Column {
842 ident: ident.clone().unwrap(),
843 span: ty.span(),
844 rust_type: ty_str,
845 name,
846 sql_type,
847 sql_default,
848 })
849 })
850 .collect::<Vec<_>>();
851
852 if !matches!(
857 columns.iter().find(|c| c.name == "rowid"),
858 Some(Column { sql_type: "INTEGER PRIMARY KEY", .. })
859 ) {
860 abort_call_site!("derive(Turbosql) structs must include a 'rowid: Option<i64>' field")
861 };
862
863 columns
864}
865
866use std::fs;
867
868fn create(table: &Table, minitable: &MiniTable) {
870 let sql = makesql_create(table);
873
874 rusqlite::Connection::open_in_memory().unwrap().execute(&sql, params![]).unwrap_or_else(|e| {
875 abort_call_site!("Error validating auto-generated CREATE TABLE statement: {} {:#?}", sql, e)
876 });
877
878 let target_migrations = make_migrations(table);
879
880 let lockfile = std::fs::File::create(std::env::temp_dir().join("migrations.toml.lock")).unwrap();
883 fs2::FileExt::lock_exclusive(&lockfile).unwrap();
884
885 let migrations_toml_path = migrations_toml_path();
886 let migrations_toml_path_lossy = migrations_toml_path.to_string_lossy();
887
888 let old_toml_str = if migrations_toml_path.exists() {
889 fs::read_to_string(&migrations_toml_path)
890 .unwrap_or_else(|e| abort_call_site!("Unable to read {}: {:?}", migrations_toml_path_lossy, e))
891 } else {
892 String::new()
893 };
894
895 let source_migrations_toml: MigrationsToml = toml::from_str(&old_toml_str).unwrap_or_else(|e| {
896 abort_call_site!("Unable to decode toml in {}: {:?}", migrations_toml_path_lossy, e)
897 });
898
899 let mut output_migrations = source_migrations_toml.migrations_append_only.unwrap_or_default();
902
903 #[allow(clippy::search_is_some)]
904 target_migrations.iter().for_each(|target_m| {
905 if output_migrations
906 .iter()
907 .find(|source_m| (source_m == &target_m) || (source_m == &&format!("--{}", target_m)))
908 .is_none()
909 {
910 output_migrations.push(target_m.clone());
911 }
912 });
913
914 let mut tables = source_migrations_toml.output_generated_tables_do_not_edit.unwrap_or_default();
915 tables.insert(table.name.clone(), minitable.clone());
916
917 let mut new_toml_str = String::new();
920 let serializer = toml::Serializer::pretty(&mut new_toml_str);
921
922 MigrationsToml {
923 output_generated_schema_for_your_information_do_not_edit: Some(format!(
924 " {}\n",
925 migrations_to_schema(&output_migrations)
926 .unwrap()
927 .replace('\n', "\n ")
928 .replace('(', "(\n ")
929 .replace(", ", ",\n ")
930 .replace(')', "\n )")
931 )),
932 migrations_append_only: Some(output_migrations),
933 output_generated_tables_do_not_edit: Some(tables),
934 }
935 .serialize(serializer)
936 .unwrap_or_else(|e| abort_call_site!("Unable to serialize migrations toml: {:?}", e));
937
938 let new_toml_str = format!("# This file is auto-generated by Turbosql.\n# It is used to create and apply automatic schema migrations.\n# It should be checked into source control.\n# Modifying it by hand may be dangerous; see the docs.\n\n{}", &new_toml_str);
939
940 if old_toml_str.replace("\r\n", "\n") != new_toml_str {
944 #[cfg(not(feature = "test"))]
945 if std::env::var("CI").is_ok() || std::env::var("TURBOSQL_LOCKED_MODE").is_ok() {
946 abort_call_site!("Change in `{}` detected with CI or TURBOSQL_LOCKED_MODE environment variable set. Make sure your `migrations.toml` file is committed and up-to-date.", migrations_toml_path_lossy);
947 };
948 fs::write(&migrations_toml_path, new_toml_str)
949 .unwrap_or_else(|e| abort_call_site!("Unable to write {}: {:?}", migrations_toml_path_lossy, e));
950 }
951}
952
953fn makesql_create(table: &Table) -> String {
954 let columns =
955 table.columns.iter().map(|c| format!("{} {}", c.name, c.sql_type)).collect::<Vec<_>>().join(",");
956
957 if cfg!(feature = "sqlite-compat-no-strict-tables") {
958 format!("CREATE TABLE {} ({})", table.name, columns)
959 } else {
960 format!("CREATE TABLE {} ({}) STRICT", table.name, columns)
961 }
962}
963
964fn make_migrations(table: &Table) -> Vec<String> {
965 let sql = if cfg!(feature = "sqlite-compat-no-strict-tables") {
966 format!("CREATE TABLE {} (rowid INTEGER PRIMARY KEY)", table.name)
967 } else {
968 format!("CREATE TABLE {} (rowid INTEGER PRIMARY KEY) STRICT", table.name)
969 };
970
971 let mut vec = vec![sql];
972
973 let mut alters = table
974 .columns
975 .iter()
976 .filter_map(|c| match (c.name.as_str(), c.sql_type, &c.sql_default) {
977 ("rowid", "INTEGER PRIMARY KEY", _) => None,
978 (_, _, None) => Some(format!("ALTER TABLE {} ADD COLUMN {} {}", table.name, c.name, c.sql_type)),
979 (_, _, Some(sql_default)) => Some(format!(
980 "ALTER TABLE {} ADD COLUMN {} {} DEFAULT {}",
981 table.name, c.name, c.sql_type, sql_default
982 )),
983 })
984 .collect::<Vec<_>>();
985
986 vec.append(&mut alters);
987
988 vec
989}
990
991fn migrations_toml_path() -> std::path::PathBuf {
992 let mut path = std::path::PathBuf::from(env!("OUT_DIR"));
993 while path.file_name() != Some(std::ffi::OsStr::new("target")) {
994 path.pop();
995 }
996 path.pop();
997 path.push(MIGRATIONS_FILENAME);
998 path
999}
1000
1001fn is_rust_analyzer() -> bool {
1002 std::env::current_exe()
1003 .unwrap()
1004 .file_stem()
1005 .unwrap()
1006 .to_string_lossy()
1007 .starts_with("rust-analyzer")
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013
1014 #[test]
1015 fn test_extract_columns() {
1016 let fields_named = parse_quote!({
1017 rowid: Option<i64>,
1018 name: Option<String>,
1019 age: Option<u32>,
1020 awesomeness: Option<f64>,
1021 #[turbosql(skip)]
1022 skipped: Option<bool>
1023 });
1024
1025 let columns = extract_columns(&fields_named);
1026
1027 assert_eq!(columns.len(), 4);
1028
1029 assert_eq!(columns[0].name, "rowid");
1030 assert_eq!(columns[0].rust_type, "Option < i64 >");
1031 assert_eq!(columns[0].sql_type, "INTEGER PRIMARY KEY");
1032
1033 assert_eq!(columns[1].name, "name");
1034 assert_eq!(columns[1].rust_type, "Option < String >");
1035 assert_eq!(columns[1].sql_type, "TEXT");
1036
1037 assert_eq!(columns[2].name, "age");
1038 assert_eq!(columns[2].rust_type, "Option < u32 >");
1039 assert_eq!(columns[2].sql_type, "INTEGER");
1040
1041 assert_eq!(columns[3].name, "awesomeness");
1042 assert_eq!(columns[3].rust_type, "Option < f64 >");
1043 assert_eq!(columns[3].sql_type, "REAL");
1044
1045 assert!(!columns.iter().any(|c| c.name == "skipped"));
1046 }
1047}