1use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::quote;
12use syn::{
13 parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, GenericArgument, LitStr,
14 PathArguments, Type, TypePath,
15};
16
17#[proc_macro_derive(Model, attributes(rustango))]
19pub fn derive_model(input: TokenStream) -> TokenStream {
20 let input = parse_macro_input!(input as DeriveInput);
21 expand(&input)
22 .unwrap_or_else(syn::Error::into_compile_error)
23 .into()
24}
25
26#[proc_macro]
61pub fn embed_migrations(input: TokenStream) -> TokenStream {
62 expand_embed_migrations(input.into())
63 .unwrap_or_else(syn::Error::into_compile_error)
64 .into()
65}
66
67fn expand_embed_migrations(input: TokenStream2) -> syn::Result<TokenStream2> {
68 let path_str = if input.is_empty() {
70 "./migrations".to_string()
71 } else {
72 let lit: LitStr = syn::parse2(input)?;
73 lit.value()
74 };
75
76 let manifest = std::env::var("CARGO_MANIFEST_DIR").map_err(|_| {
77 syn::Error::new(
78 proc_macro2::Span::call_site(),
79 "embed_migrations! must be invoked during a Cargo build (CARGO_MANIFEST_DIR not set)",
80 )
81 })?;
82 let abs = std::path::Path::new(&manifest).join(&path_str);
83
84 let mut entries: Vec<(String, std::path::PathBuf)> = Vec::new();
85 if abs.is_dir() {
86 let read = std::fs::read_dir(&abs).map_err(|e| {
87 syn::Error::new(
88 proc_macro2::Span::call_site(),
89 format!("embed_migrations!: cannot read {}: {e}", abs.display()),
90 )
91 })?;
92 for entry in read.flatten() {
93 let path = entry.path();
94 if !path.is_file() {
95 continue;
96 }
97 if path.extension().and_then(|s| s.to_str()) != Some("json") {
98 continue;
99 }
100 let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
101 continue;
102 };
103 entries.push((stem.to_owned(), path));
104 }
105 }
106 entries.sort_by(|a, b| a.0.cmp(&b.0));
107
108 let mut chain_names: Vec<String> = Vec::with_capacity(entries.len());
121 let mut prev_refs: Vec<(String, Option<String>)> = Vec::with_capacity(entries.len());
122 for (stem, path) in &entries {
123 let raw = std::fs::read_to_string(path).map_err(|e| {
124 syn::Error::new(
125 proc_macro2::Span::call_site(),
126 format!(
127 "embed_migrations!: cannot read {} for chain validation: {e}",
128 path.display()
129 ),
130 )
131 })?;
132 let json: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
133 syn::Error::new(
134 proc_macro2::Span::call_site(),
135 format!(
136 "embed_migrations!: {} is not valid JSON: {e}",
137 path.display()
138 ),
139 )
140 })?;
141 let name = json
142 .get("name")
143 .and_then(|v| v.as_str())
144 .ok_or_else(|| {
145 syn::Error::new(
146 proc_macro2::Span::call_site(),
147 format!(
148 "embed_migrations!: {} is missing the `name` field",
149 path.display()
150 ),
151 )
152 })?
153 .to_owned();
154 if name != *stem {
155 return Err(syn::Error::new(
156 proc_macro2::Span::call_site(),
157 format!(
158 "embed_migrations!: file stem `{stem}` does not match the migration's \
159 `name` field `{name}` — rename the file or fix the JSON",
160 ),
161 ));
162 }
163 let prev = json
164 .get("prev")
165 .and_then(|v| v.as_str())
166 .map(str::to_owned);
167 chain_names.push(name.clone());
168 prev_refs.push((name, prev));
169 }
170
171 let name_set: std::collections::HashSet<&str> =
172 chain_names.iter().map(String::as_str).collect();
173 for (name, prev) in &prev_refs {
174 if let Some(p) = prev {
175 if !name_set.contains(p.as_str()) {
176 return Err(syn::Error::new(
177 proc_macro2::Span::call_site(),
178 format!(
179 "embed_migrations!: broken migration chain — `{name}` declares \
180 prev=`{p}` but no migration with that name exists in {}",
181 abs.display()
182 ),
183 ));
184 }
185 }
186 }
187
188 let pairs: Vec<TokenStream2> = entries
189 .iter()
190 .map(|(name, path)| {
191 let path_lit = path.display().to_string();
192 quote! { (#name, ::core::include_str!(#path_lit)) }
193 })
194 .collect();
195
196 Ok(quote! {
197 {
198 const __RUSTANGO_EMBEDDED: &[(&'static str, &'static str)] = &[#(#pairs),*];
199 __RUSTANGO_EMBEDDED
200 }
201 })
202}
203
204fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
205 let struct_name = &input.ident;
206
207 let Data::Struct(data) = &input.data else {
208 return Err(syn::Error::new_spanned(
209 struct_name,
210 "Model can only be derived on structs",
211 ));
212 };
213 let Fields::Named(named) = &data.fields else {
214 return Err(syn::Error::new_spanned(
215 struct_name,
216 "Model requires a struct with named fields",
217 ));
218 };
219
220 let container = parse_container_attrs(input)?;
221 let table = container
222 .table
223 .unwrap_or_else(|| to_snake_case(&struct_name.to_string()));
224 let model_name = struct_name.to_string();
225
226 let collected = collect_fields(named)?;
227
228 if let Some((ref display, span)) = container.display {
230 if !collected.field_names.iter().any(|n| n == display) {
231 return Err(syn::Error::new(
232 span,
233 format!("`display = \"{display}\"` does not match any field on this struct"),
234 ));
235 }
236 }
237 let display = container.display.map(|(name, _)| name);
238
239 let model_impl = model_impl_tokens(
240 struct_name,
241 &model_name,
242 &table,
243 display.as_deref(),
244 &collected.field_schemas,
245 );
246 let module_ident = column_module_ident(struct_name);
247 let column_consts = column_const_tokens(&module_ident, &collected.column_entries);
248 let inherent_impl = inherent_impl_tokens(
249 struct_name,
250 &collected,
251 collected.primary_key.as_ref(),
252 &column_consts,
253 );
254 let column_module = column_module_tokens(&module_ident, struct_name, &collected.column_entries);
255 let from_row_impl = from_row_impl_tokens(struct_name, &collected.from_row_inits);
256
257 Ok(quote! {
258 #model_impl
259 #inherent_impl
260 #from_row_impl
261 #column_module
262
263 ::rustango::core::inventory::submit! {
264 ::rustango::core::ModelEntry {
265 schema: <#struct_name as ::rustango::core::Model>::SCHEMA,
266 }
267 }
268 })
269}
270
271struct ColumnEntry {
272 ident: syn::Ident,
275 value_ty: Type,
277 name: String,
279 column: String,
281 field_type_tokens: TokenStream2,
283}
284
285struct CollectedFields {
286 field_schemas: Vec<TokenStream2>,
287 from_row_inits: Vec<TokenStream2>,
288 insert_columns: Vec<TokenStream2>,
291 insert_values: Vec<TokenStream2>,
294 insert_pushes: Vec<TokenStream2>,
299 returning_cols: Vec<TokenStream2>,
302 auto_assigns: Vec<TokenStream2>,
305 auto_field_idents: Vec<(syn::Ident, String)>,
309 bulk_pushes_no_auto: Vec<TokenStream2>,
313 bulk_pushes_all: Vec<TokenStream2>,
317 bulk_columns_no_auto: Vec<TokenStream2>,
320 bulk_columns_all: Vec<TokenStream2>,
323 bulk_auto_uniformity: Vec<TokenStream2>,
327 first_auto_ident: Option<syn::Ident>,
330 has_auto: bool,
332 pk_is_auto: bool,
336 update_assignments: Vec<TokenStream2>,
339 primary_key: Option<(syn::Ident, String)>,
340 column_entries: Vec<ColumnEntry>,
341 field_names: Vec<String>,
344}
345
346fn collect_fields(named: &syn::FieldsNamed) -> syn::Result<CollectedFields> {
347 let cap = named.named.len();
348 let mut out = CollectedFields {
349 field_schemas: Vec::with_capacity(cap),
350 from_row_inits: Vec::with_capacity(cap),
351 insert_columns: Vec::with_capacity(cap),
352 insert_values: Vec::with_capacity(cap),
353 insert_pushes: Vec::with_capacity(cap),
354 returning_cols: Vec::new(),
355 auto_assigns: Vec::new(),
356 auto_field_idents: Vec::new(),
357 bulk_pushes_no_auto: Vec::with_capacity(cap),
358 bulk_pushes_all: Vec::with_capacity(cap),
359 bulk_columns_no_auto: Vec::with_capacity(cap),
360 bulk_columns_all: Vec::with_capacity(cap),
361 bulk_auto_uniformity: Vec::new(),
362 first_auto_ident: None,
363 has_auto: false,
364 pk_is_auto: false,
365 update_assignments: Vec::with_capacity(cap),
366 primary_key: None,
367 column_entries: Vec::with_capacity(cap),
368 field_names: Vec::with_capacity(cap),
369 };
370
371 for field in &named.named {
372 let info = process_field(field)?;
373 out.field_names.push(info.ident.to_string());
374 out.field_schemas.push(info.schema);
375 out.from_row_inits.push(info.from_row_init);
376 let column = info.column.as_str();
377 let ident = info.ident;
378 out.insert_columns.push(quote!(#column));
379 out.insert_values.push(quote! {
380 ::core::convert::Into::<::rustango::core::SqlValue>::into(
381 ::core::clone::Clone::clone(&self.#ident)
382 )
383 });
384 if info.auto {
385 out.has_auto = true;
386 if out.first_auto_ident.is_none() {
387 out.first_auto_ident = Some(ident.clone());
388 }
389 out.returning_cols.push(quote!(#column));
390 out.auto_field_idents
391 .push((ident.clone(), info.column.clone()));
392 out.auto_assigns.push(quote! {
393 self.#ident = ::rustango::sql::sqlx::Row::try_get(&_returning_row, #column)?;
394 });
395 out.insert_pushes.push(quote! {
396 if let ::rustango::sql::Auto::Set(_v) = &self.#ident {
397 _columns.push(#column);
398 _values.push(::core::convert::Into::<::rustango::core::SqlValue>::into(
399 ::core::clone::Clone::clone(_v)
400 ));
401 }
402 });
403 out.bulk_columns_all.push(quote!(#column));
406 out.bulk_pushes_all.push(quote! {
407 _row_vals.push(::core::convert::Into::<::rustango::core::SqlValue>::into(
408 ::core::clone::Clone::clone(&_row.#ident)
409 ));
410 });
411 let ident_clone = ident.clone();
415 out.bulk_auto_uniformity.push(quote! {
416 for _r in rows.iter().skip(1) {
417 if matches!(_r.#ident_clone, ::rustango::sql::Auto::Unset) != _first_unset {
418 return ::core::result::Result::Err(
419 ::rustango::sql::ExecError::Sql(
420 ::rustango::sql::SqlError::BulkAutoMixed
421 )
422 );
423 }
424 }
425 });
426 } else {
427 out.insert_pushes.push(quote! {
428 _columns.push(#column);
429 _values.push(::core::convert::Into::<::rustango::core::SqlValue>::into(
430 ::core::clone::Clone::clone(&self.#ident)
431 ));
432 });
433 out.bulk_columns_no_auto.push(quote!(#column));
435 out.bulk_columns_all.push(quote!(#column));
436 let push_expr = quote! {
437 _row_vals.push(::core::convert::Into::<::rustango::core::SqlValue>::into(
438 ::core::clone::Clone::clone(&_row.#ident)
439 ));
440 };
441 out.bulk_pushes_no_auto.push(push_expr.clone());
442 out.bulk_pushes_all.push(push_expr);
443 }
444 if info.primary_key {
445 if out.primary_key.is_some() {
446 return Err(syn::Error::new_spanned(
447 field,
448 "only one field may be marked `#[rustango(primary_key)]`",
449 ));
450 }
451 out.primary_key = Some((ident.clone(), info.column.clone()));
452 if info.auto {
453 out.pk_is_auto = true;
454 }
455 } else {
456 out.update_assignments.push(quote! {
457 ::rustango::core::Assignment {
458 column: #column,
459 value: ::core::convert::Into::<::rustango::core::SqlValue>::into(
460 ::core::clone::Clone::clone(&self.#ident)
461 ),
462 }
463 });
464 }
465 out.column_entries.push(ColumnEntry {
466 ident: ident.clone(),
467 value_ty: info.value_ty.clone(),
468 name: ident.to_string(),
469 column: info.column.clone(),
470 field_type_tokens: info.field_type_tokens,
471 });
472 }
473 Ok(out)
474}
475
476fn model_impl_tokens(
477 struct_name: &syn::Ident,
478 model_name: &str,
479 table: &str,
480 display: Option<&str>,
481 field_schemas: &[TokenStream2],
482) -> TokenStream2 {
483 let display_tokens = if let Some(name) = display {
484 quote!(::core::option::Option::Some(#name))
485 } else {
486 quote!(::core::option::Option::None)
487 };
488 quote! {
489 impl ::rustango::core::Model for #struct_name {
490 const SCHEMA: &'static ::rustango::core::ModelSchema = &::rustango::core::ModelSchema {
491 name: #model_name,
492 table: #table,
493 fields: &[ #(#field_schemas),* ],
494 display: #display_tokens,
495 };
496 }
497 }
498}
499
500fn inherent_impl_tokens(
501 struct_name: &syn::Ident,
502 fields: &CollectedFields,
503 primary_key: Option<&(syn::Ident, String)>,
504 column_consts: &TokenStream2,
505) -> TokenStream2 {
506 let save_method = if fields.pk_is_auto {
507 let (pk_ident, pk_column) = primary_key
508 .expect("pk_is_auto implies primary_key is Some");
509 let pk_column_lit = pk_column.as_str();
510 let assignments = &fields.update_assignments;
511 Some(quote! {
512 pub async fn save(
530 &mut self,
531 pool: &::rustango::sql::sqlx::PgPool,
532 ) -> ::core::result::Result<(), ::rustango::sql::ExecError> {
533 if matches!(self.#pk_ident, ::rustango::sql::Auto::Unset) {
534 return self.insert(pool).await;
535 }
536 let _query = ::rustango::core::UpdateQuery {
537 model: <Self as ::rustango::core::Model>::SCHEMA,
538 set: ::std::vec![ #( #assignments ),* ],
539 where_clause: ::rustango::core::WhereExpr::Predicate(
540 ::rustango::core::Filter {
541 column: #pk_column_lit,
542 op: ::rustango::core::Op::Eq,
543 value: ::core::convert::Into::<::rustango::core::SqlValue>::into(
544 ::core::clone::Clone::clone(&self.#pk_ident)
545 ),
546 }
547 ),
548 };
549 let _ = ::rustango::sql::update(pool, &_query).await?;
550 ::core::result::Result::Ok(())
551 }
552 })
553 } else {
554 None
555 };
556
557 let pk_methods = primary_key.map(|(pk_ident, pk_column)| {
558 let pk_column_lit = pk_column.as_str();
559 quote! {
560 pub async fn delete(
568 &self,
569 pool: &::rustango::sql::sqlx::PgPool,
570 ) -> ::core::result::Result<u64, ::rustango::sql::ExecError> {
571 let query = ::rustango::core::DeleteQuery {
572 model: <Self as ::rustango::core::Model>::SCHEMA,
573 where_clause: ::rustango::core::WhereExpr::Predicate(
574 ::rustango::core::Filter {
575 column: #pk_column_lit,
576 op: ::rustango::core::Op::Eq,
577 value: ::core::convert::Into::<::rustango::core::SqlValue>::into(
578 ::core::clone::Clone::clone(&self.#pk_ident)
579 ),
580 }
581 ),
582 };
583 ::rustango::sql::delete(pool, &query).await
584 }
585 }
586 });
587
588 let insert_method = if fields.has_auto {
589 let pushes = &fields.insert_pushes;
590 let returning_cols = &fields.returning_cols;
591 let auto_assigns = &fields.auto_assigns;
592 quote! {
593 pub async fn insert(
602 &mut self,
603 pool: &::rustango::sql::sqlx::PgPool,
604 ) -> ::core::result::Result<(), ::rustango::sql::ExecError> {
605 let mut _columns: ::std::vec::Vec<&'static str> =
606 ::std::vec::Vec::new();
607 let mut _values: ::std::vec::Vec<::rustango::core::SqlValue> =
608 ::std::vec::Vec::new();
609 #( #pushes )*
610 let query = ::rustango::core::InsertQuery {
611 model: <Self as ::rustango::core::Model>::SCHEMA,
612 columns: _columns,
613 values: _values,
614 returning: ::std::vec![ #( #returning_cols ),* ],
615 };
616 let _returning_row = ::rustango::sql::insert_returning(pool, &query).await?;
617 #( #auto_assigns )*
618 ::core::result::Result::Ok(())
619 }
620 }
621 } else {
622 let insert_columns = &fields.insert_columns;
623 let insert_values = &fields.insert_values;
624 quote! {
625 pub async fn insert(
631 &self,
632 pool: &::rustango::sql::sqlx::PgPool,
633 ) -> ::core::result::Result<(), ::rustango::sql::ExecError> {
634 let query = ::rustango::core::InsertQuery {
635 model: <Self as ::rustango::core::Model>::SCHEMA,
636 columns: ::std::vec![ #( #insert_columns ),* ],
637 values: ::std::vec![ #( #insert_values ),* ],
638 returning: ::std::vec::Vec::new(),
639 };
640 ::rustango::sql::insert(pool, &query).await
641 }
642 }
643 };
644
645 let bulk_insert_method = if fields.has_auto {
646 let cols_no_auto = &fields.bulk_columns_no_auto;
647 let cols_all = &fields.bulk_columns_all;
648 let pushes_no_auto = &fields.bulk_pushes_no_auto;
649 let pushes_all = &fields.bulk_pushes_all;
650 let returning_cols = &fields.returning_cols;
651 let auto_assigns_for_row = bulk_auto_assigns_for_row(fields);
652 let uniformity = &fields.bulk_auto_uniformity;
653 let first_auto_ident = fields
654 .first_auto_ident
655 .as_ref()
656 .expect("has_auto implies first_auto_ident is Some");
657 quote! {
658 pub async fn bulk_insert(
672 rows: &mut [Self],
673 pool: &::rustango::sql::sqlx::PgPool,
674 ) -> ::core::result::Result<(), ::rustango::sql::ExecError> {
675 if rows.is_empty() {
676 return ::core::result::Result::Ok(());
677 }
678 let _first_unset = matches!(
679 rows[0].#first_auto_ident,
680 ::rustango::sql::Auto::Unset
681 );
682 #( #uniformity )*
683
684 let mut _all_rows: ::std::vec::Vec<
685 ::std::vec::Vec<::rustango::core::SqlValue>,
686 > = ::std::vec::Vec::with_capacity(rows.len());
687 let _columns: ::std::vec::Vec<&'static str> = if _first_unset {
688 for _row in rows.iter() {
689 let mut _row_vals: ::std::vec::Vec<::rustango::core::SqlValue> =
690 ::std::vec::Vec::new();
691 #( #pushes_no_auto )*
692 _all_rows.push(_row_vals);
693 }
694 ::std::vec![ #( #cols_no_auto ),* ]
695 } else {
696 for _row in rows.iter() {
697 let mut _row_vals: ::std::vec::Vec<::rustango::core::SqlValue> =
698 ::std::vec::Vec::new();
699 #( #pushes_all )*
700 _all_rows.push(_row_vals);
701 }
702 ::std::vec![ #( #cols_all ),* ]
703 };
704
705 let _query = ::rustango::core::BulkInsertQuery {
706 model: <Self as ::rustango::core::Model>::SCHEMA,
707 columns: _columns,
708 rows: _all_rows,
709 returning: ::std::vec![ #( #returning_cols ),* ],
710 };
711 let _returned = ::rustango::sql::bulk_insert(pool, &_query).await?;
712 if _returned.len() != rows.len() {
713 return ::core::result::Result::Err(
714 ::rustango::sql::ExecError::Sql(
715 ::rustango::sql::SqlError::BulkInsertReturningMismatch {
716 expected: rows.len(),
717 actual: _returned.len(),
718 }
719 )
720 );
721 }
722 for (_returning_row, _row_mut) in _returned.iter().zip(rows.iter_mut()) {
723 #auto_assigns_for_row
724 }
725 ::core::result::Result::Ok(())
726 }
727 }
728 } else {
729 let cols_all = &fields.bulk_columns_all;
730 let pushes_all = &fields.bulk_pushes_all;
731 quote! {
732 pub async fn bulk_insert(
742 rows: &[Self],
743 pool: &::rustango::sql::sqlx::PgPool,
744 ) -> ::core::result::Result<(), ::rustango::sql::ExecError> {
745 if rows.is_empty() {
746 return ::core::result::Result::Ok(());
747 }
748 let mut _all_rows: ::std::vec::Vec<
749 ::std::vec::Vec<::rustango::core::SqlValue>,
750 > = ::std::vec::Vec::with_capacity(rows.len());
751 for _row in rows.iter() {
752 let mut _row_vals: ::std::vec::Vec<::rustango::core::SqlValue> =
753 ::std::vec::Vec::new();
754 #( #pushes_all )*
755 _all_rows.push(_row_vals);
756 }
757 let _query = ::rustango::core::BulkInsertQuery {
758 model: <Self as ::rustango::core::Model>::SCHEMA,
759 columns: ::std::vec![ #( #cols_all ),* ],
760 rows: _all_rows,
761 returning: ::std::vec::Vec::new(),
762 };
763 let _ = ::rustango::sql::bulk_insert(pool, &_query).await?;
764 ::core::result::Result::Ok(())
765 }
766 }
767 };
768
769 quote! {
770 impl #struct_name {
771 #[must_use]
773 pub fn objects() -> ::rustango::query::QuerySet<#struct_name> {
774 ::rustango::query::QuerySet::new()
775 }
776
777 #insert_method
778
779 #bulk_insert_method
780
781 #save_method
782
783 #pk_methods
784
785 #column_consts
786 }
787 }
788}
789
790fn bulk_auto_assigns_for_row(fields: &CollectedFields) -> TokenStream2 {
794 let lines = fields.auto_field_idents.iter().map(|(ident, column)| {
795 let col_lit = column.as_str();
796 quote! {
797 _row_mut.#ident = ::rustango::sql::sqlx::Row::try_get(
798 _returning_row,
799 #col_lit,
800 )?;
801 }
802 });
803 quote! { #( #lines )* }
804}
805
806fn column_const_tokens(module_ident: &syn::Ident, entries: &[ColumnEntry]) -> TokenStream2 {
808 let lines = entries.iter().map(|e| {
809 let ident = &e.ident;
810 let col_ty = column_type_ident(ident);
811 quote! {
812 #[allow(non_upper_case_globals)]
813 pub const #ident: #module_ident::#col_ty = #module_ident::#col_ty;
814 }
815 });
816 quote! { #(#lines)* }
817}
818
819fn column_module_tokens(
822 module_ident: &syn::Ident,
823 struct_name: &syn::Ident,
824 entries: &[ColumnEntry],
825) -> TokenStream2 {
826 let items = entries.iter().map(|e| {
827 let col_ty = column_type_ident(&e.ident);
828 let value_ty = &e.value_ty;
829 let name = &e.name;
830 let column = &e.column;
831 let field_type_tokens = &e.field_type_tokens;
832 quote! {
833 #[derive(::core::clone::Clone, ::core::marker::Copy)]
834 pub struct #col_ty;
835
836 impl ::rustango::core::Column for #col_ty {
837 type Model = super::#struct_name;
838 type Value = #value_ty;
839 const NAME: &'static str = #name;
840 const COLUMN: &'static str = #column;
841 const FIELD_TYPE: ::rustango::core::FieldType = #field_type_tokens;
842 }
843 }
844 });
845 quote! {
846 #[doc(hidden)]
847 #[allow(non_camel_case_types, non_snake_case)]
848 pub mod #module_ident {
849 #[allow(unused_imports)]
854 use super::*;
855 #(#items)*
856 }
857 }
858}
859
860fn column_type_ident(field_ident: &syn::Ident) -> syn::Ident {
861 syn::Ident::new(&format!("{field_ident}_col"), field_ident.span())
862}
863
864fn column_module_ident(struct_name: &syn::Ident) -> syn::Ident {
865 syn::Ident::new(
866 &format!("__rustango_cols_{struct_name}"),
867 struct_name.span(),
868 )
869}
870
871fn from_row_impl_tokens(struct_name: &syn::Ident, from_row_inits: &[TokenStream2]) -> TokenStream2 {
872 quote! {
873 impl<'r> ::rustango::sql::sqlx::FromRow<'r, ::rustango::sql::sqlx::postgres::PgRow>
874 for #struct_name
875 {
876 fn from_row(
877 row: &'r ::rustango::sql::sqlx::postgres::PgRow,
878 ) -> ::core::result::Result<Self, ::rustango::sql::sqlx::Error> {
879 ::core::result::Result::Ok(Self {
880 #( #from_row_inits ),*
881 })
882 }
883 }
884 }
885}
886
887struct ContainerAttrs {
888 table: Option<String>,
889 display: Option<(String, proc_macro2::Span)>,
890}
891
892fn parse_container_attrs(input: &DeriveInput) -> syn::Result<ContainerAttrs> {
893 let mut out = ContainerAttrs {
894 table: None,
895 display: None,
896 };
897 for attr in &input.attrs {
898 if !attr.path().is_ident("rustango") {
899 continue;
900 }
901 attr.parse_nested_meta(|meta| {
902 if meta.path.is_ident("table") {
903 let s: LitStr = meta.value()?.parse()?;
904 out.table = Some(s.value());
905 return Ok(());
906 }
907 if meta.path.is_ident("display") {
908 let s: LitStr = meta.value()?.parse()?;
909 out.display = Some((s.value(), s.span()));
910 return Ok(());
911 }
912 Err(meta.error("unknown rustango container attribute"))
913 })?;
914 }
915 Ok(out)
916}
917
918struct FieldAttrs {
919 column: Option<String>,
920 primary_key: bool,
921 fk: Option<String>,
922 o2o: Option<String>,
923 on: Option<String>,
924 max_length: Option<u32>,
925 min: Option<i64>,
926 max: Option<i64>,
927 default: Option<String>,
928}
929
930fn parse_field_attrs(field: &syn::Field) -> syn::Result<FieldAttrs> {
931 let mut out = FieldAttrs {
932 column: None,
933 primary_key: false,
934 fk: None,
935 o2o: None,
936 on: None,
937 max_length: None,
938 min: None,
939 max: None,
940 default: None,
941 };
942 for attr in &field.attrs {
943 if !attr.path().is_ident("rustango") {
944 continue;
945 }
946 attr.parse_nested_meta(|meta| {
947 if meta.path.is_ident("column") {
948 let s: LitStr = meta.value()?.parse()?;
949 out.column = Some(s.value());
950 return Ok(());
951 }
952 if meta.path.is_ident("primary_key") {
953 out.primary_key = true;
954 return Ok(());
955 }
956 if meta.path.is_ident("fk") {
957 let s: LitStr = meta.value()?.parse()?;
958 out.fk = Some(s.value());
959 return Ok(());
960 }
961 if meta.path.is_ident("o2o") {
962 let s: LitStr = meta.value()?.parse()?;
963 out.o2o = Some(s.value());
964 return Ok(());
965 }
966 if meta.path.is_ident("on") {
967 let s: LitStr = meta.value()?.parse()?;
968 out.on = Some(s.value());
969 return Ok(());
970 }
971 if meta.path.is_ident("max_length") {
972 let lit: syn::LitInt = meta.value()?.parse()?;
973 out.max_length = Some(lit.base10_parse::<u32>()?);
974 return Ok(());
975 }
976 if meta.path.is_ident("min") {
977 out.min = Some(parse_signed_i64(&meta)?);
978 return Ok(());
979 }
980 if meta.path.is_ident("max") {
981 out.max = Some(parse_signed_i64(&meta)?);
982 return Ok(());
983 }
984 if meta.path.is_ident("default") {
985 let s: LitStr = meta.value()?.parse()?;
986 out.default = Some(s.value());
987 return Ok(());
988 }
989 Err(meta.error("unknown rustango field attribute"))
990 })?;
991 }
992 Ok(out)
993}
994
995fn parse_signed_i64(meta: &syn::meta::ParseNestedMeta<'_>) -> syn::Result<i64> {
997 let expr: syn::Expr = meta.value()?.parse()?;
998 match expr {
999 syn::Expr::Lit(syn::ExprLit {
1000 lit: syn::Lit::Int(lit),
1001 ..
1002 }) => lit.base10_parse::<i64>(),
1003 syn::Expr::Unary(syn::ExprUnary {
1004 op: syn::UnOp::Neg(_),
1005 expr,
1006 ..
1007 }) => {
1008 if let syn::Expr::Lit(syn::ExprLit {
1009 lit: syn::Lit::Int(lit),
1010 ..
1011 }) = *expr
1012 {
1013 let v: i64 = lit.base10_parse()?;
1014 Ok(-v)
1015 } else {
1016 Err(syn::Error::new_spanned(expr, "expected integer literal"))
1017 }
1018 }
1019 other => Err(syn::Error::new_spanned(
1020 other,
1021 "expected integer literal (signed)",
1022 )),
1023 }
1024}
1025
1026struct FieldInfo<'a> {
1027 ident: &'a syn::Ident,
1028 column: String,
1029 primary_key: bool,
1030 auto: bool,
1034 value_ty: &'a Type,
1037 field_type_tokens: TokenStream2,
1039 schema: TokenStream2,
1040 from_row_init: TokenStream2,
1041}
1042
1043fn process_field(field: &syn::Field) -> syn::Result<FieldInfo<'_>> {
1044 let attrs = parse_field_attrs(field)?;
1045 let ident = field
1046 .ident
1047 .as_ref()
1048 .ok_or_else(|| syn::Error::new(field.span(), "tuple structs are not supported"))?;
1049 let name = ident.to_string();
1050 let column = attrs.column.clone().unwrap_or_else(|| name.clone());
1051 let primary_key = attrs.primary_key;
1052 let DetectedType {
1053 kind,
1054 nullable,
1055 auto,
1056 fk_inner,
1057 } = detect_type(&field.ty)?;
1058 check_bound_compatibility(field, &attrs, kind)?;
1059 if auto && !primary_key {
1060 return Err(syn::Error::new_spanned(
1061 field,
1062 "`Auto<T>` is only valid on a `#[rustango(primary_key)]` field",
1063 ));
1064 }
1065 if auto && attrs.default.is_some() {
1066 return Err(syn::Error::new_spanned(
1067 field,
1068 "`#[rustango(default = \"…\")]` is redundant on an `Auto<T>` field — \
1069 SERIAL / BIGSERIAL already supplies a default sequence.",
1070 ));
1071 }
1072 if fk_inner.is_some() && primary_key {
1073 return Err(syn::Error::new_spanned(
1074 field,
1075 "`ForeignKey<T>` is not allowed on a primary-key field — \
1076 a row's PK is its own identity, not a reference to a parent.",
1077 ));
1078 }
1079 let relation = relation_tokens(field, &attrs, fk_inner)?;
1080 let column_lit = column.as_str();
1081 let field_type_tokens = kind.variant_tokens();
1082 let max_length = optional_u32(attrs.max_length);
1083 let min = optional_i64(attrs.min);
1084 let max = optional_i64(attrs.max);
1085 let default = optional_str(attrs.default.as_deref());
1086
1087 let schema = quote! {
1088 ::rustango::core::FieldSchema {
1089 name: #name,
1090 column: #column_lit,
1091 ty: #field_type_tokens,
1092 nullable: #nullable,
1093 primary_key: #primary_key,
1094 relation: #relation,
1095 max_length: #max_length,
1096 min: #min,
1097 max: #max,
1098 default: #default,
1099 auto: #auto,
1100 }
1101 };
1102
1103 let from_row_init = quote! {
1104 #ident: ::rustango::sql::sqlx::Row::try_get(row, #column_lit)?
1105 };
1106
1107 Ok(FieldInfo {
1108 ident,
1109 column,
1110 primary_key,
1111 auto,
1112 value_ty: &field.ty,
1113 field_type_tokens,
1114 schema,
1115 from_row_init,
1116 })
1117}
1118
1119fn check_bound_compatibility(
1120 field: &syn::Field,
1121 attrs: &FieldAttrs,
1122 kind: DetectedKind,
1123) -> syn::Result<()> {
1124 if attrs.max_length.is_some() && kind != DetectedKind::String {
1125 return Err(syn::Error::new_spanned(
1126 field,
1127 "`max_length` is only valid on `String` fields (or `Option<String>`)",
1128 ));
1129 }
1130 if (attrs.min.is_some() || attrs.max.is_some()) && !kind.is_integer() {
1131 return Err(syn::Error::new_spanned(
1132 field,
1133 "`min` / `max` are only valid on integer fields (`i32`, `i64`, optionally Option-wrapped)",
1134 ));
1135 }
1136 if let (Some(min), Some(max)) = (attrs.min, attrs.max) {
1137 if min > max {
1138 return Err(syn::Error::new_spanned(
1139 field,
1140 format!("`min` ({min}) is greater than `max` ({max})"),
1141 ));
1142 }
1143 }
1144 Ok(())
1145}
1146
1147fn optional_u32(value: Option<u32>) -> TokenStream2 {
1148 if let Some(v) = value {
1149 quote!(::core::option::Option::Some(#v))
1150 } else {
1151 quote!(::core::option::Option::None)
1152 }
1153}
1154
1155fn optional_i64(value: Option<i64>) -> TokenStream2 {
1156 if let Some(v) = value {
1157 quote!(::core::option::Option::Some(#v))
1158 } else {
1159 quote!(::core::option::Option::None)
1160 }
1161}
1162
1163fn optional_str(value: Option<&str>) -> TokenStream2 {
1164 if let Some(v) = value {
1165 quote!(::core::option::Option::Some(#v))
1166 } else {
1167 quote!(::core::option::Option::None)
1168 }
1169}
1170
1171fn relation_tokens(
1172 field: &syn::Field,
1173 attrs: &FieldAttrs,
1174 fk_inner: Option<&syn::Type>,
1175) -> syn::Result<TokenStream2> {
1176 if let Some(inner) = fk_inner {
1177 if attrs.fk.is_some() || attrs.o2o.is_some() {
1178 return Err(syn::Error::new_spanned(
1179 field,
1180 "`ForeignKey<T>` already declares the FK target via the type parameter — \
1181 remove the `fk = \"…\"` / `o2o = \"…\"` attribute.",
1182 ));
1183 }
1184 let on = attrs.on.as_deref().unwrap_or("id");
1185 return Ok(quote! {
1186 ::core::option::Option::Some(::rustango::core::Relation::Fk {
1187 to: <#inner as ::rustango::core::Model>::SCHEMA.table,
1188 on: #on,
1189 })
1190 });
1191 }
1192 match (&attrs.fk, &attrs.o2o) {
1193 (Some(_), Some(_)) => Err(syn::Error::new_spanned(
1194 field,
1195 "`fk` and `o2o` are mutually exclusive",
1196 )),
1197 (Some(to), None) => {
1198 let on = attrs.on.as_deref().unwrap_or("id");
1199 Ok(quote! {
1200 ::core::option::Option::Some(::rustango::core::Relation::Fk { to: #to, on: #on })
1201 })
1202 }
1203 (None, Some(to)) => {
1204 let on = attrs.on.as_deref().unwrap_or("id");
1205 Ok(quote! {
1206 ::core::option::Option::Some(::rustango::core::Relation::O2O { to: #to, on: #on })
1207 })
1208 }
1209 (None, None) => {
1210 if attrs.on.is_some() {
1211 return Err(syn::Error::new_spanned(
1212 field,
1213 "`on` requires `fk` or `o2o`",
1214 ));
1215 }
1216 Ok(quote!(::core::option::Option::None))
1217 }
1218 }
1219}
1220
1221#[derive(Clone, Copy, PartialEq, Eq)]
1225enum DetectedKind {
1226 I32,
1227 I64,
1228 F32,
1229 F64,
1230 Bool,
1231 String,
1232 DateTime,
1233 Date,
1234 Uuid,
1235 Json,
1236}
1237
1238impl DetectedKind {
1239 fn variant_tokens(self) -> TokenStream2 {
1240 match self {
1241 Self::I32 => quote!(::rustango::core::FieldType::I32),
1242 Self::I64 => quote!(::rustango::core::FieldType::I64),
1243 Self::F32 => quote!(::rustango::core::FieldType::F32),
1244 Self::F64 => quote!(::rustango::core::FieldType::F64),
1245 Self::Bool => quote!(::rustango::core::FieldType::Bool),
1246 Self::String => quote!(::rustango::core::FieldType::String),
1247 Self::DateTime => quote!(::rustango::core::FieldType::DateTime),
1248 Self::Date => quote!(::rustango::core::FieldType::Date),
1249 Self::Uuid => quote!(::rustango::core::FieldType::Uuid),
1250 Self::Json => quote!(::rustango::core::FieldType::Json),
1251 }
1252 }
1253
1254 fn is_integer(self) -> bool {
1255 matches!(self, Self::I32 | Self::I64)
1256 }
1257}
1258
1259#[derive(Clone, Copy)]
1265struct DetectedType<'a> {
1266 kind: DetectedKind,
1267 nullable: bool,
1268 auto: bool,
1269 fk_inner: Option<&'a syn::Type>,
1270}
1271
1272fn detect_type(ty: &syn::Type) -> syn::Result<DetectedType<'_>> {
1273 let Type::Path(TypePath { path, qself: None }) = ty else {
1274 return Err(syn::Error::new_spanned(ty, "unsupported field type"));
1275 };
1276 let last = path
1277 .segments
1278 .last()
1279 .ok_or_else(|| syn::Error::new_spanned(ty, "empty type path"))?;
1280
1281 if last.ident == "Option" {
1282 let inner = generic_inner(ty, &last.arguments, "Option")?;
1283 let inner_det = detect_type(inner)?;
1284 if inner_det.nullable {
1285 return Err(syn::Error::new_spanned(
1286 ty,
1287 "nested Option is not supported",
1288 ));
1289 }
1290 if inner_det.auto {
1291 return Err(syn::Error::new_spanned(
1292 ty,
1293 "`Option<Auto<T>>` is not supported — Auto fields are server-assigned and cannot be NULL",
1294 ));
1295 }
1296 return Ok(DetectedType {
1297 nullable: true,
1298 ..inner_det
1299 });
1300 }
1301
1302 if last.ident == "Auto" {
1303 let inner = generic_inner(ty, &last.arguments, "Auto")?;
1304 let inner_det = detect_type(inner)?;
1305 if inner_det.auto {
1306 return Err(syn::Error::new_spanned(
1307 ty,
1308 "nested Auto is not supported",
1309 ));
1310 }
1311 if inner_det.nullable {
1312 return Err(syn::Error::new_spanned(
1313 ty,
1314 "`Auto<Option<T>>` is not supported — Auto fields are server-assigned and cannot be NULL",
1315 ));
1316 }
1317 if inner_det.fk_inner.is_some() {
1318 return Err(syn::Error::new_spanned(
1319 ty,
1320 "`Auto<ForeignKey<T>>` is not supported — Auto is for server-assigned PKs, ForeignKey is for parent references",
1321 ));
1322 }
1323 if !matches!(inner_det.kind, DetectedKind::I32 | DetectedKind::I64) {
1324 return Err(syn::Error::new_spanned(
1325 ty,
1326 "`Auto<T>` only supports integer types (`i32` → SERIAL, `i64` → BIGSERIAL)",
1327 ));
1328 }
1329 return Ok(DetectedType {
1330 auto: true,
1331 ..inner_det
1332 });
1333 }
1334
1335 if last.ident == "ForeignKey" {
1336 let inner = generic_inner(ty, &last.arguments, "ForeignKey")?;
1337 return Ok(DetectedType {
1342 kind: DetectedKind::I64,
1343 nullable: false,
1344 auto: false,
1345 fk_inner: Some(inner),
1346 });
1347 }
1348
1349 let kind = match last.ident.to_string().as_str() {
1350 "i32" => DetectedKind::I32,
1351 "i64" => DetectedKind::I64,
1352 "f32" => DetectedKind::F32,
1353 "f64" => DetectedKind::F64,
1354 "bool" => DetectedKind::Bool,
1355 "String" => DetectedKind::String,
1356 "DateTime" => DetectedKind::DateTime,
1357 "NaiveDate" => DetectedKind::Date,
1358 "Uuid" => DetectedKind::Uuid,
1359 "Value" => DetectedKind::Json,
1360 other => {
1361 return Err(syn::Error::new_spanned(
1362 ty,
1363 format!("unsupported field type `{other}`; v0.1 supports i32/i64/f32/f64/bool/String/DateTime/NaiveDate/Uuid/serde_json::Value, optionally wrapped in Option or Auto (Auto only on integers)"),
1364 ));
1365 }
1366 };
1367 Ok(DetectedType {
1368 kind,
1369 nullable: false,
1370 auto: false,
1371 fk_inner: None,
1372 })
1373}
1374
1375fn generic_inner<'a>(
1376 ty: &'a Type,
1377 arguments: &'a PathArguments,
1378 wrapper: &str,
1379) -> syn::Result<&'a Type> {
1380 let PathArguments::AngleBracketed(args) = arguments else {
1381 return Err(syn::Error::new_spanned(
1382 ty,
1383 format!("{wrapper} requires a generic argument"),
1384 ));
1385 };
1386 args.args
1387 .iter()
1388 .find_map(|a| match a {
1389 GenericArgument::Type(t) => Some(t),
1390 _ => None,
1391 })
1392 .ok_or_else(|| {
1393 syn::Error::new_spanned(ty, format!("{wrapper}<T> requires a type argument"))
1394 })
1395}
1396
1397fn to_snake_case(s: &str) -> String {
1398 let mut out = String::with_capacity(s.len() + 4);
1399 for (i, ch) in s.chars().enumerate() {
1400 if ch.is_ascii_uppercase() {
1401 if i > 0 {
1402 out.push('_');
1403 }
1404 out.push(ch.to_ascii_lowercase());
1405 } else {
1406 out.push(ch);
1407 }
1408 }
1409 out
1410}