Skip to main content

sqlmodel_macros/
lib.rs

1//! Procedural macros for SQLModel Rust.
2//!
3//! `sqlmodel-macros` is the **compile-time codegen layer**. It turns Rust structs into
4//! fully described SQL models by generating static metadata and trait implementations.
5//!
6//! # Role In The Architecture
7//!
8//! - **Model metadata**: `#[derive(Model)]` produces a `Model` implementation with
9//!   table/column metadata consumed by query, schema, and session layers.
10//! - **Validation**: `#[derive(Validate)]` generates field validation glue.
11//! - **Schema export**: `#[derive(JsonSchema)]` enables JSON schema generation for
12//!   API documentation or tooling.
13//!
14//! These macros are used by application crates via the `sqlmodel` facade.
15
16use proc_macro::TokenStream;
17use syn::ext::IdentExt;
18
19mod infer;
20mod parse;
21mod validate;
22mod validate_derive;
23
24use parse::{ModelDef, RelationshipKindAttr, parse_model};
25
26/// Derive macro for the `Model` trait.
27///
28/// This macro generates implementations for:
29/// - Table name and primary key metadata
30/// - Field information
31/// - Row conversion (to_row, from_row)
32/// - Primary key access
33///
34/// # Attributes
35///
36/// - `#[sqlmodel(table = "name")]` - Override table name (defaults to snake_case struct name)
37/// - `#[sqlmodel(primary_key)]` - Mark field as primary key
38/// - `#[sqlmodel(auto_increment)]` - Mark field as auto-incrementing
39/// - `#[sqlmodel(column = "name")]` - Override column name
40/// - `#[sqlmodel(nullable)]` - Mark field as nullable
41/// - `#[sqlmodel(unique)]` - Add unique constraint
42/// - `#[sqlmodel(default = "expr")]` - Set default SQL expression
43/// - `#[sqlmodel(foreign_key = "table.column")]` - Add foreign key reference
44/// - `#[sqlmodel(index = "name")]` - Add to named index
45/// - `#[sqlmodel(skip)]` - Skip this field in database operations
46///
47/// # Example
48///
49/// ```ignore
50/// use sqlmodel::Model;
51///
52/// #[derive(Model)]
53/// #[sqlmodel(table = "heroes")]
54/// struct Hero {
55///     #[sqlmodel(primary_key, auto_increment)]
56///     id: Option<i64>,
57///
58///     #[sqlmodel(unique)]
59///     name: String,
60///
61///     secret_name: String,
62///
63///     #[sqlmodel(nullable)]
64///     age: Option<i32>,
65///
66///     #[sqlmodel(foreign_key = "teams.id")]
67///     team_id: Option<i64>,
68/// }
69/// ```
70#[proc_macro_derive(Model, attributes(sqlmodel))]
71pub fn derive_model(input: TokenStream) -> TokenStream {
72    let input = syn::parse_macro_input!(input as syn::DeriveInput);
73
74    // Parse the struct and its attributes
75    let model = match parse_model(&input) {
76        Ok(m) => m,
77        Err(e) => return e.to_compile_error().into(),
78    };
79
80    // Validate the parsed model
81    if let Err(e) = validate::validate_model(&model) {
82        return e.to_compile_error().into();
83    }
84
85    // Generate the Model implementation
86    generate_model_impl(&model).into()
87}
88
89/// Generate the Model trait implementation from parsed model definition.
90fn generate_model_impl(model: &ModelDef) -> proc_macro2::TokenStream {
91    let name = &model.name;
92    let table_name = &model.table_name;
93    let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
94
95    // Collect primary key field names
96    let pk_fields: Vec<&str> = model
97        .primary_key_fields()
98        .iter()
99        .map(|f| f.column_name.as_str())
100        .collect();
101    let pk_field_names: Vec<_> = pk_fields.clone();
102
103    // If no explicit primary key, default to "id" if present
104    let pk_slice = if pk_field_names.is_empty() {
105        // Only default to "id" if an "id" field actually exists
106        let has_id_field = model.fields.iter().any(|f| f.name == "id" && !f.skip);
107        if has_id_field {
108            quote::quote! { &["id"] }
109        } else {
110            quote::quote! { &[] }
111        }
112    } else {
113        quote::quote! { &[#(#pk_field_names),*] }
114    };
115
116    // Generate static FieldInfo array for fields()
117    let field_infos = generate_field_infos(model);
118
119    // Generate RELATIONSHIPS constant
120    let relationships = generate_relationships(model);
121
122    // Generate to_row implementation
123    let to_row_body = generate_to_row(model);
124
125    // Generate from_row implementation
126    let from_row_body = generate_from_row(model);
127
128    // Generate primary_key_value implementation
129    let pk_value_body = generate_primary_key_value(model);
130
131    // Generate is_new implementation
132    let is_new_body = generate_is_new(model);
133
134    // Generate model_config implementation
135    let model_config_body = generate_model_config(model);
136
137    // Generate inheritance implementation
138    let inheritance_body = generate_inheritance(model);
139
140    // Generate shard_key implementation
141    let (shard_key_const, shard_key_value_body) = generate_shard_key(model);
142
143    // Generate Debug impl only if any field has repr=false
144    let debug_impl = generate_debug_impl(model);
145
146    // Generate hybrid property expr methods
147    let hybrid_impl = generate_hybrid_methods(model);
148
149    quote::quote! {
150        impl #impl_generics sqlmodel_core::Model for #name #ty_generics #where_clause {
151            const TABLE_NAME: &'static str = #table_name;
152            const PRIMARY_KEY: &'static [&'static str] = #pk_slice;
153            const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] = #relationships;
154            const SHARD_KEY: Option<&'static str> = #shard_key_const;
155
156            fn fields() -> &'static [sqlmodel_core::FieldInfo] {
157                static FIELDS: &[sqlmodel_core::FieldInfo] = &[
158                    #field_infos
159                ];
160                FIELDS
161            }
162
163            fn to_row(&self) -> Vec<(&'static str, sqlmodel_core::Value)> {
164                #to_row_body
165            }
166
167            fn from_row(row: &sqlmodel_core::Row) -> sqlmodel_core::Result<Self> {
168                #from_row_body
169            }
170
171            fn primary_key_value(&self) -> Vec<sqlmodel_core::Value> {
172                #pk_value_body
173            }
174
175            fn is_new(&self) -> bool {
176                #is_new_body
177            }
178
179            fn model_config() -> sqlmodel_core::ModelConfig {
180                #model_config_body
181            }
182
183            fn inheritance() -> sqlmodel_core::InheritanceInfo {
184                #inheritance_body
185            }
186
187            fn shard_key_value(&self) -> Option<sqlmodel_core::Value> {
188                #shard_key_value_body
189            }
190        }
191
192        #debug_impl
193
194        #hybrid_impl
195    }
196}
197
198/// Generate associated functions for hybrid properties.
199///
200/// For each field with `#[sqlmodel(hybrid, sql = "...")]`, generates
201/// a `pub fn {field}_expr() -> sqlmodel_query::Expr` method that returns
202/// `Expr::raw(sql)`.
203fn generate_hybrid_methods(model: &ModelDef) -> proc_macro2::TokenStream {
204    let hybrid_fields: Vec<_> = model
205        .fields
206        .iter()
207        .filter(|f| f.hybrid && f.hybrid_sql.is_some())
208        .collect();
209
210    if hybrid_fields.is_empty() {
211        return quote::quote! {};
212    }
213
214    let name = &model.name;
215    let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
216
217    let methods: Vec<_> = hybrid_fields
218        .iter()
219        .map(|field| {
220            let sql = field.hybrid_sql.as_ref().unwrap();
221            let method_name = quote::format_ident!("{}_expr", field.name);
222            let doc = format!(
223                "SQL expression for the `{}` hybrid property.\n\nGenerates: `{}`",
224                field.name, sql
225            );
226            quote::quote! {
227                #[doc = #doc]
228                pub fn #method_name() -> sqlmodel_query::Expr {
229                    sqlmodel_query::Expr::raw(#sql)
230                }
231            }
232        })
233        .collect();
234
235    quote::quote! {
236        impl #impl_generics #name #ty_generics #where_clause {
237            #(#methods)*
238        }
239    }
240}
241
242/// Convert a referential action string to the corresponding token.
243fn referential_action_token(action: &str) -> proc_macro2::TokenStream {
244    match action.to_uppercase().as_str() {
245        "NO ACTION" | "NOACTION" | "NO_ACTION" => {
246            quote::quote! { sqlmodel_core::ReferentialAction::NoAction }
247        }
248        "RESTRICT" => quote::quote! { sqlmodel_core::ReferentialAction::Restrict },
249        "CASCADE" => quote::quote! { sqlmodel_core::ReferentialAction::Cascade },
250        "SET NULL" | "SETNULL" | "SET_NULL" => {
251            quote::quote! { sqlmodel_core::ReferentialAction::SetNull }
252        }
253        "SET DEFAULT" | "SETDEFAULT" | "SET_DEFAULT" => {
254            quote::quote! { sqlmodel_core::ReferentialAction::SetDefault }
255        }
256        _ => quote::quote! { sqlmodel_core::ReferentialAction::NoAction },
257    }
258}
259
260/// Generate the static FieldInfo array contents.
261fn generate_field_infos(model: &ModelDef) -> proc_macro2::TokenStream {
262    let mut field_tokens = Vec::new();
263
264    // Use data_fields() to include computed fields in metadata (needed for serialization)
265    for field in model.data_fields() {
266        let field_ident = field.name.unraw();
267        let column_name = &field.column_name;
268        let primary_key = field.primary_key;
269        let auto_increment = field.auto_increment;
270
271        // Check if sa_column override is present
272        let sa_col = field.sa_column.as_ref();
273
274        // Nullable: sa_column.nullable takes precedence over field.nullable
275        let nullable = sa_col.and_then(|sc| sc.nullable).unwrap_or(field.nullable);
276
277        // Unique: sa_column.unique takes precedence over field.unique
278        let unique = sa_col.and_then(|sc| sc.unique).unwrap_or(field.unique);
279
280        // Determine SQL type: sa_column.sql_type > field.sql_type > inferred
281        let effective_sql_type = sa_col
282            .and_then(|sc| sc.sql_type.as_ref())
283            .or(field.sql_type.as_ref());
284        let sql_type_token = if let Some(sql_type_str) = effective_sql_type {
285            // Parse the explicit SQL type attribute string
286            infer::parse_sql_type_attr(sql_type_str)
287        } else {
288            // Infer from Rust type (handles primitives, Option<T>, common library types)
289            infer::infer_sql_type(&field.ty)
290        };
291
292        // If sql_type attribute was provided, also store the raw string as an override for DDL.
293        let sql_type_override_token = if let Some(sql_type_str) = effective_sql_type {
294            quote::quote! { Some(#sql_type_str) }
295        } else {
296            quote::quote! { None }
297        };
298
299        // Default value: sa_column.server_default takes precedence over field.default
300        let effective_default = sa_col
301            .and_then(|sc| sc.server_default.as_ref())
302            .or(field.default.as_ref());
303        let default_token = if let Some(d) = effective_default {
304            quote::quote! { Some(#d) }
305        } else {
306            quote::quote! { None }
307        };
308
309        // Foreign key (validation prevents use with sa_column, so field value is always used)
310        let fk_token = if let Some(fk) = &field.foreign_key {
311            quote::quote! { Some(#fk) }
312        } else {
313            quote::quote! { None }
314        };
315
316        // Index: sa_column.index takes precedence over field.index
317        let effective_index = sa_col
318            .and_then(|sc| sc.index.as_ref())
319            .or(field.index.as_ref());
320        let index_token = if let Some(idx) = effective_index {
321            quote::quote! { Some(#idx) }
322        } else {
323            quote::quote! { None }
324        };
325
326        // ON DELETE action
327        let on_delete_token = if let Some(ref action) = field.on_delete {
328            let action_token = referential_action_token(action);
329            quote::quote! { Some(#action_token) }
330        } else {
331            quote::quote! { None }
332        };
333
334        // ON UPDATE action
335        let on_update_token = if let Some(ref action) = field.on_update {
336            let action_token = referential_action_token(action);
337            quote::quote! { Some(#action_token) }
338        } else {
339            quote::quote! { None }
340        };
341
342        // Alias tokens
343        let alias_token = if let Some(ref alias) = field.alias {
344            quote::quote! { Some(#alias) }
345        } else {
346            quote::quote! { None }
347        };
348
349        let validation_alias_token = if let Some(ref val_alias) = field.validation_alias {
350            quote::quote! { Some(#val_alias) }
351        } else {
352            quote::quote! { None }
353        };
354
355        let serialization_alias_token = if let Some(ref ser_alias) = field.serialization_alias {
356            quote::quote! { Some(#ser_alias) }
357        } else {
358            quote::quote! { None }
359        };
360
361        let computed = field.computed;
362        let exclude = field.exclude;
363
364        // Schema metadata tokens
365        let title_token = if let Some(ref title) = field.title {
366            quote::quote! { Some(#title) }
367        } else {
368            quote::quote! { None }
369        };
370
371        let description_token = if let Some(ref desc) = field.description {
372            quote::quote! { Some(#desc) }
373        } else {
374            quote::quote! { None }
375        };
376
377        let schema_extra_token = if let Some(ref extra) = field.schema_extra {
378            quote::quote! { Some(#extra) }
379        } else {
380            quote::quote! { None }
381        };
382
383        // Default JSON for exclude_defaults support
384        let default_json_token = if let Some(ref dj) = field.default_json {
385            quote::quote! { Some(#dj) }
386        } else {
387            quote::quote! { None }
388        };
389
390        // Const field
391        let const_field = field.const_field;
392
393        // Column constraints: sa_column.check is used if sa_column is present,
394        // otherwise field.column_constraints (validation prevents both being set)
395        let effective_constraints: Vec<&String> = if let Some(sc) = sa_col {
396            sc.check.iter().collect()
397        } else {
398            field.column_constraints.iter().collect()
399        };
400        let column_constraints_token = if effective_constraints.is_empty() {
401            quote::quote! { &[] }
402        } else {
403            quote::quote! { &[#(#effective_constraints),*] }
404        };
405
406        // Column comment: sa_column.comment is used if sa_column is present,
407        // otherwise field.column_comment (validation prevents both being set)
408        let effective_comment = sa_col
409            .and_then(|sc| sc.comment.as_ref())
410            .or(field.column_comment.as_ref());
411        let column_comment_token = if let Some(comment) = effective_comment {
412            quote::quote! { Some(#comment) }
413        } else {
414            quote::quote! { None }
415        };
416
417        // Column info
418        let column_info_token = if let Some(ref info) = field.column_info {
419            quote::quote! { Some(#info) }
420        } else {
421            quote::quote! { None }
422        };
423
424        // Hybrid SQL expression
425        let hybrid_sql_token = if let Some(ref sql) = field.hybrid_sql {
426            quote::quote! { Some(#sql) }
427        } else {
428            quote::quote! { None }
429        };
430
431        // Discriminator for union types
432        let discriminator_token = if let Some(ref disc) = field.discriminator {
433            quote::quote! { Some(#disc) }
434        } else {
435            quote::quote! { None }
436        };
437
438        // Decimal precision (max_digits -> precision, decimal_places -> scale)
439        let precision_token = if let Some(p) = field.max_digits {
440            quote::quote! { Some(#p) }
441        } else {
442            quote::quote! { None }
443        };
444
445        let scale_token = if let Some(s) = field.decimal_places {
446            quote::quote! { Some(#s) }
447        } else {
448            quote::quote! { None }
449        };
450
451        field_tokens.push(quote::quote! {
452            sqlmodel_core::FieldInfo::new(stringify!(#field_ident), #column_name, #sql_type_token)
453                .sql_type_override_opt(#sql_type_override_token)
454                .precision_opt(#precision_token)
455                .scale_opt(#scale_token)
456                .nullable(#nullable)
457                .primary_key(#primary_key)
458                .auto_increment(#auto_increment)
459                .unique(#unique)
460                .default_opt(#default_token)
461                .foreign_key_opt(#fk_token)
462                .on_delete_opt(#on_delete_token)
463                .on_update_opt(#on_update_token)
464                .index_opt(#index_token)
465                .alias_opt(#alias_token)
466                .validation_alias_opt(#validation_alias_token)
467                .serialization_alias_opt(#serialization_alias_token)
468                .computed(#computed)
469                .exclude(#exclude)
470                .title_opt(#title_token)
471                .description_opt(#description_token)
472                .schema_extra_opt(#schema_extra_token)
473                .default_json_opt(#default_json_token)
474                .const_field(#const_field)
475                .column_constraints(#column_constraints_token)
476                .column_comment_opt(#column_comment_token)
477                .column_info_opt(#column_info_token)
478                .hybrid_sql_opt(#hybrid_sql_token)
479                .discriminator_opt(#discriminator_token)
480        });
481    }
482
483    quote::quote! { #(#field_tokens),* }
484}
485
486/// Generate the to_row method body.
487fn generate_to_row(model: &ModelDef) -> proc_macro2::TokenStream {
488    let mut conversions = Vec::new();
489
490    for field in model.select_fields() {
491        let field_name = &field.name;
492        let column_name = &field.column_name;
493
494        // Convert field to Value
495        if parse::is_option_type(&field.ty) {
496            conversions.push(quote::quote! {
497                (#column_name, match &self.#field_name {
498                    Some(v) => sqlmodel_core::Value::from(v.clone()),
499                    None => sqlmodel_core::Value::Null,
500                })
501            });
502        } else {
503            conversions.push(quote::quote! {
504                (#column_name, sqlmodel_core::Value::from(self.#field_name.clone()))
505            });
506        }
507    }
508
509    quote::quote! {
510        vec![#(#conversions),*]
511    }
512}
513
514/// Generate the from_row method body.
515fn generate_from_row(model: &ModelDef) -> proc_macro2::TokenStream {
516    let name = &model.name;
517    let mut field_extractions = Vec::new();
518
519    for field in model.select_fields() {
520        let field_name = &field.name;
521        let column_name = &field.column_name;
522
523        if parse::is_option_type(&field.ty) {
524            // For Option<T> fields, handle NULL gracefully
525            field_extractions.push(quote::quote! {
526                #field_name: row.get_named(#column_name).ok()
527            });
528        } else {
529            // For required fields, propagate errors
530            field_extractions.push(quote::quote! {
531                #field_name: row.get_named(#column_name)?
532            });
533        }
534    }
535
536    // Handle skipped fields with Default
537    let skipped_fields: Vec<_> = model
538        .fields
539        .iter()
540        .filter(|f| f.skip)
541        .map(|f| {
542            let field_name = &f.name;
543            quote::quote! { #field_name: Default::default() }
544        })
545        .collect();
546
547    // Handle relationship fields with Default (they're not in the DB row)
548    let relationship_fields: Vec<_> = model
549        .fields
550        .iter()
551        .filter(|f| f.relationship.is_some())
552        .map(|f| {
553            let field_name = &f.name;
554            quote::quote! { #field_name: Default::default() }
555        })
556        .collect();
557
558    // Handle computed fields with Default (they're not in the DB row)
559    let computed_fields: Vec<_> = model
560        .computed_fields()
561        .iter()
562        .map(|f| {
563            let field_name = &f.name;
564            quote::quote! { #field_name: Default::default() }
565        })
566        .collect();
567
568    quote::quote! {
569        Ok(#name {
570            #(#field_extractions,)*
571            #(#skipped_fields,)*
572            #(#relationship_fields,)*
573            #(#computed_fields,)*
574        })
575    }
576}
577
578/// Generate the primary_key_value method body.
579fn generate_primary_key_value(model: &ModelDef) -> proc_macro2::TokenStream {
580    let pk_fields = model.primary_key_fields();
581
582    if pk_fields.is_empty() {
583        // Try to use "id" field if it exists
584        let id_field = model.fields.iter().find(|f| f.name == "id");
585        if let Some(field) = id_field {
586            let field_name = &field.name;
587            if parse::is_option_type(&field.ty) {
588                return quote::quote! {
589                    match &self.#field_name {
590                        Some(v) => vec![sqlmodel_core::Value::from(v.clone())],
591                        None => vec![sqlmodel_core::Value::Null],
592                    }
593                };
594            }
595            return quote::quote! {
596                vec![sqlmodel_core::Value::from(self.#field_name.clone())]
597            };
598        }
599        return quote::quote! { vec![] };
600    }
601
602    let mut value_exprs = Vec::new();
603    for field in pk_fields {
604        let field_name = &field.name;
605        if parse::is_option_type(&field.ty) {
606            value_exprs.push(quote::quote! {
607                match &self.#field_name {
608                    Some(v) => sqlmodel_core::Value::from(v.clone()),
609                    None => sqlmodel_core::Value::Null,
610                }
611            });
612        } else {
613            value_exprs.push(quote::quote! {
614                sqlmodel_core::Value::from(self.#field_name.clone())
615            });
616        }
617    }
618
619    quote::quote! {
620        vec![#(#value_exprs),*]
621    }
622}
623
624/// Generate the is_new method body.
625fn generate_is_new(model: &ModelDef) -> proc_macro2::TokenStream {
626    let pk_fields = model.primary_key_fields();
627
628    // If there's an auto_increment primary key field that is Option<T>,
629    // check if it's None
630    for field in &pk_fields {
631        if field.auto_increment && parse::is_option_type(&field.ty) {
632            let field_name = &field.name;
633            return quote::quote! {
634                self.#field_name.is_none()
635            };
636        }
637    }
638
639    // Otherwise, try "id" field if it exists and is Option<T>
640    if let Some(id_field) = model.fields.iter().find(|f| f.name == "id") {
641        if parse::is_option_type(&id_field.ty) {
642            return quote::quote! {
643                self.id.is_none()
644            };
645        }
646    }
647
648    // Default: cannot determine, always return true
649    quote::quote! { true }
650}
651
652/// Generate the model_config method body.
653fn generate_model_config(model: &ModelDef) -> proc_macro2::TokenStream {
654    let config = &model.config;
655
656    let table = config.table;
657    let from_attributes = config.from_attributes;
658    let validate_assignment = config.validate_assignment;
659    let strict = config.strict;
660    let populate_by_name = config.populate_by_name;
661    let use_enum_values = config.use_enum_values;
662    let arbitrary_types_allowed = config.arbitrary_types_allowed;
663    let defer_build = config.defer_build;
664    let revalidate_instances = config.revalidate_instances;
665
666    // Handle extra field behavior
667    let extra_token = match config.extra.as_str() {
668        "forbid" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Forbid },
669        "allow" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Allow },
670        _ => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Ignore },
671    };
672
673    // Handle optional string fields
674    let json_schema_extra_token = if let Some(ref extra) = config.json_schema_extra {
675        quote::quote! { Some(#extra) }
676    } else {
677        quote::quote! { None }
678    };
679
680    let title_token = if let Some(ref title) = config.title {
681        quote::quote! { Some(#title) }
682    } else {
683        quote::quote! { None }
684    };
685
686    quote::quote! {
687        sqlmodel_core::ModelConfig {
688            table: #table,
689            from_attributes: #from_attributes,
690            validate_assignment: #validate_assignment,
691            extra: #extra_token,
692            strict: #strict,
693            populate_by_name: #populate_by_name,
694            use_enum_values: #use_enum_values,
695            arbitrary_types_allowed: #arbitrary_types_allowed,
696            defer_build: #defer_build,
697            revalidate_instances: #revalidate_instances,
698            json_schema_extra: #json_schema_extra_token,
699            title: #title_token,
700        }
701    }
702}
703
704/// Generate the inheritance method body.
705fn generate_inheritance(model: &ModelDef) -> proc_macro2::TokenStream {
706    use crate::parse::InheritanceStrategy;
707
708    let config = &model.config;
709
710    // Determine the inheritance strategy token
711    let strategy_token = match config.inheritance {
712        InheritanceStrategy::None => {
713            quote::quote! { sqlmodel_core::InheritanceStrategy::None }
714        }
715        InheritanceStrategy::Single => {
716            quote::quote! { sqlmodel_core::InheritanceStrategy::Single }
717        }
718        InheritanceStrategy::Joined => {
719            quote::quote! { sqlmodel_core::InheritanceStrategy::Joined }
720        }
721        InheritanceStrategy::Concrete => {
722            quote::quote! { sqlmodel_core::InheritanceStrategy::Concrete }
723        }
724    };
725
726    // Handle parent model name
727    let parent_token = if let Some(ref parent) = config.inherits {
728        quote::quote! { Some(#parent) }
729    } else {
730        quote::quote! { None }
731    };
732
733    // Handle discriminator column (for base models)
734    let discriminator_column_token = if let Some(ref column) = config.discriminator_column {
735        quote::quote! { Some(#column) }
736    } else {
737        quote::quote! { None }
738    };
739
740    // Handle discriminator value (for child models)
741    let discriminator_value_token = if let Some(ref value) = config.discriminator_value {
742        quote::quote! { Some(#value) }
743    } else {
744        quote::quote! { None }
745    };
746
747    quote::quote! {
748        sqlmodel_core::InheritanceInfo {
749            strategy: #strategy_token,
750            parent: #parent_token,
751            discriminator_column: #discriminator_column_token,
752            discriminator_value: #discriminator_value_token,
753        }
754    }
755}
756
757/// Generate the shard_key constant and shard_key_value method body.
758///
759/// Returns a tuple of (const token, method body token) for:
760/// - `const SHARD_KEY: Option<&'static str>`
761/// - `fn shard_key_value(&self) -> Option<Value>`
762fn generate_shard_key(model: &ModelDef) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
763    let config = &model.config;
764
765    if let Some(ref shard_key_name) = config.shard_key {
766        // Find the shard key field to get its type info
767        let shard_field = model.fields.iter().find(|f| f.name == shard_key_name);
768
769        let const_token = quote::quote! { Some(#shard_key_name) };
770
771        // Generate the method body based on whether the field exists and its type
772        let value_body = if let Some(field) = shard_field {
773            let field_ident = &field.name;
774            if parse::is_option_type(&field.ty) {
775                // Option<T> field: return Some(value) if Some, None if None
776                quote::quote! {
777                    match &self.#field_ident {
778                        Some(v) => Some(sqlmodel_core::Value::from(v.clone())),
779                        None => None,
780                    }
781                }
782            } else {
783                // Non-optional field: always has a value
784                quote::quote! {
785                    Some(sqlmodel_core::Value::from(self.#field_ident.clone()))
786                }
787            }
788        } else {
789            // Field not found - this is a compile error in validation,
790            // but generate safe fallback code
791            quote::quote! { None }
792        };
793
794        (const_token, value_body)
795    } else {
796        // No shard key configured
797        let const_token = quote::quote! { None };
798        let value_body = quote::quote! { None };
799        (const_token, value_body)
800    }
801}
802
803/// Generate a custom Debug implementation if any field has repr=false.
804///
805/// This generates a Debug impl that excludes fields marked with `repr = false`,
806/// which is useful for hiding sensitive data like passwords from debug output.
807fn generate_debug_impl(model: &ModelDef) -> proc_macro2::TokenStream {
808    // Check if any field has repr=false
809    let has_hidden_fields = model.fields.iter().any(|f| !f.repr);
810
811    // Only generate custom Debug if there are hidden fields
812    if !has_hidden_fields {
813        return quote::quote! {};
814    }
815
816    let name = &model.name;
817    let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
818
819    // Generate field entries for Debug, excluding fields with repr=false
820    let debug_fields: Vec<_> = model
821        .fields
822        .iter()
823        .filter(|f| f.repr) // Only include fields with repr=true
824        .map(|f| {
825            let field_name = &f.name;
826            let field_name_str = field_name.to_string();
827            quote::quote! {
828                .field(#field_name_str, &self.#field_name)
829            }
830        })
831        .collect();
832
833    let struct_name_str = name.to_string();
834
835    quote::quote! {
836        impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause {
837            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
838                f.debug_struct(#struct_name_str)
839                    #(#debug_fields)*
840                    .finish()
841            }
842        }
843    }
844}
845
846/// Generate the RELATIONSHIPS constant from relationship fields.
847fn generate_relationships(model: &ModelDef) -> proc_macro2::TokenStream {
848    let relationship_fields = model.relationship_fields();
849
850    if relationship_fields.is_empty() {
851        return quote::quote! { &[] };
852    }
853
854    let mut relationship_tokens = Vec::new();
855
856    for field in relationship_fields {
857        let rel = field.relationship.as_ref().unwrap();
858        let field_name = &field.name;
859        let related_table = &rel.model;
860
861        // Determine RelationshipKind token
862        let kind_token = match rel.kind {
863            RelationshipKindAttr::OneToOne => {
864                quote::quote! { sqlmodel_core::RelationshipKind::OneToOne }
865            }
866            RelationshipKindAttr::ManyToOne => {
867                quote::quote! { sqlmodel_core::RelationshipKind::ManyToOne }
868            }
869            RelationshipKindAttr::OneToMany => {
870                quote::quote! { sqlmodel_core::RelationshipKind::OneToMany }
871            }
872            RelationshipKindAttr::ManyToMany => {
873                quote::quote! { sqlmodel_core::RelationshipKind::ManyToMany }
874            }
875        };
876
877        // Build optional method calls
878        let local_key_call = if let Some(ref fk) = rel.foreign_key {
879            quote::quote! { .local_key(#fk) }
880        } else {
881            quote::quote! {}
882        };
883
884        let remote_key_call = if let Some(ref rk) = rel.remote_key {
885            quote::quote! { .remote_key(#rk) }
886        } else {
887            quote::quote! {}
888        };
889
890        let back_populates_call = if let Some(ref bp) = rel.back_populates {
891            quote::quote! { .back_populates(#bp) }
892        } else {
893            quote::quote! {}
894        };
895
896        let link_table_call = if let Some(ref lt) = rel.link_table {
897            let table = &lt.table;
898            let local_col = &lt.local_column;
899            let remote_col = &lt.remote_column;
900            quote::quote! {
901                .link_table(sqlmodel_core::LinkTableInfo::new(#table, #local_col, #remote_col))
902            }
903        } else {
904            quote::quote! {}
905        };
906
907        let lazy_val = rel.lazy;
908        let cascade_val = rel.cascade_delete;
909        let passive_deletes_token = match rel.passive_deletes {
910            crate::parse::PassiveDeletesAttr::Active => {
911                quote::quote! { sqlmodel_core::PassiveDeletes::Active }
912            }
913            crate::parse::PassiveDeletesAttr::Passive => {
914                quote::quote! { sqlmodel_core::PassiveDeletes::Passive }
915            }
916            crate::parse::PassiveDeletesAttr::All => {
917                quote::quote! { sqlmodel_core::PassiveDeletes::All }
918            }
919        };
920
921        // New sa_relationship fields
922        let order_by_call = if let Some(ref ob) = rel.order_by {
923            quote::quote! { .order_by(#ob) }
924        } else {
925            quote::quote! {}
926        };
927
928        let lazy_strategy_call = if let Some(ref strategy) = rel.lazy_strategy {
929            let strategy_token = match strategy {
930                crate::parse::LazyLoadStrategyAttr::Select => {
931                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Select }
932                }
933                crate::parse::LazyLoadStrategyAttr::Joined => {
934                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Joined }
935                }
936                crate::parse::LazyLoadStrategyAttr::Subquery => {
937                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Subquery }
938                }
939                crate::parse::LazyLoadStrategyAttr::Selectin => {
940                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Selectin }
941                }
942                crate::parse::LazyLoadStrategyAttr::Dynamic => {
943                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Dynamic }
944                }
945                crate::parse::LazyLoadStrategyAttr::NoLoad => {
946                    quote::quote! { sqlmodel_core::LazyLoadStrategy::NoLoad }
947                }
948                crate::parse::LazyLoadStrategyAttr::RaiseOnSql => {
949                    quote::quote! { sqlmodel_core::LazyLoadStrategy::RaiseOnSql }
950                }
951                crate::parse::LazyLoadStrategyAttr::WriteOnly => {
952                    quote::quote! { sqlmodel_core::LazyLoadStrategy::WriteOnly }
953                }
954            };
955            quote::quote! { .lazy_strategy(#strategy_token) }
956        } else {
957            quote::quote! {}
958        };
959
960        let cascade_call = if let Some(ref c) = rel.cascade {
961            quote::quote! { .cascade(#c) }
962        } else {
963            quote::quote! {}
964        };
965
966        let uselist_call = if let Some(ul) = rel.uselist {
967            quote::quote! { .uselist(#ul) }
968        } else {
969            quote::quote! {}
970        };
971
972        relationship_tokens.push(quote::quote! {
973            sqlmodel_core::RelationshipInfo::new(
974                stringify!(#field_name),
975                #related_table,
976                #kind_token
977            )
978            #local_key_call
979            #remote_key_call
980            #back_populates_call
981            #link_table_call
982            .lazy(#lazy_val)
983            .cascade_delete(#cascade_val)
984            .passive_deletes(#passive_deletes_token)
985            #order_by_call
986            #lazy_strategy_call
987            #cascade_call
988            #uselist_call
989        });
990    }
991
992    quote::quote! {
993        &[#(#relationship_tokens),*]
994    }
995}
996
997/// Derive macro for field validation.
998///
999/// Generates a `validate()` method that checks field constraints at runtime.
1000///
1001/// # Attributes
1002///
1003/// - `#[validate(min = N)]` - Minimum value for numbers
1004/// - `#[validate(max = N)]` - Maximum value for numbers
1005/// - `#[validate(min_length = N)]` - Minimum length for strings
1006/// - `#[validate(max_length = N)]` - Maximum length for strings
1007/// - `#[validate(pattern = "regex")]` - Regex pattern for strings
1008/// - `#[validate(email)]` - Email format validation
1009/// - `#[validate(url)]` - URL format validation
1010/// - `#[validate(required)]` - Mark an Option<T> field as required
1011/// - `#[validate(custom = "fn_name")]` - Custom validation function
1012///
1013/// # Example
1014///
1015/// ```ignore
1016/// use sqlmodel::Validate;
1017///
1018/// #[derive(Validate)]
1019/// struct User {
1020///     #[validate(min_length = 1, max_length = 100)]
1021///     name: String,
1022///
1023///     #[validate(min = 0, max = 150)]
1024///     age: i32,
1025///
1026///     #[validate(email)]
1027///     email: String,
1028///
1029///     #[validate(required)]
1030///     team_id: Option<i64>,
1031/// }
1032///
1033/// let user = User {
1034///     name: "".to_string(),
1035///     age: 200,
1036///     email: "invalid".to_string(),
1037///     team_id: None,
1038/// };
1039///
1040/// // Returns Err with all validation failures
1041/// let result = user.validate();
1042/// assert!(result.is_err());
1043/// ```
1044#[proc_macro_derive(Validate, attributes(validate))]
1045pub fn derive_validate(input: TokenStream) -> TokenStream {
1046    let input = syn::parse_macro_input!(input as syn::DeriveInput);
1047
1048    // Parse the struct and its validation attributes
1049    let def = match validate_derive::parse_validate(&input) {
1050        Ok(d) => d,
1051        Err(e) => return e.to_compile_error().into(),
1052    };
1053
1054    // Generate the validation implementation
1055    validate_derive::generate_validate_impl(&def).into()
1056}
1057
1058/// Derive macro for SQL enum types.
1059///
1060/// Generates `SqlEnum` trait implementation, `From<EnumType> for Value`,
1061/// `TryFrom<Value> for EnumType`, and `Display`/`FromStr` implementations.
1062///
1063/// Enum variants are mapped to their snake_case string representations by default.
1064/// Use `#[sqlmodel(rename = "custom_name")]` on variants to override.
1065///
1066/// # Example
1067///
1068/// ```ignore
1069/// #[derive(SqlEnum, Debug, Clone, PartialEq)]
1070/// enum Status {
1071///     Active,
1072///     Inactive,
1073///     #[sqlmodel(rename = "on_hold")]
1074///     OnHold,
1075/// }
1076/// ```
1077#[proc_macro_derive(SqlEnum, attributes(sqlmodel))]
1078pub fn derive_sql_enum(input: TokenStream) -> TokenStream {
1079    let input = syn::parse_macro_input!(input as syn::DeriveInput);
1080    match generate_sql_enum_impl(&input) {
1081        Ok(tokens) => tokens.into(),
1082        Err(e) => e.to_compile_error().into(),
1083    }
1084}
1085
1086fn generate_sql_enum_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
1087    let name = &input.ident;
1088    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1089
1090    let syn::Data::Enum(data) = &input.data else {
1091        return Err(syn::Error::new_spanned(
1092            input,
1093            "SqlEnum can only be derived for enums",
1094        ));
1095    };
1096
1097    // Collect variant info
1098    let mut variant_names = Vec::new();
1099    let mut variant_strings = Vec::new();
1100
1101    for variant in &data.variants {
1102        if !variant.fields.is_empty() {
1103            return Err(syn::Error::new_spanned(
1104                variant,
1105                "SqlEnum variants must be unit variants (no fields)",
1106            ));
1107        }
1108
1109        let ident = &variant.ident;
1110        variant_names.push(ident.clone());
1111
1112        // Check for #[sqlmodel(rename = "...")] attribute
1113        let mut custom_name = None;
1114        for attr in &variant.attrs {
1115            if attr.path().is_ident("sqlmodel") {
1116                attr.parse_nested_meta(|meta| {
1117                    if meta.path.is_ident("rename") {
1118                        let value = meta.value()?;
1119                        let s: syn::LitStr = value.parse()?;
1120                        custom_name = Some(s.value());
1121                    }
1122                    Ok(())
1123                })?;
1124            }
1125        }
1126
1127        let sql_str = custom_name.unwrap_or_else(|| to_snake_case(&ident.to_string()));
1128        variant_strings.push(sql_str);
1129    }
1130
1131    let type_name = to_snake_case(&name.to_string());
1132
1133    // Generate static VARIANTS array
1134    let variant_str_refs: Vec<_> = variant_strings.iter().map(|s| s.as_str()).collect();
1135
1136    let to_sql_arms: Vec<_> = variant_names
1137        .iter()
1138        .zip(variant_strings.iter())
1139        .map(|(ident, s)| {
1140            quote::quote! { #name::#ident => #s }
1141        })
1142        .collect();
1143
1144    let from_sql_arms: Vec<_> = variant_names
1145        .iter()
1146        .zip(variant_strings.iter())
1147        .map(|(ident, s)| {
1148            quote::quote! { #s => Ok(#name::#ident) }
1149        })
1150        .collect();
1151
1152    // Build the error message listing valid values
1153    let valid_values: String = variant_strings
1154        .iter()
1155        .map(|s| format!("'{}'", s))
1156        .collect::<Vec<_>>()
1157        .join(", ");
1158    let error_msg = format!(
1159        "invalid value for {}: expected one of {}",
1160        name, valid_values
1161    );
1162
1163    Ok(quote::quote! {
1164        impl #impl_generics sqlmodel_core::SqlEnum for #name #ty_generics #where_clause {
1165            const VARIANTS: &'static [&'static str] = &[#(#variant_str_refs),*];
1166            const TYPE_NAME: &'static str = #type_name;
1167
1168            fn to_sql_str(&self) -> &'static str {
1169                match self {
1170                    #(#to_sql_arms,)*
1171                }
1172            }
1173
1174            fn from_sql_str(s: &str) -> Result<Self, String> {
1175                match s {
1176                    #(#from_sql_arms,)*
1177                    _ => Err(format!("{}, got '{}'", #error_msg, s)),
1178                }
1179            }
1180        }
1181
1182        impl #impl_generics From<#name #ty_generics> for sqlmodel_core::Value #where_clause {
1183            fn from(v: #name #ty_generics) -> Self {
1184                sqlmodel_core::Value::Text(
1185                    sqlmodel_core::SqlEnum::to_sql_str(&v).to_string()
1186                )
1187            }
1188        }
1189
1190        impl #impl_generics From<&#name #ty_generics> for sqlmodel_core::Value #where_clause {
1191            fn from(v: &#name #ty_generics) -> Self {
1192                sqlmodel_core::Value::Text(
1193                    sqlmodel_core::SqlEnum::to_sql_str(v).to_string()
1194                )
1195            }
1196        }
1197
1198        impl #impl_generics TryFrom<sqlmodel_core::Value> for #name #ty_generics #where_clause {
1199            type Error = sqlmodel_core::Error;
1200
1201            fn try_from(value: sqlmodel_core::Value) -> Result<Self, Self::Error> {
1202                match value {
1203                    sqlmodel_core::Value::Text(ref s) => {
1204                        sqlmodel_core::SqlEnum::from_sql_str(s.as_str()).map_err(|e| {
1205                            sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1206                                expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1207                                actual: e,
1208                                column: None,
1209                                rust_type: None,
1210                            })
1211                        })
1212                    }
1213                    other => Err(sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1214                        expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1215                        actual: other.type_name().to_string(),
1216                        column: None,
1217                        rust_type: None,
1218                    })),
1219                }
1220            }
1221        }
1222
1223        impl #impl_generics ::core::fmt::Display for #name #ty_generics #where_clause {
1224            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1225                f.write_str(sqlmodel_core::SqlEnum::to_sql_str(self))
1226            }
1227        }
1228
1229        impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
1230            type Err = String;
1231
1232            fn from_str(s: &str) -> Result<Self, Self::Err> {
1233                sqlmodel_core::SqlEnum::from_sql_str(s)
1234            }
1235        }
1236    })
1237}
1238
1239fn to_snake_case(s: &str) -> String {
1240    let mut result = String::with_capacity(s.len() + 4);
1241    let chars: Vec<char> = s.chars().collect();
1242    for (i, &ch) in chars.iter().enumerate() {
1243        if ch.is_uppercase() {
1244            if i > 0 {
1245                let prev_lower = chars[i - 1].is_lowercase();
1246                let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
1247                // "FooBar" -> "foo_bar": insert _ when prev is lowercase
1248                // "HTTPStatus" -> "http_status": insert _ when next is lowercase (acronym boundary)
1249                if prev_lower || (next_lower && chars[i - 1].is_uppercase()) {
1250                    result.push('_');
1251                }
1252            }
1253            result.push(ch.to_ascii_lowercase());
1254        } else {
1255            result.push(ch);
1256        }
1257    }
1258    result
1259}
1260
1261/// Attribute macro for defining SQL functions in handlers.
1262///
1263/// # Example
1264///
1265/// ```ignore
1266/// #[sqlmodel::query]
1267/// async fn get_heroes(cx: &Cx, conn: &impl Connection) -> Vec<Hero> {
1268///     sqlmodel::select!(Hero).all(cx, conn).await
1269/// }
1270/// ```
1271#[proc_macro_attribute]
1272pub fn query(_attr: TokenStream, item: TokenStream) -> TokenStream {
1273    // Stub: query attribute macro is a pass-through placeholder for future SQL validation.
1274    // When implemented, it will provide compile-time SQL validation and query optimization hints.
1275    item
1276}