1use proc_macro::TokenStream;
17use syn::ext::IdentExt;
18
19mod infer;
20mod parse;
21mod validate;
22mod validate_derive;
23
24use parse::{InheritanceStrategy, ModelDef, RelationshipKindAttr, parse_model};
25
26#[proc_macro_derive(Model, attributes(sqlmodel))]
71pub fn derive_model(input: TokenStream) -> TokenStream {
72 let input = syn::parse_macro_input!(input as syn::DeriveInput);
73
74 let model = match parse_model(&input) {
76 Ok(m) => m,
77 Err(e) => return e.to_compile_error().into(),
78 };
79
80 if let Err(e) = validate::validate_model(&model) {
82 return e.to_compile_error().into();
83 }
84
85 generate_model_impl(&model).into()
87}
88
89fn generate_model_impl(model: &ModelDef) -> proc_macro2::TokenStream {
91 let name = &model.name;
92 let table_name_lit = &model.table_name;
93 let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
94
95 let table_name_ts =
98 if model.config.inherits.is_some() && model.config.discriminator_value.is_some() {
99 let parent = model
100 .config
101 .inherits
102 .as_deref()
103 .expect("inherits checked above");
104 let parent_ty_ts: proc_macro2::TokenStream =
105 if let Ok(path) = syn::parse_str::<syn::Path>(parent) {
106 quote::quote! { #path }
107 } else {
108 let ident = syn::Ident::new(parent, proc_macro2::Span::call_site());
109 quote::quote! { #ident }
110 };
111 quote::quote! { <#parent_ty_ts as sqlmodel_core::Model>::TABLE_NAME }
112 } else {
113 quote::quote! { #table_name_lit }
114 };
115
116 let pk_fields: Vec<&str> = model
118 .primary_key_fields()
119 .iter()
120 .map(|f| f.column_name.as_str())
121 .collect();
122 let pk_field_names: Vec<_> = pk_fields.clone();
123
124 let pk_slice = if pk_field_names.is_empty() {
126 let has_id_field = model.fields.iter().any(|f| f.name == "id" && !f.skip);
128 if has_id_field {
129 quote::quote! { &["id"] }
130 } else {
131 quote::quote! { &[] }
132 }
133 } else {
134 quote::quote! { &[#(#pk_field_names),*] }
135 };
136
137 let field_infos = generate_field_infos(model);
139
140 let relationships = generate_relationships(model);
142
143 let to_row_body = generate_to_row(model);
145
146 let from_row_body = generate_from_row(model);
148
149 let pk_value_body = generate_primary_key_value(model);
151
152 let is_new_body = generate_is_new(model);
154
155 let model_config_body = generate_model_config(model);
157
158 let inheritance_body = generate_inheritance(model);
160
161 let (shard_key_const, shard_key_value_body) = generate_shard_key(model);
163
164 let joined_parent_row_body = generate_joined_parent_row(model);
166
167 let debug_impl = generate_debug_impl(model);
169
170 let hybrid_impl = generate_hybrid_methods(model);
172
173 quote::quote! {
174 impl #impl_generics sqlmodel_core::Model for #name #ty_generics #where_clause {
175 const TABLE_NAME: &'static str = #table_name_ts;
176 const PRIMARY_KEY: &'static [&'static str] = #pk_slice;
177 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] = #relationships;
178 const SHARD_KEY: Option<&'static str> = #shard_key_const;
179
180 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
181 static FIELDS: &[sqlmodel_core::FieldInfo] = &[
182 #field_infos
183 ];
184 FIELDS
185 }
186
187 fn to_row(&self) -> Vec<(&'static str, sqlmodel_core::Value)> {
188 #to_row_body
189 }
190
191 fn from_row(row: &sqlmodel_core::Row) -> sqlmodel_core::Result<Self> {
192 #from_row_body
193 }
194
195 fn primary_key_value(&self) -> Vec<sqlmodel_core::Value> {
196 #pk_value_body
197 }
198
199 fn is_new(&self) -> bool {
200 #is_new_body
201 }
202
203 fn model_config() -> sqlmodel_core::ModelConfig {
204 #model_config_body
205 }
206
207 fn inheritance() -> sqlmodel_core::InheritanceInfo {
208 #inheritance_body
209 }
210
211 fn shard_key_value(&self) -> Option<sqlmodel_core::Value> {
212 #shard_key_value_body
213 }
214
215 #joined_parent_row_body
216 }
217
218 #debug_impl
219
220 #hybrid_impl
221 }
222}
223
224fn generate_hybrid_methods(model: &ModelDef) -> proc_macro2::TokenStream {
230 let hybrid_fields: Vec<_> = model
231 .fields
232 .iter()
233 .filter_map(|f| {
234 if !f.hybrid {
235 return None;
236 }
237 let sql = f.hybrid_sql.as_deref()?;
238 Some((f, sql))
239 })
240 .collect();
241
242 if hybrid_fields.is_empty() {
243 return quote::quote! {};
244 }
245
246 let name = &model.name;
247 let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
248
249 let methods: Vec<_> = hybrid_fields
250 .iter()
251 .map(|(field, sql)| {
252 let method_name = quote::format_ident!("{}_expr", field.name);
253 let doc = format!(
254 "SQL expression for the `{}` hybrid property.\n\nGenerates: `{}`",
255 field.name, sql
256 );
257 quote::quote! {
258 #[doc = #doc]
259 pub fn #method_name() -> sqlmodel_query::Expr {
260 sqlmodel_query::Expr::raw(#sql)
261 }
262 }
263 })
264 .collect();
265
266 quote::quote! {
267 impl #impl_generics #name #ty_generics #where_clause {
268 #(#methods)*
269 }
270 }
271}
272
273fn generate_joined_parent_row(model: &ModelDef) -> proc_macro2::TokenStream {
274 let is_joined_child =
275 model.config.inheritance == InheritanceStrategy::Joined && model.config.inherits.is_some();
276 if !is_joined_child {
277 return quote::quote! {};
278 }
279
280 let Some(parent_field) = model.fields.iter().find(|f| f.parent) else {
281 return quote::quote! {};
282 };
283 let parent_ident = &parent_field.name;
284
285 quote::quote! {
286 fn joined_parent_row(&self) -> Option<Vec<(&'static str, sqlmodel_core::Value)>> {
287 Some(self.#parent_ident.to_row())
288 }
289 }
290}
291
292fn referential_action_ts(action: &str) -> proc_macro2::TokenStream {
294 match action.to_uppercase().as_str() {
295 "NO ACTION" | "NOACTION" | "NO_ACTION" => {
296 quote::quote! { sqlmodel_core::ReferentialAction::NoAction }
297 }
298 "RESTRICT" => quote::quote! { sqlmodel_core::ReferentialAction::Restrict },
299 "CASCADE" => quote::quote! { sqlmodel_core::ReferentialAction::Cascade },
300 "SET NULL" | "SETNULL" | "SET_NULL" => {
301 quote::quote! { sqlmodel_core::ReferentialAction::SetNull }
302 }
303 "SET DEFAULT" | "SETDEFAULT" | "SET_DEFAULT" => {
304 quote::quote! { sqlmodel_core::ReferentialAction::SetDefault }
305 }
306 _ => quote::quote! { sqlmodel_core::ReferentialAction::NoAction },
307 }
308}
309
310fn generate_field_infos(model: &ModelDef) -> proc_macro2::TokenStream {
312 let mut field_ts = Vec::new();
313
314 for field in model.data_fields() {
316 let field_ident = field.name.unraw();
317 let column_name = &field.column_name;
318 let primary_key = field.primary_key;
319 let auto_increment = field.auto_increment;
320
321 let sa_col = field.sa_column.as_ref();
323
324 let nullable = sa_col.and_then(|sc| sc.nullable).unwrap_or(field.nullable);
326
327 let unique = sa_col.and_then(|sc| sc.unique).unwrap_or(field.unique);
329
330 let effective_sql_type = sa_col
332 .and_then(|sc| sc.sql_type.as_ref())
333 .or(field.sql_type.as_ref());
334 let sql_type_ts = if let Some(sql_type_str) = effective_sql_type {
335 infer::parse_sql_type_attr(sql_type_str)
337 } else {
338 infer::infer_sql_type(&field.ty)
340 };
341
342 let sql_type_override_ts = if let Some(sql_type_str) = effective_sql_type {
344 quote::quote! { Some(#sql_type_str) }
345 } else {
346 quote::quote! { None }
347 };
348
349 let effective_default = sa_col
351 .and_then(|sc| sc.server_default.as_ref())
352 .or(field.default.as_ref());
353 let default_ts = if let Some(d) = effective_default {
354 quote::quote! { Some(#d) }
355 } else {
356 quote::quote! { None }
357 };
358
359 let fk_ts = if let Some(fk) = &field.foreign_key {
361 quote::quote! { Some(#fk) }
362 } else {
363 quote::quote! { None }
364 };
365
366 let effective_index = sa_col
368 .and_then(|sc| sc.index.as_ref())
369 .or(field.index.as_ref());
370 let index_ts = if let Some(idx) = effective_index {
371 quote::quote! { Some(#idx) }
372 } else {
373 quote::quote! { None }
374 };
375
376 let on_delete_ts = if let Some(ref action) = field.on_delete {
378 let action_ts = referential_action_ts(action);
379 quote::quote! { Some(#action_ts) }
380 } else {
381 quote::quote! { None }
382 };
383
384 let on_update_ts = if let Some(ref action) = field.on_update {
386 let action_ts = referential_action_ts(action);
387 quote::quote! { Some(#action_ts) }
388 } else {
389 quote::quote! { None }
390 };
391
392 let alias_ts = if let Some(ref alias) = field.alias {
394 quote::quote! { Some(#alias) }
395 } else {
396 quote::quote! { None }
397 };
398
399 let validation_alias_ts = if let Some(ref val_alias) = field.validation_alias {
400 quote::quote! { Some(#val_alias) }
401 } else {
402 quote::quote! { None }
403 };
404
405 let serialization_alias_ts = if let Some(ref ser_alias) = field.serialization_alias {
406 quote::quote! { Some(#ser_alias) }
407 } else {
408 quote::quote! { None }
409 };
410
411 let computed = field.computed;
412 let exclude = field.exclude;
413
414 let title_ts = if let Some(ref title) = field.title {
416 quote::quote! { Some(#title) }
417 } else {
418 quote::quote! { None }
419 };
420
421 let description_ts = if let Some(ref desc) = field.description {
422 quote::quote! { Some(#desc) }
423 } else {
424 quote::quote! { None }
425 };
426
427 let schema_extra_ts = if let Some(ref extra) = field.schema_extra {
428 quote::quote! { Some(#extra) }
429 } else {
430 quote::quote! { None }
431 };
432
433 let default_json_ts = if let Some(ref dj) = field.default_json {
435 quote::quote! { Some(#dj) }
436 } else {
437 quote::quote! { None }
438 };
439
440 let const_field = field.const_field;
442
443 let effective_constraints: Vec<&String> = if let Some(sc) = sa_col {
446 sc.check.iter().collect()
447 } else {
448 field.column_constraints.iter().collect()
449 };
450 let column_constraints_ts = if effective_constraints.is_empty() {
451 quote::quote! { &[] }
452 } else {
453 quote::quote! { &[#(#effective_constraints),*] }
454 };
455
456 let effective_comment = sa_col
459 .and_then(|sc| sc.comment.as_ref())
460 .or(field.column_comment.as_ref());
461 let column_comment_ts = if let Some(comment) = effective_comment {
462 quote::quote! { Some(#comment) }
463 } else {
464 quote::quote! { None }
465 };
466
467 let column_info_ts = if let Some(ref info) = field.column_info {
469 quote::quote! { Some(#info) }
470 } else {
471 quote::quote! { None }
472 };
473
474 let hybrid_sql_ts = if let Some(ref sql) = field.hybrid_sql {
476 quote::quote! { Some(#sql) }
477 } else {
478 quote::quote! { None }
479 };
480
481 let discriminator_ts = if let Some(ref disc) = field.discriminator {
483 quote::quote! { Some(#disc) }
484 } else {
485 quote::quote! { None }
486 };
487
488 let precision_ts = if let Some(p) = field.max_digits {
490 quote::quote! { Some(#p) }
491 } else {
492 quote::quote! { None }
493 };
494
495 let scale_ts = if let Some(s) = field.decimal_places {
496 quote::quote! { Some(#s) }
497 } else {
498 quote::quote! { None }
499 };
500
501 field_ts.push(quote::quote! {
502 sqlmodel_core::FieldInfo::new(stringify!(#field_ident), #column_name, #sql_type_ts)
503 .sql_type_override_opt(#sql_type_override_ts)
504 .precision_opt(#precision_ts)
505 .scale_opt(#scale_ts)
506 .nullable(#nullable)
507 .primary_key(#primary_key)
508 .auto_increment(#auto_increment)
509 .unique(#unique)
510 .default_opt(#default_ts)
511 .foreign_key_opt(#fk_ts)
512 .on_delete_opt(#on_delete_ts)
513 .on_update_opt(#on_update_ts)
514 .index_opt(#index_ts)
515 .alias_opt(#alias_ts)
516 .validation_alias_opt(#validation_alias_ts)
517 .serialization_alias_opt(#serialization_alias_ts)
518 .computed(#computed)
519 .exclude(#exclude)
520 .title_opt(#title_ts)
521 .description_opt(#description_ts)
522 .schema_extra_opt(#schema_extra_ts)
523 .default_json_opt(#default_json_ts)
524 .const_field(#const_field)
525 .column_constraints(#column_constraints_ts)
526 .column_comment_opt(#column_comment_ts)
527 .column_info_opt(#column_info_ts)
528 .hybrid_sql_opt(#hybrid_sql_ts)
529 .discriminator_opt(#discriminator_ts)
530 });
531 }
532
533 quote::quote! { #(#field_ts),* }
534}
535
536fn generate_to_row(model: &ModelDef) -> proc_macro2::TokenStream {
538 let mut conversions = Vec::new();
539
540 for field in model.select_fields() {
541 let field_name = &field.name;
542 let column_name = &field.column_name;
543
544 if parse::is_option_type(&field.ty) {
546 conversions.push(quote::quote! {
547 (#column_name, match &self.#field_name {
548 Some(v) => sqlmodel_core::Value::from(v.clone()),
549 None => sqlmodel_core::Value::Null,
550 })
551 });
552 } else {
553 conversions.push(quote::quote! {
554 (#column_name, sqlmodel_core::Value::from(self.#field_name.clone()))
555 });
556 }
557 }
558
559 quote::quote! {
560 let mut out = vec![#(#conversions),*];
561
562 let inh = <Self as sqlmodel_core::Model>::inheritance();
566 if let (Some(col), Some(val)) = (inh.discriminator_column, inh.discriminator_value) {
567 if !out.iter().any(|(c, _)| *c == col) {
568 out.push((col, sqlmodel_core::Value::from(val)));
569 }
570 }
571
572 out
573 }
574}
575
576fn generate_from_row(model: &ModelDef) -> proc_macro2::TokenStream {
578 let name = &model.name;
579 let mut field_extractions = Vec::new();
580
581 let row_ident = quote::format_ident!("local_row");
584
585 for field in model.select_fields() {
586 let field_name = &field.name;
587 let column_name = &field.column_name;
588
589 if parse::is_option_type(&field.ty) {
590 field_extractions.push(quote::quote! {
592 #field_name: #row_ident.get_named(#column_name).ok()
593 });
594 } else {
595 field_extractions.push(quote::quote! {
597 #field_name: #row_ident.get_named(#column_name)?
598 });
599 }
600 }
601
602 let skipped_fields: Vec<_> = model
604 .fields
605 .iter()
606 .filter(|f| f.skip)
607 .map(|f| {
608 let field_name = &f.name;
609 quote::quote! { #field_name: Default::default() }
610 })
611 .collect();
612
613 let relationship_fields: Vec<_> = model
615 .fields
616 .iter()
617 .filter(|f| f.relationship.is_some())
618 .map(|f| {
619 let field_name = &f.name;
620 quote::quote! { #field_name: Default::default() }
621 })
622 .collect();
623
624 let parent_fields: Vec<_> = model
626 .fields
627 .iter()
628 .filter(|f| f.parent)
629 .map(|f| {
630 let field_name = &f.name;
631 let ty = &f.ty;
632 quote::quote! {
633 #field_name: {
634 let inh = <Self as sqlmodel_core::Model>::inheritance();
635 let parent_table = inh.parent.ok_or_else(|| {
636 sqlmodel_core::Error::Custom(
637 "joined inheritance parent_table missing in inheritance metadata".to_string(),
638 )
639 })?;
640 if !row.has_prefix(parent_table) {
641 return Err(sqlmodel_core::Error::Custom(format!(
642 "expected prefixed parent columns for joined inheritance: {}__*",
643 parent_table
644 )));
645 }
646 let prow = row.subset_by_prefix(parent_table);
647 <#ty as sqlmodel_core::Model>::from_row(&prow)?
648 }
649 }
650 })
651 .collect();
652
653 let computed_fields: Vec<_> = model
655 .computed_fields()
656 .iter()
657 .map(|f| {
658 let field_name = &f.name;
659 quote::quote! { #field_name: Default::default() }
660 })
661 .collect();
662
663 quote::quote! {
664 let #row_ident = if row.has_prefix(<Self as sqlmodel_core::Model>::TABLE_NAME) {
665 row.subset_by_prefix(<Self as sqlmodel_core::Model>::TABLE_NAME)
666 } else {
667 row.clone()
668 };
669
670 Ok(#name {
671 #(#field_extractions,)*
672 #(#skipped_fields,)*
673 #(#relationship_fields,)*
674 #(#parent_fields,)*
675 #(#computed_fields,)*
676 })
677 }
678}
679
680fn generate_primary_key_value(model: &ModelDef) -> proc_macro2::TokenStream {
682 let pk_fields = model.primary_key_fields();
683
684 if pk_fields.is_empty() {
685 let id_field = model.fields.iter().find(|f| f.name == "id");
687 if let Some(field) = id_field {
688 let field_name = &field.name;
689 if parse::is_option_type(&field.ty) {
690 return quote::quote! {
691 match &self.#field_name {
692 Some(v) => vec![sqlmodel_core::Value::from(v.clone())],
693 None => vec![sqlmodel_core::Value::Null],
694 }
695 };
696 }
697 return quote::quote! {
698 vec![sqlmodel_core::Value::from(self.#field_name.clone())]
699 };
700 }
701 return quote::quote! { vec![] };
702 }
703
704 let mut value_exprs = Vec::new();
705 for field in pk_fields {
706 let field_name = &field.name;
707 if parse::is_option_type(&field.ty) {
708 value_exprs.push(quote::quote! {
709 match &self.#field_name {
710 Some(v) => sqlmodel_core::Value::from(v.clone()),
711 None => sqlmodel_core::Value::Null,
712 }
713 });
714 } else {
715 value_exprs.push(quote::quote! {
716 sqlmodel_core::Value::from(self.#field_name.clone())
717 });
718 }
719 }
720
721 quote::quote! {
722 vec![#(#value_exprs),*]
723 }
724}
725
726fn generate_is_new(model: &ModelDef) -> proc_macro2::TokenStream {
728 let pk_fields = model.primary_key_fields();
729
730 for field in &pk_fields {
733 if field.auto_increment && parse::is_option_type(&field.ty) {
734 let field_name = &field.name;
735 return quote::quote! {
736 self.#field_name.is_none()
737 };
738 }
739 }
740
741 if let Some(id_field) = model.fields.iter().find(|f| f.name == "id") {
743 if parse::is_option_type(&id_field.ty) {
744 return quote::quote! {
745 self.id.is_none()
746 };
747 }
748 }
749
750 quote::quote! { true }
752}
753
754fn generate_model_config(model: &ModelDef) -> proc_macro2::TokenStream {
756 let config = &model.config;
757
758 let table = config.table;
759 let from_attributes = config.from_attributes;
760 let validate_assignment = config.validate_assignment;
761 let strict = config.strict;
762 let populate_by_name = config.populate_by_name;
763 let use_enum_values = config.use_enum_values;
764 let arbitrary_types_allowed = config.arbitrary_types_allowed;
765 let defer_build = config.defer_build;
766 let revalidate_instances = config.revalidate_instances;
767
768 let extra_ts = match config.extra.as_str() {
770 "forbid" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Forbid },
771 "allow" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Allow },
772 _ => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Ignore },
773 };
774
775 let json_schema_extra_ts = if let Some(ref extra) = config.json_schema_extra {
777 quote::quote! { Some(#extra) }
778 } else {
779 quote::quote! { None }
780 };
781
782 let title_ts = if let Some(ref title) = config.title {
783 quote::quote! { Some(#title) }
784 } else {
785 quote::quote! { None }
786 };
787
788 quote::quote! {
789 sqlmodel_core::ModelConfig {
790 table: #table,
791 from_attributes: #from_attributes,
792 validate_assignment: #validate_assignment,
793 extra: #extra_ts,
794 strict: #strict,
795 populate_by_name: #populate_by_name,
796 use_enum_values: #use_enum_values,
797 arbitrary_types_allowed: #arbitrary_types_allowed,
798 defer_build: #defer_build,
799 revalidate_instances: #revalidate_instances,
800 json_schema_extra: #json_schema_extra_ts,
801 title: #title_ts,
802 }
803 }
804}
805
806fn generate_inheritance(model: &ModelDef) -> proc_macro2::TokenStream {
808 use crate::parse::InheritanceStrategy;
809
810 let config = &model.config;
811
812 let strategy_ts = match config.inheritance {
814 InheritanceStrategy::None => {
815 quote::quote! { sqlmodel_core::InheritanceStrategy::None }
816 }
817 InheritanceStrategy::Single => {
818 quote::quote! { sqlmodel_core::InheritanceStrategy::Single }
819 }
820 InheritanceStrategy::Joined => {
821 quote::quote! { sqlmodel_core::InheritanceStrategy::Joined }
822 }
823 InheritanceStrategy::Concrete => {
824 quote::quote! { sqlmodel_core::InheritanceStrategy::Concrete }
825 }
826 };
827
828 let parent_ty_ts: Option<proc_macro2::TokenStream> = config.inherits.as_deref().map(|p| {
832 if let Ok(path) = syn::parse_str::<syn::Path>(p) {
833 quote::quote! { #path }
834 } else {
835 let ident = syn::Ident::new(p, proc_macro2::Span::call_site());
836 quote::quote! { #ident }
837 }
838 });
839
840 let parent_table_ts = if let Some(ref parent_ty) = parent_ty_ts {
842 quote::quote! { Some(<#parent_ty as sqlmodel_core::Model>::TABLE_NAME) }
843 } else {
844 quote::quote! { None }
845 };
846
847 let parent_fields_fn_ts = if let Some(ref parent_ty) = parent_ty_ts {
848 quote::quote! { Some(<#parent_ty as sqlmodel_core::Model>::fields) }
849 } else {
850 quote::quote! { None }
851 };
852
853 let discriminator_column_ts = if let Some(ref column) = config.discriminator_column {
859 quote::quote! { Some(#column) }
860 } else if config.discriminator_value.is_some() {
861 if let Some(parent_ty) = parent_ty_ts.as_ref() {
862 quote::quote! { <#parent_ty as sqlmodel_core::Model>::inheritance().discriminator_column }
863 } else {
864 quote::quote! { None }
865 }
866 } else {
867 quote::quote! { None }
868 };
869
870 let discriminator_value_ts = if let Some(ref value) = config.discriminator_value {
872 quote::quote! { Some(#value) }
873 } else {
874 quote::quote! { None }
875 };
876
877 quote::quote! {
878 sqlmodel_core::InheritanceInfo {
879 strategy: #strategy_ts,
880 parent: #parent_table_ts,
881 parent_fields_fn: #parent_fields_fn_ts,
882 discriminator_column: #discriminator_column_ts,
883 discriminator_value: #discriminator_value_ts,
884 }
885 }
886}
887
888fn generate_shard_key(model: &ModelDef) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
894 let config = &model.config;
895
896 if let Some(ref shard_key_name) = config.shard_key {
897 let shard_field = model.fields.iter().find(|f| f.name == shard_key_name);
899
900 let const_ts = quote::quote! { Some(#shard_key_name) };
901
902 let value_body = if let Some(field) = shard_field {
904 let field_ident = &field.name;
905 if parse::is_option_type(&field.ty) {
906 quote::quote! {
908 match &self.#field_ident {
909 Some(v) => Some(sqlmodel_core::Value::from(v.clone())),
910 None => None,
911 }
912 }
913 } else {
914 quote::quote! {
916 Some(sqlmodel_core::Value::from(self.#field_ident.clone()))
917 }
918 }
919 } else {
920 quote::quote! { None }
923 };
924
925 (const_ts, value_body)
926 } else {
927 let const_ts = quote::quote! { None };
929 let value_body = quote::quote! { None };
930 (const_ts, value_body)
931 }
932}
933
934fn generate_debug_impl(model: &ModelDef) -> proc_macro2::TokenStream {
939 let has_hidden_fields = model.fields.iter().any(|f| !f.repr);
941
942 if !has_hidden_fields {
944 return quote::quote! {};
945 }
946
947 let name = &model.name;
948 let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
949
950 let debug_fields: Vec<_> = model
952 .fields
953 .iter()
954 .filter(|f| f.repr) .map(|f| {
956 let field_name = &f.name;
957 let field_name_str = field_name.to_string();
958 quote::quote! {
959 .field(#field_name_str, &self.#field_name)
960 }
961 })
962 .collect();
963
964 let struct_name_str = name.to_string();
965
966 quote::quote! {
967 impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause {
968 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
969 f.debug_struct(#struct_name_str)
970 #(#debug_fields)*
971 .finish()
972 }
973 }
974 }
975}
976
977fn generate_relationships(model: &ModelDef) -> proc_macro2::TokenStream {
979 fn relationship_inner_model_ty(ty: &syn::Type) -> Option<syn::Type> {
980 let syn::Type::Path(tp) = ty else {
981 return None;
982 };
983
984 let last = tp.path.segments.last()?;
985 let ident = last.ident.to_string();
986 if ident != "Related" && ident != "RelatedMany" && ident != "Lazy" {
987 return None;
988 }
989
990 let syn::PathArguments::AngleBracketed(args) = &last.arguments else {
991 return None;
992 };
993
994 args.args.iter().find_map(|arg| match arg {
995 syn::GenericArgument::Type(t) => Some(t.clone()),
996 _ => None,
997 })
998 }
999
1000 let relationship_fields = model.relationship_fields();
1001
1002 if relationship_fields.is_empty() {
1003 return quote::quote! { &[] };
1004 }
1005
1006 let mut relationship_ts = Vec::new();
1007
1008 for field in relationship_fields {
1009 let Some(rel) = field.relationship.as_ref() else {
1010 relationship_ts.push(quote::quote! {
1011 ::core::compile_error!(
1012 "sqlmodel: internal error: relationship field missing parsed relationship metadata"
1013 )
1014 });
1015 continue;
1016 };
1017 let field_name = &field.name;
1018 let related_table = &rel.model;
1019
1020 let Some(related_ty) = relationship_inner_model_ty(&field.ty) else {
1021 relationship_ts.push(quote::quote! {
1022 ::core::compile_error!(
1023 "sqlmodel: relationship field type must be Related<T>, RelatedMany<T>, or Lazy<T>"
1024 )
1025 });
1026 continue;
1027 };
1028
1029 let kind_ts = match rel.kind {
1031 RelationshipKindAttr::OneToOne => {
1032 quote::quote! { sqlmodel_core::RelationshipKind::OneToOne }
1033 }
1034 RelationshipKindAttr::ManyToOne => {
1035 quote::quote! { sqlmodel_core::RelationshipKind::ManyToOne }
1036 }
1037 RelationshipKindAttr::OneToMany => {
1038 quote::quote! { sqlmodel_core::RelationshipKind::OneToMany }
1039 }
1040 RelationshipKindAttr::ManyToMany => {
1041 quote::quote! { sqlmodel_core::RelationshipKind::ManyToMany }
1042 }
1043 };
1044
1045 let local_key_call = if let Some(ref fk) = rel.foreign_key {
1047 quote::quote! { .local_key(#fk) }
1048 } else {
1049 quote::quote! {}
1050 };
1051
1052 let remote_key_call = if let Some(ref rk) = rel.remote_key {
1053 quote::quote! { .remote_key(#rk) }
1054 } else {
1055 quote::quote! {}
1056 };
1057
1058 let back_populates_call = if let Some(ref bp) = rel.back_populates {
1059 quote::quote! { .back_populates(#bp) }
1060 } else {
1061 quote::quote! {}
1062 };
1063
1064 let link_table_call = if let Some(ref lt) = rel.link_table {
1065 let table = <.table;
1066 let local_col = <.local_column;
1067 let remote_col = <.remote_column;
1068 quote::quote! {
1069 .link_table(sqlmodel_core::LinkTableInfo::new(#table, #local_col, #remote_col))
1070 }
1071 } else {
1072 quote::quote! {}
1073 };
1074
1075 let lazy_val = rel.lazy;
1076 let cascade_val = rel.cascade_delete;
1077 let passive_deletes_ts = match rel.passive_deletes {
1078 crate::parse::PassiveDeletesAttr::Active => {
1079 quote::quote! { sqlmodel_core::PassiveDeletes::Active }
1080 }
1081 crate::parse::PassiveDeletesAttr::Passive => {
1082 quote::quote! { sqlmodel_core::PassiveDeletes::Passive }
1083 }
1084 crate::parse::PassiveDeletesAttr::All => {
1085 quote::quote! { sqlmodel_core::PassiveDeletes::All }
1086 }
1087 };
1088
1089 let order_by_call = if let Some(ref ob) = rel.order_by {
1091 quote::quote! { .order_by(#ob) }
1092 } else {
1093 quote::quote! {}
1094 };
1095
1096 let lazy_strategy_call = if let Some(ref strategy) = rel.lazy_strategy {
1097 let strategy_ts = match strategy {
1098 crate::parse::LazyLoadStrategyAttr::Select => {
1099 quote::quote! { sqlmodel_core::LazyLoadStrategy::Select }
1100 }
1101 crate::parse::LazyLoadStrategyAttr::Joined => {
1102 quote::quote! { sqlmodel_core::LazyLoadStrategy::Joined }
1103 }
1104 crate::parse::LazyLoadStrategyAttr::Subquery => {
1105 quote::quote! { sqlmodel_core::LazyLoadStrategy::Subquery }
1106 }
1107 crate::parse::LazyLoadStrategyAttr::Selectin => {
1108 quote::quote! { sqlmodel_core::LazyLoadStrategy::Selectin }
1109 }
1110 crate::parse::LazyLoadStrategyAttr::Dynamic => {
1111 quote::quote! { sqlmodel_core::LazyLoadStrategy::Dynamic }
1112 }
1113 crate::parse::LazyLoadStrategyAttr::NoLoad => {
1114 quote::quote! { sqlmodel_core::LazyLoadStrategy::NoLoad }
1115 }
1116 crate::parse::LazyLoadStrategyAttr::RaiseOnSql => {
1117 quote::quote! { sqlmodel_core::LazyLoadStrategy::RaiseOnSql }
1118 }
1119 crate::parse::LazyLoadStrategyAttr::WriteOnly => {
1120 quote::quote! { sqlmodel_core::LazyLoadStrategy::WriteOnly }
1121 }
1122 };
1123 quote::quote! { .lazy_strategy(#strategy_ts) }
1124 } else {
1125 quote::quote! {}
1126 };
1127
1128 let cascade_call = if let Some(ref c) = rel.cascade {
1129 quote::quote! { .cascade(#c) }
1130 } else {
1131 quote::quote! {}
1132 };
1133
1134 let uselist_call = if let Some(ul) = rel.uselist {
1135 quote::quote! { .uselist(#ul) }
1136 } else {
1137 quote::quote! {}
1138 };
1139
1140 relationship_ts.push(quote::quote! {
1141 sqlmodel_core::RelationshipInfo::new(
1142 stringify!(#field_name),
1143 #related_table,
1144 #kind_ts
1145 )
1146 .related_fields(<#related_ty as sqlmodel_core::Model>::fields)
1147 #local_key_call
1148 #remote_key_call
1149 #back_populates_call
1150 #link_table_call
1151 .lazy(#lazy_val)
1152 .cascade_delete(#cascade_val)
1153 .passive_deletes(#passive_deletes_ts)
1154 #order_by_call
1155 #lazy_strategy_call
1156 #cascade_call
1157 #uselist_call
1158 });
1159 }
1160
1161 quote::quote! {
1162 &[#(#relationship_ts),*]
1163 }
1164}
1165
1166#[proc_macro_derive(Validate, attributes(validate))]
1214pub fn derive_validate(input: TokenStream) -> TokenStream {
1215 let input = syn::parse_macro_input!(input as syn::DeriveInput);
1216
1217 let def = match validate_derive::parse_validate(&input) {
1219 Ok(d) => d,
1220 Err(e) => return e.to_compile_error().into(),
1221 };
1222
1223 validate_derive::generate_validate_impl(&def).into()
1225}
1226
1227#[proc_macro_derive(SqlEnum, attributes(sqlmodel))]
1247pub fn derive_sql_enum(input: TokenStream) -> TokenStream {
1248 let input = syn::parse_macro_input!(input as syn::DeriveInput);
1249 match generate_sql_enum_impl(&input) {
1250 Ok(tokens) => tokens.into(),
1251 Err(e) => e.to_compile_error().into(),
1252 }
1253}
1254
1255fn generate_sql_enum_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
1256 let name = &input.ident;
1257 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1258
1259 let syn::Data::Enum(data) = &input.data else {
1260 return Err(syn::Error::new_spanned(
1261 input,
1262 "SqlEnum can only be derived for enums",
1263 ));
1264 };
1265
1266 let mut variant_names = Vec::new();
1268 let mut variant_strings = Vec::new();
1269
1270 for variant in &data.variants {
1271 if !variant.fields.is_empty() {
1272 return Err(syn::Error::new_spanned(
1273 variant,
1274 "SqlEnum variants must be unit variants (no fields)",
1275 ));
1276 }
1277
1278 let ident = &variant.ident;
1279 variant_names.push(ident.clone());
1280
1281 let mut custom_name = None;
1283 for attr in &variant.attrs {
1284 if attr.path().is_ident("sqlmodel") {
1285 attr.parse_nested_meta(|meta| {
1286 if meta.path.is_ident("rename") {
1287 let value = meta.value()?;
1288 let s: syn::LitStr = value.parse()?;
1289 custom_name = Some(s.value());
1290 }
1291 Ok(())
1292 })?;
1293 }
1294 }
1295
1296 let sql_str = custom_name.unwrap_or_else(|| to_snake_case(&ident.to_string()));
1297 variant_strings.push(sql_str);
1298 }
1299
1300 let type_name = to_snake_case(&name.to_string());
1301
1302 let variant_str_refs: Vec<_> = variant_strings.iter().map(|s| s.as_str()).collect();
1304
1305 let to_sql_arms: Vec<_> = variant_names
1306 .iter()
1307 .zip(variant_strings.iter())
1308 .map(|(ident, s)| {
1309 quote::quote! { #name::#ident => #s }
1310 })
1311 .collect();
1312
1313 let from_sql_arms: Vec<_> = variant_names
1314 .iter()
1315 .zip(variant_strings.iter())
1316 .map(|(ident, s)| {
1317 quote::quote! { #s => Ok(#name::#ident) }
1318 })
1319 .collect();
1320
1321 let valid_values: String = variant_strings
1323 .iter()
1324 .map(|s| format!("'{}'", s))
1325 .collect::<Vec<_>>()
1326 .join(", ");
1327 let error_msg = format!(
1328 "invalid value for {}: expected one of {}",
1329 name, valid_values
1330 );
1331
1332 Ok(quote::quote! {
1333 impl #impl_generics sqlmodel_core::SqlEnum for #name #ty_generics #where_clause {
1334 const VARIANTS: &'static [&'static str] = &[#(#variant_str_refs),*];
1335 const TYPE_NAME: &'static str = #type_name;
1336
1337 fn to_sql_str(&self) -> &'static str {
1338 match self {
1339 #(#to_sql_arms,)*
1340 }
1341 }
1342
1343 fn from_sql_str(s: &str) -> Result<Self, String> {
1344 match s {
1345 #(#from_sql_arms,)*
1346 _ => Err(format!("{}, got '{}'", #error_msg, s)),
1347 }
1348 }
1349 }
1350
1351 impl #impl_generics From<#name #ty_generics> for sqlmodel_core::Value #where_clause {
1352 fn from(v: #name #ty_generics) -> Self {
1353 sqlmodel_core::Value::Text(
1354 sqlmodel_core::SqlEnum::to_sql_str(&v).to_string()
1355 )
1356 }
1357 }
1358
1359 impl #impl_generics From<&#name #ty_generics> for sqlmodel_core::Value #where_clause {
1360 fn from(v: &#name #ty_generics) -> Self {
1361 sqlmodel_core::Value::Text(
1362 sqlmodel_core::SqlEnum::to_sql_str(v).to_string()
1363 )
1364 }
1365 }
1366
1367 impl #impl_generics TryFrom<sqlmodel_core::Value> for #name #ty_generics #where_clause {
1368 type Error = sqlmodel_core::Error;
1369
1370 fn try_from(value: sqlmodel_core::Value) -> Result<Self, Self::Error> {
1371 match value {
1372 sqlmodel_core::Value::Text(ref s) => {
1373 sqlmodel_core::SqlEnum::from_sql_str(s.as_str()).map_err(|e| {
1374 sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1375 expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1376 actual: e,
1377 column: None,
1378 rust_type: None,
1379 })
1380 })
1381 }
1382 other => Err(sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1383 expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1384 actual: other.type_name().to_string(),
1385 column: None,
1386 rust_type: None,
1387 })),
1388 }
1389 }
1390 }
1391
1392 impl #impl_generics ::core::fmt::Display for #name #ty_generics #where_clause {
1393 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1394 f.write_str(sqlmodel_core::SqlEnum::to_sql_str(self))
1395 }
1396 }
1397
1398 impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
1399 type Err = String;
1400
1401 fn from_str(s: &str) -> Result<Self, Self::Err> {
1402 sqlmodel_core::SqlEnum::from_sql_str(s)
1403 }
1404 }
1405 })
1406}
1407
1408fn to_snake_case(s: &str) -> String {
1409 let mut result = String::with_capacity(s.len() + 4);
1410 let chars: Vec<char> = s.chars().collect();
1411 for (i, &ch) in chars.iter().enumerate() {
1412 if ch.is_uppercase() {
1413 if i > 0 {
1414 let prev_lower = chars[i - 1].is_lowercase();
1415 let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
1416 if prev_lower || (next_lower && chars[i - 1].is_uppercase()) {
1419 result.push('_');
1420 }
1421 }
1422 result.push(ch.to_ascii_lowercase());
1423 } else {
1424 result.push(ch);
1425 }
1426 }
1427 result
1428}
1429
1430#[proc_macro_attribute]
1441pub fn query(_attr: TokenStream, item: TokenStream) -> TokenStream {
1442 let original = item.clone();
1443
1444 let func: syn::ItemFn = match syn::parse(item) {
1445 Ok(f) => f,
1446 Err(e) => return e.to_compile_error().into(),
1447 };
1448
1449 if func.sig.asyncness.is_none() {
1450 return syn::Error::new_spanned(
1451 func.sig.fn_token,
1452 "#[sqlmodel::query] requires an async fn",
1453 )
1454 .to_compile_error()
1455 .into();
1456 }
1457
1458 let Some(first_arg) = func.sig.inputs.first() else {
1459 return syn::Error::new_spanned(
1460 &func.sig.ident,
1461 "#[sqlmodel::query] requires the first parameter to be `cx: &Cx`",
1462 )
1463 .to_compile_error()
1464 .into();
1465 };
1466
1467 let first_ty = match first_arg {
1468 syn::FnArg::Typed(pat_ty) => &*pat_ty.ty,
1469 syn::FnArg::Receiver(recv) => {
1470 return syn::Error::new_spanned(
1471 recv,
1472 "#[sqlmodel::query] does not support methods; use a free function",
1473 )
1474 .to_compile_error()
1475 .into();
1476 }
1477 };
1478
1479 if !is_ref_to_cx(first_ty) {
1480 return syn::Error::new_spanned(
1481 first_ty,
1482 "#[sqlmodel::query] requires the first parameter to be `cx: &Cx`",
1483 )
1484 .to_compile_error()
1485 .into();
1486 }
1487
1488 original
1489}
1490
1491fn is_ref_to_cx(ty: &syn::Type) -> bool {
1492 let syn::Type::Reference(r) = ty else {
1493 return false;
1494 };
1495 is_cx_path(&r.elem)
1496}
1497
1498fn is_cx_path(ty: &syn::Type) -> bool {
1499 let syn::Type::Path(p) = ty else {
1500 return false;
1501 };
1502 p.path.segments.last().is_some_and(|seg| seg.ident == "Cx")
1503}