1use proc_macro::TokenStream;
17use syn::ext::IdentExt;
18
19mod infer;
20mod parse;
21mod validate;
22mod validate_derive;
23
24use parse::{ModelDef, RelationshipKindAttr, parse_model};
25
26#[proc_macro_derive(Model, attributes(sqlmodel))]
71pub fn derive_model(input: TokenStream) -> TokenStream {
72 let input = syn::parse_macro_input!(input as syn::DeriveInput);
73
74 let model = match parse_model(&input) {
76 Ok(m) => m,
77 Err(e) => return e.to_compile_error().into(),
78 };
79
80 if let Err(e) = validate::validate_model(&model) {
82 return e.to_compile_error().into();
83 }
84
85 generate_model_impl(&model).into()
87}
88
89fn generate_model_impl(model: &ModelDef) -> proc_macro2::TokenStream {
91 let name = &model.name;
92 let table_name = &model.table_name;
93 let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
94
95 let pk_fields: Vec<&str> = model
97 .primary_key_fields()
98 .iter()
99 .map(|f| f.column_name.as_str())
100 .collect();
101 let pk_field_names: Vec<_> = pk_fields.clone();
102
103 let pk_slice = if pk_field_names.is_empty() {
105 let has_id_field = model.fields.iter().any(|f| f.name == "id" && !f.skip);
107 if has_id_field {
108 quote::quote! { &["id"] }
109 } else {
110 quote::quote! { &[] }
111 }
112 } else {
113 quote::quote! { &[#(#pk_field_names),*] }
114 };
115
116 let field_infos = generate_field_infos(model);
118
119 let relationships = generate_relationships(model);
121
122 let to_row_body = generate_to_row(model);
124
125 let from_row_body = generate_from_row(model);
127
128 let pk_value_body = generate_primary_key_value(model);
130
131 let is_new_body = generate_is_new(model);
133
134 let model_config_body = generate_model_config(model);
136
137 let inheritance_body = generate_inheritance(model);
139
140 let (shard_key_const, shard_key_value_body) = generate_shard_key(model);
142
143 let debug_impl = generate_debug_impl(model);
145
146 let hybrid_impl = generate_hybrid_methods(model);
148
149 quote::quote! {
150 impl #impl_generics sqlmodel_core::Model for #name #ty_generics #where_clause {
151 const TABLE_NAME: &'static str = #table_name;
152 const PRIMARY_KEY: &'static [&'static str] = #pk_slice;
153 const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] = #relationships;
154 const SHARD_KEY: Option<&'static str> = #shard_key_const;
155
156 fn fields() -> &'static [sqlmodel_core::FieldInfo] {
157 static FIELDS: &[sqlmodel_core::FieldInfo] = &[
158 #field_infos
159 ];
160 FIELDS
161 }
162
163 fn to_row(&self) -> Vec<(&'static str, sqlmodel_core::Value)> {
164 #to_row_body
165 }
166
167 fn from_row(row: &sqlmodel_core::Row) -> sqlmodel_core::Result<Self> {
168 #from_row_body
169 }
170
171 fn primary_key_value(&self) -> Vec<sqlmodel_core::Value> {
172 #pk_value_body
173 }
174
175 fn is_new(&self) -> bool {
176 #is_new_body
177 }
178
179 fn model_config() -> sqlmodel_core::ModelConfig {
180 #model_config_body
181 }
182
183 fn inheritance() -> sqlmodel_core::InheritanceInfo {
184 #inheritance_body
185 }
186
187 fn shard_key_value(&self) -> Option<sqlmodel_core::Value> {
188 #shard_key_value_body
189 }
190 }
191
192 #debug_impl
193
194 #hybrid_impl
195 }
196}
197
198fn generate_hybrid_methods(model: &ModelDef) -> proc_macro2::TokenStream {
204 let hybrid_fields: Vec<_> = model
205 .fields
206 .iter()
207 .filter(|f| f.hybrid && f.hybrid_sql.is_some())
208 .collect();
209
210 if hybrid_fields.is_empty() {
211 return quote::quote! {};
212 }
213
214 let name = &model.name;
215 let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
216
217 let methods: Vec<_> = hybrid_fields
218 .iter()
219 .map(|field| {
220 let sql = field.hybrid_sql.as_ref().unwrap();
221 let method_name = quote::format_ident!("{}_expr", field.name);
222 let doc = format!(
223 "SQL expression for the `{}` hybrid property.\n\nGenerates: `{}`",
224 field.name, sql
225 );
226 quote::quote! {
227 #[doc = #doc]
228 pub fn #method_name() -> sqlmodel_query::Expr {
229 sqlmodel_query::Expr::raw(#sql)
230 }
231 }
232 })
233 .collect();
234
235 quote::quote! {
236 impl #impl_generics #name #ty_generics #where_clause {
237 #(#methods)*
238 }
239 }
240}
241
242fn referential_action_token(action: &str) -> proc_macro2::TokenStream {
244 match action.to_uppercase().as_str() {
245 "NO ACTION" | "NOACTION" | "NO_ACTION" => {
246 quote::quote! { sqlmodel_core::ReferentialAction::NoAction }
247 }
248 "RESTRICT" => quote::quote! { sqlmodel_core::ReferentialAction::Restrict },
249 "CASCADE" => quote::quote! { sqlmodel_core::ReferentialAction::Cascade },
250 "SET NULL" | "SETNULL" | "SET_NULL" => {
251 quote::quote! { sqlmodel_core::ReferentialAction::SetNull }
252 }
253 "SET DEFAULT" | "SETDEFAULT" | "SET_DEFAULT" => {
254 quote::quote! { sqlmodel_core::ReferentialAction::SetDefault }
255 }
256 _ => quote::quote! { sqlmodel_core::ReferentialAction::NoAction },
257 }
258}
259
260fn generate_field_infos(model: &ModelDef) -> proc_macro2::TokenStream {
262 let mut field_tokens = Vec::new();
263
264 for field in model.data_fields() {
266 let field_ident = field.name.unraw();
267 let column_name = &field.column_name;
268 let primary_key = field.primary_key;
269 let auto_increment = field.auto_increment;
270
271 let sa_col = field.sa_column.as_ref();
273
274 let nullable = sa_col.and_then(|sc| sc.nullable).unwrap_or(field.nullable);
276
277 let unique = sa_col.and_then(|sc| sc.unique).unwrap_or(field.unique);
279
280 let effective_sql_type = sa_col
282 .and_then(|sc| sc.sql_type.as_ref())
283 .or(field.sql_type.as_ref());
284 let sql_type_token = if let Some(sql_type_str) = effective_sql_type {
285 infer::parse_sql_type_attr(sql_type_str)
287 } else {
288 infer::infer_sql_type(&field.ty)
290 };
291
292 let sql_type_override_token = if let Some(sql_type_str) = effective_sql_type {
294 quote::quote! { Some(#sql_type_str) }
295 } else {
296 quote::quote! { None }
297 };
298
299 let effective_default = sa_col
301 .and_then(|sc| sc.server_default.as_ref())
302 .or(field.default.as_ref());
303 let default_token = if let Some(d) = effective_default {
304 quote::quote! { Some(#d) }
305 } else {
306 quote::quote! { None }
307 };
308
309 let fk_token = if let Some(fk) = &field.foreign_key {
311 quote::quote! { Some(#fk) }
312 } else {
313 quote::quote! { None }
314 };
315
316 let effective_index = sa_col
318 .and_then(|sc| sc.index.as_ref())
319 .or(field.index.as_ref());
320 let index_token = if let Some(idx) = effective_index {
321 quote::quote! { Some(#idx) }
322 } else {
323 quote::quote! { None }
324 };
325
326 let on_delete_token = if let Some(ref action) = field.on_delete {
328 let action_token = referential_action_token(action);
329 quote::quote! { Some(#action_token) }
330 } else {
331 quote::quote! { None }
332 };
333
334 let on_update_token = if let Some(ref action) = field.on_update {
336 let action_token = referential_action_token(action);
337 quote::quote! { Some(#action_token) }
338 } else {
339 quote::quote! { None }
340 };
341
342 let alias_token = if let Some(ref alias) = field.alias {
344 quote::quote! { Some(#alias) }
345 } else {
346 quote::quote! { None }
347 };
348
349 let validation_alias_token = if let Some(ref val_alias) = field.validation_alias {
350 quote::quote! { Some(#val_alias) }
351 } else {
352 quote::quote! { None }
353 };
354
355 let serialization_alias_token = if let Some(ref ser_alias) = field.serialization_alias {
356 quote::quote! { Some(#ser_alias) }
357 } else {
358 quote::quote! { None }
359 };
360
361 let computed = field.computed;
362 let exclude = field.exclude;
363
364 let title_token = if let Some(ref title) = field.title {
366 quote::quote! { Some(#title) }
367 } else {
368 quote::quote! { None }
369 };
370
371 let description_token = if let Some(ref desc) = field.description {
372 quote::quote! { Some(#desc) }
373 } else {
374 quote::quote! { None }
375 };
376
377 let schema_extra_token = if let Some(ref extra) = field.schema_extra {
378 quote::quote! { Some(#extra) }
379 } else {
380 quote::quote! { None }
381 };
382
383 let default_json_token = if let Some(ref dj) = field.default_json {
385 quote::quote! { Some(#dj) }
386 } else {
387 quote::quote! { None }
388 };
389
390 let const_field = field.const_field;
392
393 let effective_constraints: Vec<&String> = if let Some(sc) = sa_col {
396 sc.check.iter().collect()
397 } else {
398 field.column_constraints.iter().collect()
399 };
400 let column_constraints_token = if effective_constraints.is_empty() {
401 quote::quote! { &[] }
402 } else {
403 quote::quote! { &[#(#effective_constraints),*] }
404 };
405
406 let effective_comment = sa_col
409 .and_then(|sc| sc.comment.as_ref())
410 .or(field.column_comment.as_ref());
411 let column_comment_token = if let Some(comment) = effective_comment {
412 quote::quote! { Some(#comment) }
413 } else {
414 quote::quote! { None }
415 };
416
417 let column_info_token = if let Some(ref info) = field.column_info {
419 quote::quote! { Some(#info) }
420 } else {
421 quote::quote! { None }
422 };
423
424 let hybrid_sql_token = if let Some(ref sql) = field.hybrid_sql {
426 quote::quote! { Some(#sql) }
427 } else {
428 quote::quote! { None }
429 };
430
431 let discriminator_token = if let Some(ref disc) = field.discriminator {
433 quote::quote! { Some(#disc) }
434 } else {
435 quote::quote! { None }
436 };
437
438 let precision_token = if let Some(p) = field.max_digits {
440 quote::quote! { Some(#p) }
441 } else {
442 quote::quote! { None }
443 };
444
445 let scale_token = if let Some(s) = field.decimal_places {
446 quote::quote! { Some(#s) }
447 } else {
448 quote::quote! { None }
449 };
450
451 field_tokens.push(quote::quote! {
452 sqlmodel_core::FieldInfo::new(stringify!(#field_ident), #column_name, #sql_type_token)
453 .sql_type_override_opt(#sql_type_override_token)
454 .precision_opt(#precision_token)
455 .scale_opt(#scale_token)
456 .nullable(#nullable)
457 .primary_key(#primary_key)
458 .auto_increment(#auto_increment)
459 .unique(#unique)
460 .default_opt(#default_token)
461 .foreign_key_opt(#fk_token)
462 .on_delete_opt(#on_delete_token)
463 .on_update_opt(#on_update_token)
464 .index_opt(#index_token)
465 .alias_opt(#alias_token)
466 .validation_alias_opt(#validation_alias_token)
467 .serialization_alias_opt(#serialization_alias_token)
468 .computed(#computed)
469 .exclude(#exclude)
470 .title_opt(#title_token)
471 .description_opt(#description_token)
472 .schema_extra_opt(#schema_extra_token)
473 .default_json_opt(#default_json_token)
474 .const_field(#const_field)
475 .column_constraints(#column_constraints_token)
476 .column_comment_opt(#column_comment_token)
477 .column_info_opt(#column_info_token)
478 .hybrid_sql_opt(#hybrid_sql_token)
479 .discriminator_opt(#discriminator_token)
480 });
481 }
482
483 quote::quote! { #(#field_tokens),* }
484}
485
486fn generate_to_row(model: &ModelDef) -> proc_macro2::TokenStream {
488 let mut conversions = Vec::new();
489
490 for field in model.select_fields() {
491 let field_name = &field.name;
492 let column_name = &field.column_name;
493
494 if parse::is_option_type(&field.ty) {
496 conversions.push(quote::quote! {
497 (#column_name, match &self.#field_name {
498 Some(v) => sqlmodel_core::Value::from(v.clone()),
499 None => sqlmodel_core::Value::Null,
500 })
501 });
502 } else {
503 conversions.push(quote::quote! {
504 (#column_name, sqlmodel_core::Value::from(self.#field_name.clone()))
505 });
506 }
507 }
508
509 quote::quote! {
510 vec![#(#conversions),*]
511 }
512}
513
514fn generate_from_row(model: &ModelDef) -> proc_macro2::TokenStream {
516 let name = &model.name;
517 let mut field_extractions = Vec::new();
518
519 for field in model.select_fields() {
520 let field_name = &field.name;
521 let column_name = &field.column_name;
522
523 if parse::is_option_type(&field.ty) {
524 field_extractions.push(quote::quote! {
526 #field_name: row.get_named(#column_name).ok()
527 });
528 } else {
529 field_extractions.push(quote::quote! {
531 #field_name: row.get_named(#column_name)?
532 });
533 }
534 }
535
536 let skipped_fields: Vec<_> = model
538 .fields
539 .iter()
540 .filter(|f| f.skip)
541 .map(|f| {
542 let field_name = &f.name;
543 quote::quote! { #field_name: Default::default() }
544 })
545 .collect();
546
547 let relationship_fields: Vec<_> = model
549 .fields
550 .iter()
551 .filter(|f| f.relationship.is_some())
552 .map(|f| {
553 let field_name = &f.name;
554 quote::quote! { #field_name: Default::default() }
555 })
556 .collect();
557
558 let computed_fields: Vec<_> = model
560 .computed_fields()
561 .iter()
562 .map(|f| {
563 let field_name = &f.name;
564 quote::quote! { #field_name: Default::default() }
565 })
566 .collect();
567
568 quote::quote! {
569 Ok(#name {
570 #(#field_extractions,)*
571 #(#skipped_fields,)*
572 #(#relationship_fields,)*
573 #(#computed_fields,)*
574 })
575 }
576}
577
578fn generate_primary_key_value(model: &ModelDef) -> proc_macro2::TokenStream {
580 let pk_fields = model.primary_key_fields();
581
582 if pk_fields.is_empty() {
583 let id_field = model.fields.iter().find(|f| f.name == "id");
585 if let Some(field) = id_field {
586 let field_name = &field.name;
587 if parse::is_option_type(&field.ty) {
588 return quote::quote! {
589 match &self.#field_name {
590 Some(v) => vec![sqlmodel_core::Value::from(v.clone())],
591 None => vec![sqlmodel_core::Value::Null],
592 }
593 };
594 }
595 return quote::quote! {
596 vec![sqlmodel_core::Value::from(self.#field_name.clone())]
597 };
598 }
599 return quote::quote! { vec![] };
600 }
601
602 let mut value_exprs = Vec::new();
603 for field in pk_fields {
604 let field_name = &field.name;
605 if parse::is_option_type(&field.ty) {
606 value_exprs.push(quote::quote! {
607 match &self.#field_name {
608 Some(v) => sqlmodel_core::Value::from(v.clone()),
609 None => sqlmodel_core::Value::Null,
610 }
611 });
612 } else {
613 value_exprs.push(quote::quote! {
614 sqlmodel_core::Value::from(self.#field_name.clone())
615 });
616 }
617 }
618
619 quote::quote! {
620 vec![#(#value_exprs),*]
621 }
622}
623
624fn generate_is_new(model: &ModelDef) -> proc_macro2::TokenStream {
626 let pk_fields = model.primary_key_fields();
627
628 for field in &pk_fields {
631 if field.auto_increment && parse::is_option_type(&field.ty) {
632 let field_name = &field.name;
633 return quote::quote! {
634 self.#field_name.is_none()
635 };
636 }
637 }
638
639 if let Some(id_field) = model.fields.iter().find(|f| f.name == "id") {
641 if parse::is_option_type(&id_field.ty) {
642 return quote::quote! {
643 self.id.is_none()
644 };
645 }
646 }
647
648 quote::quote! { true }
650}
651
652fn generate_model_config(model: &ModelDef) -> proc_macro2::TokenStream {
654 let config = &model.config;
655
656 let table = config.table;
657 let from_attributes = config.from_attributes;
658 let validate_assignment = config.validate_assignment;
659 let strict = config.strict;
660 let populate_by_name = config.populate_by_name;
661 let use_enum_values = config.use_enum_values;
662 let arbitrary_types_allowed = config.arbitrary_types_allowed;
663 let defer_build = config.defer_build;
664 let revalidate_instances = config.revalidate_instances;
665
666 let extra_token = match config.extra.as_str() {
668 "forbid" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Forbid },
669 "allow" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Allow },
670 _ => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Ignore },
671 };
672
673 let json_schema_extra_token = if let Some(ref extra) = config.json_schema_extra {
675 quote::quote! { Some(#extra) }
676 } else {
677 quote::quote! { None }
678 };
679
680 let title_token = if let Some(ref title) = config.title {
681 quote::quote! { Some(#title) }
682 } else {
683 quote::quote! { None }
684 };
685
686 quote::quote! {
687 sqlmodel_core::ModelConfig {
688 table: #table,
689 from_attributes: #from_attributes,
690 validate_assignment: #validate_assignment,
691 extra: #extra_token,
692 strict: #strict,
693 populate_by_name: #populate_by_name,
694 use_enum_values: #use_enum_values,
695 arbitrary_types_allowed: #arbitrary_types_allowed,
696 defer_build: #defer_build,
697 revalidate_instances: #revalidate_instances,
698 json_schema_extra: #json_schema_extra_token,
699 title: #title_token,
700 }
701 }
702}
703
704fn generate_inheritance(model: &ModelDef) -> proc_macro2::TokenStream {
706 use crate::parse::InheritanceStrategy;
707
708 let config = &model.config;
709
710 let strategy_token = match config.inheritance {
712 InheritanceStrategy::None => {
713 quote::quote! { sqlmodel_core::InheritanceStrategy::None }
714 }
715 InheritanceStrategy::Single => {
716 quote::quote! { sqlmodel_core::InheritanceStrategy::Single }
717 }
718 InheritanceStrategy::Joined => {
719 quote::quote! { sqlmodel_core::InheritanceStrategy::Joined }
720 }
721 InheritanceStrategy::Concrete => {
722 quote::quote! { sqlmodel_core::InheritanceStrategy::Concrete }
723 }
724 };
725
726 let parent_token = if let Some(ref parent) = config.inherits {
728 quote::quote! { Some(#parent) }
729 } else {
730 quote::quote! { None }
731 };
732
733 let discriminator_column_token = if let Some(ref column) = config.discriminator_column {
735 quote::quote! { Some(#column) }
736 } else {
737 quote::quote! { None }
738 };
739
740 let discriminator_value_token = if let Some(ref value) = config.discriminator_value {
742 quote::quote! { Some(#value) }
743 } else {
744 quote::quote! { None }
745 };
746
747 quote::quote! {
748 sqlmodel_core::InheritanceInfo {
749 strategy: #strategy_token,
750 parent: #parent_token,
751 discriminator_column: #discriminator_column_token,
752 discriminator_value: #discriminator_value_token,
753 }
754 }
755}
756
757fn generate_shard_key(model: &ModelDef) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
763 let config = &model.config;
764
765 if let Some(ref shard_key_name) = config.shard_key {
766 let shard_field = model.fields.iter().find(|f| f.name == shard_key_name);
768
769 let const_token = quote::quote! { Some(#shard_key_name) };
770
771 let value_body = if let Some(field) = shard_field {
773 let field_ident = &field.name;
774 if parse::is_option_type(&field.ty) {
775 quote::quote! {
777 match &self.#field_ident {
778 Some(v) => Some(sqlmodel_core::Value::from(v.clone())),
779 None => None,
780 }
781 }
782 } else {
783 quote::quote! {
785 Some(sqlmodel_core::Value::from(self.#field_ident.clone()))
786 }
787 }
788 } else {
789 quote::quote! { None }
792 };
793
794 (const_token, value_body)
795 } else {
796 let const_token = quote::quote! { None };
798 let value_body = quote::quote! { None };
799 (const_token, value_body)
800 }
801}
802
803fn generate_debug_impl(model: &ModelDef) -> proc_macro2::TokenStream {
808 let has_hidden_fields = model.fields.iter().any(|f| !f.repr);
810
811 if !has_hidden_fields {
813 return quote::quote! {};
814 }
815
816 let name = &model.name;
817 let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
818
819 let debug_fields: Vec<_> = model
821 .fields
822 .iter()
823 .filter(|f| f.repr) .map(|f| {
825 let field_name = &f.name;
826 let field_name_str = field_name.to_string();
827 quote::quote! {
828 .field(#field_name_str, &self.#field_name)
829 }
830 })
831 .collect();
832
833 let struct_name_str = name.to_string();
834
835 quote::quote! {
836 impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause {
837 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
838 f.debug_struct(#struct_name_str)
839 #(#debug_fields)*
840 .finish()
841 }
842 }
843 }
844}
845
846fn generate_relationships(model: &ModelDef) -> proc_macro2::TokenStream {
848 let relationship_fields = model.relationship_fields();
849
850 if relationship_fields.is_empty() {
851 return quote::quote! { &[] };
852 }
853
854 let mut relationship_tokens = Vec::new();
855
856 for field in relationship_fields {
857 let rel = field.relationship.as_ref().unwrap();
858 let field_name = &field.name;
859 let related_table = &rel.model;
860
861 let kind_token = match rel.kind {
863 RelationshipKindAttr::OneToOne => {
864 quote::quote! { sqlmodel_core::RelationshipKind::OneToOne }
865 }
866 RelationshipKindAttr::ManyToOne => {
867 quote::quote! { sqlmodel_core::RelationshipKind::ManyToOne }
868 }
869 RelationshipKindAttr::OneToMany => {
870 quote::quote! { sqlmodel_core::RelationshipKind::OneToMany }
871 }
872 RelationshipKindAttr::ManyToMany => {
873 quote::quote! { sqlmodel_core::RelationshipKind::ManyToMany }
874 }
875 };
876
877 let local_key_call = if let Some(ref fk) = rel.foreign_key {
879 quote::quote! { .local_key(#fk) }
880 } else {
881 quote::quote! {}
882 };
883
884 let remote_key_call = if let Some(ref rk) = rel.remote_key {
885 quote::quote! { .remote_key(#rk) }
886 } else {
887 quote::quote! {}
888 };
889
890 let back_populates_call = if let Some(ref bp) = rel.back_populates {
891 quote::quote! { .back_populates(#bp) }
892 } else {
893 quote::quote! {}
894 };
895
896 let link_table_call = if let Some(ref lt) = rel.link_table {
897 let table = <.table;
898 let local_col = <.local_column;
899 let remote_col = <.remote_column;
900 quote::quote! {
901 .link_table(sqlmodel_core::LinkTableInfo::new(#table, #local_col, #remote_col))
902 }
903 } else {
904 quote::quote! {}
905 };
906
907 let lazy_val = rel.lazy;
908 let cascade_val = rel.cascade_delete;
909 let passive_deletes_token = match rel.passive_deletes {
910 crate::parse::PassiveDeletesAttr::Active => {
911 quote::quote! { sqlmodel_core::PassiveDeletes::Active }
912 }
913 crate::parse::PassiveDeletesAttr::Passive => {
914 quote::quote! { sqlmodel_core::PassiveDeletes::Passive }
915 }
916 crate::parse::PassiveDeletesAttr::All => {
917 quote::quote! { sqlmodel_core::PassiveDeletes::All }
918 }
919 };
920
921 let order_by_call = if let Some(ref ob) = rel.order_by {
923 quote::quote! { .order_by(#ob) }
924 } else {
925 quote::quote! {}
926 };
927
928 let lazy_strategy_call = if let Some(ref strategy) = rel.lazy_strategy {
929 let strategy_token = match strategy {
930 crate::parse::LazyLoadStrategyAttr::Select => {
931 quote::quote! { sqlmodel_core::LazyLoadStrategy::Select }
932 }
933 crate::parse::LazyLoadStrategyAttr::Joined => {
934 quote::quote! { sqlmodel_core::LazyLoadStrategy::Joined }
935 }
936 crate::parse::LazyLoadStrategyAttr::Subquery => {
937 quote::quote! { sqlmodel_core::LazyLoadStrategy::Subquery }
938 }
939 crate::parse::LazyLoadStrategyAttr::Selectin => {
940 quote::quote! { sqlmodel_core::LazyLoadStrategy::Selectin }
941 }
942 crate::parse::LazyLoadStrategyAttr::Dynamic => {
943 quote::quote! { sqlmodel_core::LazyLoadStrategy::Dynamic }
944 }
945 crate::parse::LazyLoadStrategyAttr::NoLoad => {
946 quote::quote! { sqlmodel_core::LazyLoadStrategy::NoLoad }
947 }
948 crate::parse::LazyLoadStrategyAttr::RaiseOnSql => {
949 quote::quote! { sqlmodel_core::LazyLoadStrategy::RaiseOnSql }
950 }
951 crate::parse::LazyLoadStrategyAttr::WriteOnly => {
952 quote::quote! { sqlmodel_core::LazyLoadStrategy::WriteOnly }
953 }
954 };
955 quote::quote! { .lazy_strategy(#strategy_token) }
956 } else {
957 quote::quote! {}
958 };
959
960 let cascade_call = if let Some(ref c) = rel.cascade {
961 quote::quote! { .cascade(#c) }
962 } else {
963 quote::quote! {}
964 };
965
966 let uselist_call = if let Some(ul) = rel.uselist {
967 quote::quote! { .uselist(#ul) }
968 } else {
969 quote::quote! {}
970 };
971
972 relationship_tokens.push(quote::quote! {
973 sqlmodel_core::RelationshipInfo::new(
974 stringify!(#field_name),
975 #related_table,
976 #kind_token
977 )
978 #local_key_call
979 #remote_key_call
980 #back_populates_call
981 #link_table_call
982 .lazy(#lazy_val)
983 .cascade_delete(#cascade_val)
984 .passive_deletes(#passive_deletes_token)
985 #order_by_call
986 #lazy_strategy_call
987 #cascade_call
988 #uselist_call
989 });
990 }
991
992 quote::quote! {
993 &[#(#relationship_tokens),*]
994 }
995}
996
997#[proc_macro_derive(Validate, attributes(validate))]
1045pub fn derive_validate(input: TokenStream) -> TokenStream {
1046 let input = syn::parse_macro_input!(input as syn::DeriveInput);
1047
1048 let def = match validate_derive::parse_validate(&input) {
1050 Ok(d) => d,
1051 Err(e) => return e.to_compile_error().into(),
1052 };
1053
1054 validate_derive::generate_validate_impl(&def).into()
1056}
1057
1058#[proc_macro_derive(SqlEnum, attributes(sqlmodel))]
1078pub fn derive_sql_enum(input: TokenStream) -> TokenStream {
1079 let input = syn::parse_macro_input!(input as syn::DeriveInput);
1080 match generate_sql_enum_impl(&input) {
1081 Ok(tokens) => tokens.into(),
1082 Err(e) => e.to_compile_error().into(),
1083 }
1084}
1085
1086fn generate_sql_enum_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
1087 let name = &input.ident;
1088 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1089
1090 let syn::Data::Enum(data) = &input.data else {
1091 return Err(syn::Error::new_spanned(
1092 input,
1093 "SqlEnum can only be derived for enums",
1094 ));
1095 };
1096
1097 let mut variant_names = Vec::new();
1099 let mut variant_strings = Vec::new();
1100
1101 for variant in &data.variants {
1102 if !variant.fields.is_empty() {
1103 return Err(syn::Error::new_spanned(
1104 variant,
1105 "SqlEnum variants must be unit variants (no fields)",
1106 ));
1107 }
1108
1109 let ident = &variant.ident;
1110 variant_names.push(ident.clone());
1111
1112 let mut custom_name = None;
1114 for attr in &variant.attrs {
1115 if attr.path().is_ident("sqlmodel") {
1116 attr.parse_nested_meta(|meta| {
1117 if meta.path.is_ident("rename") {
1118 let value = meta.value()?;
1119 let s: syn::LitStr = value.parse()?;
1120 custom_name = Some(s.value());
1121 }
1122 Ok(())
1123 })?;
1124 }
1125 }
1126
1127 let sql_str = custom_name.unwrap_or_else(|| to_snake_case(&ident.to_string()));
1128 variant_strings.push(sql_str);
1129 }
1130
1131 let type_name = to_snake_case(&name.to_string());
1132
1133 let variant_str_refs: Vec<_> = variant_strings.iter().map(|s| s.as_str()).collect();
1135
1136 let to_sql_arms: Vec<_> = variant_names
1137 .iter()
1138 .zip(variant_strings.iter())
1139 .map(|(ident, s)| {
1140 quote::quote! { #name::#ident => #s }
1141 })
1142 .collect();
1143
1144 let from_sql_arms: Vec<_> = variant_names
1145 .iter()
1146 .zip(variant_strings.iter())
1147 .map(|(ident, s)| {
1148 quote::quote! { #s => Ok(#name::#ident) }
1149 })
1150 .collect();
1151
1152 let valid_values: String = variant_strings
1154 .iter()
1155 .map(|s| format!("'{}'", s))
1156 .collect::<Vec<_>>()
1157 .join(", ");
1158 let error_msg = format!(
1159 "invalid value for {}: expected one of {}",
1160 name, valid_values
1161 );
1162
1163 Ok(quote::quote! {
1164 impl #impl_generics sqlmodel_core::SqlEnum for #name #ty_generics #where_clause {
1165 const VARIANTS: &'static [&'static str] = &[#(#variant_str_refs),*];
1166 const TYPE_NAME: &'static str = #type_name;
1167
1168 fn to_sql_str(&self) -> &'static str {
1169 match self {
1170 #(#to_sql_arms,)*
1171 }
1172 }
1173
1174 fn from_sql_str(s: &str) -> Result<Self, String> {
1175 match s {
1176 #(#from_sql_arms,)*
1177 _ => Err(format!("{}, got '{}'", #error_msg, s)),
1178 }
1179 }
1180 }
1181
1182 impl #impl_generics From<#name #ty_generics> for sqlmodel_core::Value #where_clause {
1183 fn from(v: #name #ty_generics) -> Self {
1184 sqlmodel_core::Value::Text(
1185 sqlmodel_core::SqlEnum::to_sql_str(&v).to_string()
1186 )
1187 }
1188 }
1189
1190 impl #impl_generics From<&#name #ty_generics> for sqlmodel_core::Value #where_clause {
1191 fn from(v: &#name #ty_generics) -> Self {
1192 sqlmodel_core::Value::Text(
1193 sqlmodel_core::SqlEnum::to_sql_str(v).to_string()
1194 )
1195 }
1196 }
1197
1198 impl #impl_generics TryFrom<sqlmodel_core::Value> for #name #ty_generics #where_clause {
1199 type Error = sqlmodel_core::Error;
1200
1201 fn try_from(value: sqlmodel_core::Value) -> Result<Self, Self::Error> {
1202 match value {
1203 sqlmodel_core::Value::Text(ref s) => {
1204 sqlmodel_core::SqlEnum::from_sql_str(s.as_str()).map_err(|e| {
1205 sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1206 expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1207 actual: e,
1208 column: None,
1209 rust_type: None,
1210 })
1211 })
1212 }
1213 other => Err(sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1214 expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1215 actual: other.type_name().to_string(),
1216 column: None,
1217 rust_type: None,
1218 })),
1219 }
1220 }
1221 }
1222
1223 impl #impl_generics ::core::fmt::Display for #name #ty_generics #where_clause {
1224 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1225 f.write_str(sqlmodel_core::SqlEnum::to_sql_str(self))
1226 }
1227 }
1228
1229 impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
1230 type Err = String;
1231
1232 fn from_str(s: &str) -> Result<Self, Self::Err> {
1233 sqlmodel_core::SqlEnum::from_sql_str(s)
1234 }
1235 }
1236 })
1237}
1238
1239fn to_snake_case(s: &str) -> String {
1240 let mut result = String::with_capacity(s.len() + 4);
1241 let chars: Vec<char> = s.chars().collect();
1242 for (i, &ch) in chars.iter().enumerate() {
1243 if ch.is_uppercase() {
1244 if i > 0 {
1245 let prev_lower = chars[i - 1].is_lowercase();
1246 let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
1247 if prev_lower || (next_lower && chars[i - 1].is_uppercase()) {
1250 result.push('_');
1251 }
1252 }
1253 result.push(ch.to_ascii_lowercase());
1254 } else {
1255 result.push(ch);
1256 }
1257 }
1258 result
1259}
1260
1261#[proc_macro_attribute]
1272pub fn query(_attr: TokenStream, item: TokenStream) -> TokenStream {
1273 item
1276}