Skip to main content

sqlmodel_macros/
lib.rs

1//! Procedural macros for SQLModel Rust.
2//!
3//! `sqlmodel-macros` is the **compile-time codegen layer**. It turns Rust structs into
4//! fully described SQL models by generating static metadata and trait implementations.
5//!
6//! # Role In The Architecture
7//!
8//! - **Model metadata**: `#[derive(Model)]` produces a `Model` implementation with
9//!   table/column metadata consumed by query, schema, and session layers.
10//! - **Validation**: `#[derive(Validate)]` generates field validation glue.
11//! - **Schema export**: `#[derive(JsonSchema)]` enables JSON schema generation for
12//!   API documentation or tooling.
13//!
14//! These macros are used by application crates via the `sqlmodel` facade.
15
16use proc_macro::TokenStream;
17use syn::ext::IdentExt;
18
19mod infer;
20mod parse;
21mod validate;
22mod validate_derive;
23
24use parse::{InheritanceStrategy, ModelDef, RelationshipKindAttr, parse_model};
25
26/// Derive macro for the `Model` trait.
27///
28/// This macro generates implementations for:
29/// - Table name and primary key metadata
30/// - Field information
31/// - Row conversion (to_row, from_row)
32/// - Primary key access
33///
34/// # Attributes
35///
36/// - `#[sqlmodel(table = "name")]` - Override table name (defaults to snake_case struct name)
37/// - `#[sqlmodel(primary_key)]` - Mark field as primary key
38/// - `#[sqlmodel(auto_increment)]` - Mark field as auto-incrementing
39/// - `#[sqlmodel(column = "name")]` - Override column name
40/// - `#[sqlmodel(nullable)]` - Mark field as nullable
41/// - `#[sqlmodel(unique)]` - Add unique constraint
42/// - `#[sqlmodel(default = "expr")]` - Set default SQL expression
43/// - `#[sqlmodel(foreign_key = "table.column")]` - Add foreign key reference
44/// - `#[sqlmodel(index = "name")]` - Add to named index
45/// - `#[sqlmodel(skip)]` - Skip this field in database operations
46///
47/// # Example
48///
49/// ```ignore
50/// use sqlmodel::Model;
51///
52/// #[derive(Model)]
53/// #[sqlmodel(table = "heroes")]
54/// struct Hero {
55///     #[sqlmodel(primary_key, auto_increment)]
56///     id: Option<i64>,
57///
58///     #[sqlmodel(unique)]
59///     name: String,
60///
61///     secret_name: String,
62///
63///     #[sqlmodel(nullable)]
64///     age: Option<i32>,
65///
66///     #[sqlmodel(foreign_key = "teams.id")]
67///     team_id: Option<i64>,
68/// }
69/// ```
70#[proc_macro_derive(Model, attributes(sqlmodel))]
71pub fn derive_model(input: TokenStream) -> TokenStream {
72    let input = syn::parse_macro_input!(input as syn::DeriveInput);
73
74    // Parse the struct and its attributes
75    let model = match parse_model(&input) {
76        Ok(m) => m,
77        Err(e) => return e.to_compile_error().into(),
78    };
79
80    // Validate the parsed model
81    if let Err(e) = validate::validate_model(&model) {
82        return e.to_compile_error().into();
83    }
84
85    // Generate the Model implementation
86    generate_model_impl(&model).into()
87}
88
89/// Generate the Model trait implementation from parsed model definition.
90fn generate_model_impl(model: &ModelDef) -> proc_macro2::TokenStream {
91    let name = &model.name;
92    let table_name_lit = &model.table_name;
93    let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
94
95    // If this is a single-table-inheritance child (inherits + discriminator_value),
96    // its effective table is the parent table.
97    let table_name_ts =
98        if model.config.inherits.is_some() && model.config.discriminator_value.is_some() {
99            let parent = model
100                .config
101                .inherits
102                .as_deref()
103                .expect("inherits checked above");
104            let parent_ty_ts: proc_macro2::TokenStream =
105                if let Ok(path) = syn::parse_str::<syn::Path>(parent) {
106                    quote::quote! { #path }
107                } else {
108                    let ident = syn::Ident::new(parent, proc_macro2::Span::call_site());
109                    quote::quote! { #ident }
110                };
111            quote::quote! { <#parent_ty_ts as sqlmodel_core::Model>::TABLE_NAME }
112        } else {
113            quote::quote! { #table_name_lit }
114        };
115
116    // Collect primary key field names
117    let pk_fields: Vec<&str> = model
118        .primary_key_fields()
119        .iter()
120        .map(|f| f.column_name.as_str())
121        .collect();
122    let pk_field_names: Vec<_> = pk_fields.clone();
123
124    // If no explicit primary key, default to "id" if present
125    let pk_slice = if pk_field_names.is_empty() {
126        // Only default to "id" if an "id" field actually exists
127        let has_id_field = model.fields.iter().any(|f| f.name == "id" && !f.skip);
128        if has_id_field {
129            quote::quote! { &["id"] }
130        } else {
131            quote::quote! { &[] }
132        }
133    } else {
134        quote::quote! { &[#(#pk_field_names),*] }
135    };
136
137    // Generate static FieldInfo array for fields()
138    let field_infos = generate_field_infos(model);
139
140    // Generate RELATIONSHIPS constant
141    let relationships = generate_relationships(model);
142
143    // Generate to_row implementation
144    let to_row_body = generate_to_row(model);
145
146    // Generate from_row implementation
147    let from_row_body = generate_from_row(model);
148
149    // Generate primary_key_value implementation
150    let pk_value_body = generate_primary_key_value(model);
151
152    // Generate is_new implementation
153    let is_new_body = generate_is_new(model);
154
155    // Generate model_config implementation
156    let model_config_body = generate_model_config(model);
157
158    // Generate inheritance implementation
159    let inheritance_body = generate_inheritance(model);
160
161    // Generate shard_key implementation
162    let (shard_key_const, shard_key_value_body) = generate_shard_key(model);
163
164    // Generate joined-parent extraction for joined-table inheritance child models.
165    let joined_parent_row_body = generate_joined_parent_row(model);
166
167    // Generate Debug impl only if any field has repr=false
168    let debug_impl = generate_debug_impl(model);
169
170    // Generate hybrid property expr methods
171    let hybrid_impl = generate_hybrid_methods(model);
172
173    quote::quote! {
174        impl #impl_generics sqlmodel_core::Model for #name #ty_generics #where_clause {
175            const TABLE_NAME: &'static str = #table_name_ts;
176            const PRIMARY_KEY: &'static [&'static str] = #pk_slice;
177            const RELATIONSHIPS: &'static [sqlmodel_core::RelationshipInfo] = #relationships;
178            const SHARD_KEY: Option<&'static str> = #shard_key_const;
179
180            fn fields() -> &'static [sqlmodel_core::FieldInfo] {
181                static FIELDS: &[sqlmodel_core::FieldInfo] = &[
182                    #field_infos
183                ];
184                FIELDS
185            }
186
187            fn to_row(&self) -> Vec<(&'static str, sqlmodel_core::Value)> {
188                #to_row_body
189            }
190
191            fn from_row(row: &sqlmodel_core::Row) -> sqlmodel_core::Result<Self> {
192                #from_row_body
193            }
194
195            fn primary_key_value(&self) -> Vec<sqlmodel_core::Value> {
196                #pk_value_body
197            }
198
199            fn is_new(&self) -> bool {
200                #is_new_body
201            }
202
203            fn model_config() -> sqlmodel_core::ModelConfig {
204                #model_config_body
205            }
206
207            fn inheritance() -> sqlmodel_core::InheritanceInfo {
208                #inheritance_body
209            }
210
211            fn shard_key_value(&self) -> Option<sqlmodel_core::Value> {
212                #shard_key_value_body
213            }
214
215            #joined_parent_row_body
216        }
217
218        #debug_impl
219
220        #hybrid_impl
221    }
222}
223
224/// Generate associated functions for hybrid properties.
225///
226/// For each field with `#[sqlmodel(hybrid, sql = "...")]`, generates
227/// a `pub fn {field}_expr() -> sqlmodel_query::Expr` method that returns
228/// `Expr::raw(sql)`.
229fn generate_hybrid_methods(model: &ModelDef) -> proc_macro2::TokenStream {
230    let hybrid_fields: Vec<_> = model
231        .fields
232        .iter()
233        .filter_map(|f| {
234            if !f.hybrid {
235                return None;
236            }
237            let sql = f.hybrid_sql.as_deref()?;
238            Some((f, sql))
239        })
240        .collect();
241
242    if hybrid_fields.is_empty() {
243        return quote::quote! {};
244    }
245
246    let name = &model.name;
247    let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
248
249    let methods: Vec<_> = hybrid_fields
250        .iter()
251        .map(|(field, sql)| {
252            let method_name = quote::format_ident!("{}_expr", field.name);
253            let doc = format!(
254                "SQL expression for the `{}` hybrid property.\n\nGenerates: `{}`",
255                field.name, sql
256            );
257            quote::quote! {
258                #[doc = #doc]
259                pub fn #method_name() -> sqlmodel_query::Expr {
260                    sqlmodel_query::Expr::raw(#sql)
261                }
262            }
263        })
264        .collect();
265
266    quote::quote! {
267        impl #impl_generics #name #ty_generics #where_clause {
268            #(#methods)*
269        }
270    }
271}
272
273fn generate_joined_parent_row(model: &ModelDef) -> proc_macro2::TokenStream {
274    let is_joined_child =
275        model.config.inheritance == InheritanceStrategy::Joined && model.config.inherits.is_some();
276    if !is_joined_child {
277        return quote::quote! {};
278    }
279
280    let Some(parent_field) = model.fields.iter().find(|f| f.parent) else {
281        return quote::quote! {};
282    };
283    let parent_ident = &parent_field.name;
284
285    quote::quote! {
286        fn joined_parent_row(&self) -> Option<Vec<(&'static str, sqlmodel_core::Value)>> {
287            Some(self.#parent_ident.to_row())
288        }
289    }
290}
291
292/// Convert a referential action string to the corresponding token.
293fn referential_action_ts(action: &str) -> proc_macro2::TokenStream {
294    match action.to_uppercase().as_str() {
295        "NO ACTION" | "NOACTION" | "NO_ACTION" => {
296            quote::quote! { sqlmodel_core::ReferentialAction::NoAction }
297        }
298        "RESTRICT" => quote::quote! { sqlmodel_core::ReferentialAction::Restrict },
299        "CASCADE" => quote::quote! { sqlmodel_core::ReferentialAction::Cascade },
300        "SET NULL" | "SETNULL" | "SET_NULL" => {
301            quote::quote! { sqlmodel_core::ReferentialAction::SetNull }
302        }
303        "SET DEFAULT" | "SETDEFAULT" | "SET_DEFAULT" => {
304            quote::quote! { sqlmodel_core::ReferentialAction::SetDefault }
305        }
306        _ => quote::quote! { sqlmodel_core::ReferentialAction::NoAction },
307    }
308}
309
310/// Generate the static FieldInfo array contents.
311fn generate_field_infos(model: &ModelDef) -> proc_macro2::TokenStream {
312    let mut field_ts = Vec::new();
313
314    // Use data_fields() to include computed fields in metadata (needed for serialization)
315    for field in model.data_fields() {
316        let field_ident = field.name.unraw();
317        let column_name = &field.column_name;
318        let primary_key = field.primary_key;
319        let auto_increment = field.auto_increment;
320
321        // Check if sa_column override is present
322        let sa_col = field.sa_column.as_ref();
323
324        // Nullable: sa_column.nullable takes precedence over field.nullable
325        let nullable = sa_col.and_then(|sc| sc.nullable).unwrap_or(field.nullable);
326
327        // Unique: sa_column.unique takes precedence over field.unique
328        let unique = sa_col.and_then(|sc| sc.unique).unwrap_or(field.unique);
329
330        // Determine SQL type: sa_column.sql_type > field.sql_type > inferred
331        let effective_sql_type = sa_col
332            .and_then(|sc| sc.sql_type.as_ref())
333            .or(field.sql_type.as_ref());
334        let sql_type_ts = if let Some(sql_type_str) = effective_sql_type {
335            // Parse the explicit SQL type attribute string
336            infer::parse_sql_type_attr(sql_type_str)
337        } else {
338            // Infer from Rust type (handles primitives, Option<T>, common library types)
339            infer::infer_sql_type(&field.ty)
340        };
341
342        // If sql_type attribute was provided, also store the raw string as an override for DDL.
343        let sql_type_override_ts = if let Some(sql_type_str) = effective_sql_type {
344            quote::quote! { Some(#sql_type_str) }
345        } else {
346            quote::quote! { None }
347        };
348
349        // Default value: sa_column.server_default takes precedence over field.default
350        let effective_default = sa_col
351            .and_then(|sc| sc.server_default.as_ref())
352            .or(field.default.as_ref());
353        let default_ts = if let Some(d) = effective_default {
354            quote::quote! { Some(#d) }
355        } else {
356            quote::quote! { None }
357        };
358
359        // Foreign key (validation prevents use with sa_column, so field value is always used)
360        let fk_ts = if let Some(fk) = &field.foreign_key {
361            quote::quote! { Some(#fk) }
362        } else {
363            quote::quote! { None }
364        };
365
366        // Index: sa_column.index takes precedence over field.index
367        let effective_index = sa_col
368            .and_then(|sc| sc.index.as_ref())
369            .or(field.index.as_ref());
370        let index_ts = if let Some(idx) = effective_index {
371            quote::quote! { Some(#idx) }
372        } else {
373            quote::quote! { None }
374        };
375
376        // ON DELETE action
377        let on_delete_ts = if let Some(ref action) = field.on_delete {
378            let action_ts = referential_action_ts(action);
379            quote::quote! { Some(#action_ts) }
380        } else {
381            quote::quote! { None }
382        };
383
384        // ON UPDATE action
385        let on_update_ts = if let Some(ref action) = field.on_update {
386            let action_ts = referential_action_ts(action);
387            quote::quote! { Some(#action_ts) }
388        } else {
389            quote::quote! { None }
390        };
391
392        // Alias tokens
393        let alias_ts = if let Some(ref alias) = field.alias {
394            quote::quote! { Some(#alias) }
395        } else {
396            quote::quote! { None }
397        };
398
399        let validation_alias_ts = if let Some(ref val_alias) = field.validation_alias {
400            quote::quote! { Some(#val_alias) }
401        } else {
402            quote::quote! { None }
403        };
404
405        let serialization_alias_ts = if let Some(ref ser_alias) = field.serialization_alias {
406            quote::quote! { Some(#ser_alias) }
407        } else {
408            quote::quote! { None }
409        };
410
411        let computed = field.computed;
412        let exclude = field.exclude;
413
414        // Schema metadata tokens
415        let title_ts = if let Some(ref title) = field.title {
416            quote::quote! { Some(#title) }
417        } else {
418            quote::quote! { None }
419        };
420
421        let description_ts = if let Some(ref desc) = field.description {
422            quote::quote! { Some(#desc) }
423        } else {
424            quote::quote! { None }
425        };
426
427        let schema_extra_ts = if let Some(ref extra) = field.schema_extra {
428            quote::quote! { Some(#extra) }
429        } else {
430            quote::quote! { None }
431        };
432
433        // Default JSON for exclude_defaults support
434        let default_json_ts = if let Some(ref dj) = field.default_json {
435            quote::quote! { Some(#dj) }
436        } else {
437            quote::quote! { None }
438        };
439
440        // Const field
441        let const_field = field.const_field;
442
443        // Column constraints: sa_column.check is used if sa_column is present,
444        // otherwise field.column_constraints (validation prevents both being set)
445        let effective_constraints: Vec<&String> = if let Some(sc) = sa_col {
446            sc.check.iter().collect()
447        } else {
448            field.column_constraints.iter().collect()
449        };
450        let column_constraints_ts = if effective_constraints.is_empty() {
451            quote::quote! { &[] }
452        } else {
453            quote::quote! { &[#(#effective_constraints),*] }
454        };
455
456        // Column comment: sa_column.comment is used if sa_column is present,
457        // otherwise field.column_comment (validation prevents both being set)
458        let effective_comment = sa_col
459            .and_then(|sc| sc.comment.as_ref())
460            .or(field.column_comment.as_ref());
461        let column_comment_ts = if let Some(comment) = effective_comment {
462            quote::quote! { Some(#comment) }
463        } else {
464            quote::quote! { None }
465        };
466
467        // Column info
468        let column_info_ts = if let Some(ref info) = field.column_info {
469            quote::quote! { Some(#info) }
470        } else {
471            quote::quote! { None }
472        };
473
474        // Hybrid SQL expression
475        let hybrid_sql_ts = if let Some(ref sql) = field.hybrid_sql {
476            quote::quote! { Some(#sql) }
477        } else {
478            quote::quote! { None }
479        };
480
481        // Discriminator for union types
482        let discriminator_ts = if let Some(ref disc) = field.discriminator {
483            quote::quote! { Some(#disc) }
484        } else {
485            quote::quote! { None }
486        };
487
488        // Decimal precision (max_digits -> precision, decimal_places -> scale)
489        let precision_ts = if let Some(p) = field.max_digits {
490            quote::quote! { Some(#p) }
491        } else {
492            quote::quote! { None }
493        };
494
495        let scale_ts = if let Some(s) = field.decimal_places {
496            quote::quote! { Some(#s) }
497        } else {
498            quote::quote! { None }
499        };
500
501        field_ts.push(quote::quote! {
502            sqlmodel_core::FieldInfo::new(stringify!(#field_ident), #column_name, #sql_type_ts)
503                .sql_type_override_opt(#sql_type_override_ts)
504                .precision_opt(#precision_ts)
505                .scale_opt(#scale_ts)
506                .nullable(#nullable)
507                .primary_key(#primary_key)
508                .auto_increment(#auto_increment)
509                .unique(#unique)
510                .default_opt(#default_ts)
511                .foreign_key_opt(#fk_ts)
512                .on_delete_opt(#on_delete_ts)
513                .on_update_opt(#on_update_ts)
514                .index_opt(#index_ts)
515                .alias_opt(#alias_ts)
516                .validation_alias_opt(#validation_alias_ts)
517                .serialization_alias_opt(#serialization_alias_ts)
518                .computed(#computed)
519                .exclude(#exclude)
520                .title_opt(#title_ts)
521                .description_opt(#description_ts)
522                .schema_extra_opt(#schema_extra_ts)
523                .default_json_opt(#default_json_ts)
524                .const_field(#const_field)
525                .column_constraints(#column_constraints_ts)
526                .column_comment_opt(#column_comment_ts)
527                .column_info_opt(#column_info_ts)
528                .hybrid_sql_opt(#hybrid_sql_ts)
529                .discriminator_opt(#discriminator_ts)
530        });
531    }
532
533    quote::quote! { #(#field_ts),* }
534}
535
536/// Generate the to_row method body.
537fn generate_to_row(model: &ModelDef) -> proc_macro2::TokenStream {
538    let mut conversions = Vec::new();
539
540    for field in model.select_fields() {
541        let field_name = &field.name;
542        let column_name = &field.column_name;
543
544        // Convert field to Value
545        if parse::is_option_type(&field.ty) {
546            conversions.push(quote::quote! {
547                (#column_name, match &self.#field_name {
548                    Some(v) => sqlmodel_core::Value::from(v.clone()),
549                    None => sqlmodel_core::Value::Null,
550                })
551            });
552        } else {
553            conversions.push(quote::quote! {
554                (#column_name, sqlmodel_core::Value::from(self.#field_name.clone()))
555            });
556        }
557    }
558
559    quote::quote! {
560        let mut out = vec![#(#conversions),*];
561
562        // Single-table inheritance child models should always emit their discriminator
563        // so inserts/updates can round-trip correctly even if the struct doesn't have a
564        // dedicated discriminator field.
565        let inh = <Self as sqlmodel_core::Model>::inheritance();
566        if let (Some(col), Some(val)) = (inh.discriminator_column, inh.discriminator_value) {
567            if !out.iter().any(|(c, _)| *c == col) {
568                out.push((col, sqlmodel_core::Value::from(val)));
569            }
570        }
571
572        out
573    }
574}
575
576/// Generate the from_row method body.
577fn generate_from_row(model: &ModelDef) -> proc_macro2::TokenStream {
578    let name = &model.name;
579    let mut field_extractions = Vec::new();
580
581    // Support both "plain" rows (SELECT *) and prefixed/aliased rows (e.g. eager loading,
582    // joined inheritance) by looking for `table__col` prefixes.
583    let row_ident = quote::format_ident!("local_row");
584
585    for field in model.select_fields() {
586        let field_name = &field.name;
587        let column_name = &field.column_name;
588
589        if parse::is_option_type(&field.ty) {
590            // For Option<T> fields, handle NULL gracefully
591            field_extractions.push(quote::quote! {
592                #field_name: #row_ident.get_named(#column_name).ok()
593            });
594        } else {
595            // For required fields, propagate errors
596            field_extractions.push(quote::quote! {
597                #field_name: #row_ident.get_named(#column_name)?
598            });
599        }
600    }
601
602    // Handle skipped fields with Default
603    let skipped_fields: Vec<_> = model
604        .fields
605        .iter()
606        .filter(|f| f.skip)
607        .map(|f| {
608            let field_name = &f.name;
609            quote::quote! { #field_name: Default::default() }
610        })
611        .collect();
612
613    // Handle relationship fields with Default (they're not in the DB row)
614    let relationship_fields: Vec<_> = model
615        .fields
616        .iter()
617        .filter(|f| f.relationship.is_some())
618        .map(|f| {
619            let field_name = &f.name;
620            quote::quote! { #field_name: Default::default() }
621        })
622        .collect();
623
624    // Joined-table inheritance parent field hydration.
625    let parent_fields: Vec<_> = model
626        .fields
627        .iter()
628        .filter(|f| f.parent)
629        .map(|f| {
630            let field_name = &f.name;
631            let ty = &f.ty;
632            quote::quote! {
633                #field_name: {
634                    let inh = <Self as sqlmodel_core::Model>::inheritance();
635                    let parent_table = inh.parent.ok_or_else(|| {
636                        sqlmodel_core::Error::Custom(
637                            "joined inheritance parent_table missing in inheritance metadata".to_string(),
638                        )
639                    })?;
640                    if !row.has_prefix(parent_table) {
641                        return Err(sqlmodel_core::Error::Custom(format!(
642                            "expected prefixed parent columns for joined inheritance: {}__*",
643                            parent_table
644                        )));
645                    }
646                    let prow = row.subset_by_prefix(parent_table);
647                    <#ty as sqlmodel_core::Model>::from_row(&prow)?
648                }
649            }
650        })
651        .collect();
652
653    // Handle computed fields with Default (they're not in the DB row)
654    let computed_fields: Vec<_> = model
655        .computed_fields()
656        .iter()
657        .map(|f| {
658            let field_name = &f.name;
659            quote::quote! { #field_name: Default::default() }
660        })
661        .collect();
662
663    quote::quote! {
664        let #row_ident = if row.has_prefix(<Self as sqlmodel_core::Model>::TABLE_NAME) {
665            row.subset_by_prefix(<Self as sqlmodel_core::Model>::TABLE_NAME)
666        } else {
667            row.clone()
668        };
669
670        Ok(#name {
671            #(#field_extractions,)*
672            #(#skipped_fields,)*
673            #(#relationship_fields,)*
674            #(#parent_fields,)*
675            #(#computed_fields,)*
676        })
677    }
678}
679
680/// Generate the primary_key_value method body.
681fn generate_primary_key_value(model: &ModelDef) -> proc_macro2::TokenStream {
682    let pk_fields = model.primary_key_fields();
683
684    if pk_fields.is_empty() {
685        // Try to use "id" field if it exists
686        let id_field = model.fields.iter().find(|f| f.name == "id");
687        if let Some(field) = id_field {
688            let field_name = &field.name;
689            if parse::is_option_type(&field.ty) {
690                return quote::quote! {
691                    match &self.#field_name {
692                        Some(v) => vec![sqlmodel_core::Value::from(v.clone())],
693                        None => vec![sqlmodel_core::Value::Null],
694                    }
695                };
696            }
697            return quote::quote! {
698                vec![sqlmodel_core::Value::from(self.#field_name.clone())]
699            };
700        }
701        return quote::quote! { vec![] };
702    }
703
704    let mut value_exprs = Vec::new();
705    for field in pk_fields {
706        let field_name = &field.name;
707        if parse::is_option_type(&field.ty) {
708            value_exprs.push(quote::quote! {
709                match &self.#field_name {
710                    Some(v) => sqlmodel_core::Value::from(v.clone()),
711                    None => sqlmodel_core::Value::Null,
712                }
713            });
714        } else {
715            value_exprs.push(quote::quote! {
716                sqlmodel_core::Value::from(self.#field_name.clone())
717            });
718        }
719    }
720
721    quote::quote! {
722        vec![#(#value_exprs),*]
723    }
724}
725
726/// Generate the is_new method body.
727fn generate_is_new(model: &ModelDef) -> proc_macro2::TokenStream {
728    let pk_fields = model.primary_key_fields();
729
730    // If there's an auto_increment primary key field that is Option<T>,
731    // check if it's None
732    for field in &pk_fields {
733        if field.auto_increment && parse::is_option_type(&field.ty) {
734            let field_name = &field.name;
735            return quote::quote! {
736                self.#field_name.is_none()
737            };
738        }
739    }
740
741    // Otherwise, try "id" field if it exists and is Option<T>
742    if let Some(id_field) = model.fields.iter().find(|f| f.name == "id") {
743        if parse::is_option_type(&id_field.ty) {
744            return quote::quote! {
745                self.id.is_none()
746            };
747        }
748    }
749
750    // Default: cannot determine, always return true
751    quote::quote! { true }
752}
753
754/// Generate the model_config method body.
755fn generate_model_config(model: &ModelDef) -> proc_macro2::TokenStream {
756    let config = &model.config;
757
758    let table = config.table;
759    let from_attributes = config.from_attributes;
760    let validate_assignment = config.validate_assignment;
761    let strict = config.strict;
762    let populate_by_name = config.populate_by_name;
763    let use_enum_values = config.use_enum_values;
764    let arbitrary_types_allowed = config.arbitrary_types_allowed;
765    let defer_build = config.defer_build;
766    let revalidate_instances = config.revalidate_instances;
767
768    // Handle extra field behavior
769    let extra_ts = match config.extra.as_str() {
770        "forbid" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Forbid },
771        "allow" => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Allow },
772        _ => quote::quote! { sqlmodel_core::ExtraFieldsBehavior::Ignore },
773    };
774
775    // Handle optional string fields
776    let json_schema_extra_ts = if let Some(ref extra) = config.json_schema_extra {
777        quote::quote! { Some(#extra) }
778    } else {
779        quote::quote! { None }
780    };
781
782    let title_ts = if let Some(ref title) = config.title {
783        quote::quote! { Some(#title) }
784    } else {
785        quote::quote! { None }
786    };
787
788    quote::quote! {
789        sqlmodel_core::ModelConfig {
790            table: #table,
791            from_attributes: #from_attributes,
792            validate_assignment: #validate_assignment,
793            extra: #extra_ts,
794            strict: #strict,
795            populate_by_name: #populate_by_name,
796            use_enum_values: #use_enum_values,
797            arbitrary_types_allowed: #arbitrary_types_allowed,
798            defer_build: #defer_build,
799            revalidate_instances: #revalidate_instances,
800            json_schema_extra: #json_schema_extra_ts,
801            title: #title_ts,
802        }
803    }
804}
805
806/// Generate the inheritance method body.
807fn generate_inheritance(model: &ModelDef) -> proc_macro2::TokenStream {
808    use crate::parse::InheritanceStrategy;
809
810    let config = &model.config;
811
812    // Determine the inheritance strategy token
813    let strategy_ts = match config.inheritance {
814        InheritanceStrategy::None => {
815            quote::quote! { sqlmodel_core::InheritanceStrategy::None }
816        }
817        InheritanceStrategy::Single => {
818            quote::quote! { sqlmodel_core::InheritanceStrategy::Single }
819        }
820        InheritanceStrategy::Joined => {
821            quote::quote! { sqlmodel_core::InheritanceStrategy::Joined }
822        }
823        InheritanceStrategy::Concrete => {
824            quote::quote! { sqlmodel_core::InheritanceStrategy::Concrete }
825        }
826    };
827
828    // Helper: interpret `inherits = "..."` as a Rust type path in the current scope.
829    // We keep it as a string in parsing so attribute syntax stays simple; here we
830    // translate it into type tokens for codegen.
831    let parent_ty_ts: Option<proc_macro2::TokenStream> = config.inherits.as_deref().map(|p| {
832        if let Ok(path) = syn::parse_str::<syn::Path>(p) {
833            quote::quote! { #path }
834        } else {
835            let ident = syn::Ident::new(p, proc_macro2::Span::call_site());
836            quote::quote! { #ident }
837        }
838    });
839
840    // Store parent table name (not parent Rust type name) in metadata so schema/DDL can be correct.
841    let parent_table_ts = if let Some(ref parent_ty) = parent_ty_ts {
842        quote::quote! { Some(<#parent_ty as sqlmodel_core::Model>::TABLE_NAME) }
843    } else {
844        quote::quote! { None }
845    };
846
847    let parent_fields_fn_ts = if let Some(ref parent_ty) = parent_ty_ts {
848        quote::quote! { Some(<#parent_ty as sqlmodel_core::Model>::fields) }
849    } else {
850        quote::quote! { None }
851    };
852
853    // Handle discriminator column.
854    //
855    // - Base STI models specify it explicitly: `#[sqlmodel(inheritance="single", discriminator="...")]`
856    // - Child STI models inherit it from the parent so query/schema tooling has the column name
857    //   available without requiring duplicate annotation.
858    let discriminator_column_ts = if let Some(ref column) = config.discriminator_column {
859        quote::quote! { Some(#column) }
860    } else if config.discriminator_value.is_some() {
861        if let Some(parent_ty) = parent_ty_ts.as_ref() {
862            quote::quote! { <#parent_ty as sqlmodel_core::Model>::inheritance().discriminator_column }
863        } else {
864            quote::quote! { None }
865        }
866    } else {
867        quote::quote! { None }
868    };
869
870    // Handle discriminator value (for child models)
871    let discriminator_value_ts = if let Some(ref value) = config.discriminator_value {
872        quote::quote! { Some(#value) }
873    } else {
874        quote::quote! { None }
875    };
876
877    quote::quote! {
878        sqlmodel_core::InheritanceInfo {
879            strategy: #strategy_ts,
880            parent: #parent_table_ts,
881            parent_fields_fn: #parent_fields_fn_ts,
882            discriminator_column: #discriminator_column_ts,
883            discriminator_value: #discriminator_value_ts,
884        }
885    }
886}
887
888/// Generate the shard_key constant and shard_key_value method body.
889///
890/// Returns a tuple of (const token, method body token) for:
891/// - `const SHARD_KEY: Option<&'static str>`
892/// - `fn shard_key_value(&self) -> Option<Value>`
893fn generate_shard_key(model: &ModelDef) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
894    let config = &model.config;
895
896    if let Some(ref shard_key_name) = config.shard_key {
897        // Find the shard key field to get its type info
898        let shard_field = model.fields.iter().find(|f| f.name == shard_key_name);
899
900        let const_ts = quote::quote! { Some(#shard_key_name) };
901
902        // Generate the method body based on whether the field exists and its type
903        let value_body = if let Some(field) = shard_field {
904            let field_ident = &field.name;
905            if parse::is_option_type(&field.ty) {
906                // Option<T> field: return Some(value) if Some, None if None
907                quote::quote! {
908                    match &self.#field_ident {
909                        Some(v) => Some(sqlmodel_core::Value::from(v.clone())),
910                        None => None,
911                    }
912                }
913            } else {
914                // Non-optional field: always has a value
915                quote::quote! {
916                    Some(sqlmodel_core::Value::from(self.#field_ident.clone()))
917                }
918            }
919        } else {
920            // Field not found - this is a compile error in validation,
921            // but generate safe fallback code
922            quote::quote! { None }
923        };
924
925        (const_ts, value_body)
926    } else {
927        // No shard key configured
928        let const_ts = quote::quote! { None };
929        let value_body = quote::quote! { None };
930        (const_ts, value_body)
931    }
932}
933
934/// Generate a custom Debug implementation if any field has repr=false.
935///
936/// This generates a Debug impl that excludes fields marked with `repr = false`,
937/// which is useful for hiding sensitive data like passwords from debug output.
938fn generate_debug_impl(model: &ModelDef) -> proc_macro2::TokenStream {
939    // Check if any field has repr=false
940    let has_hidden_fields = model.fields.iter().any(|f| !f.repr);
941
942    // Only generate custom Debug if there are hidden fields
943    if !has_hidden_fields {
944        return quote::quote! {};
945    }
946
947    let name = &model.name;
948    let (impl_generics, ty_generics, where_clause) = model.generics.split_for_impl();
949
950    // Generate field entries for Debug, excluding fields with repr=false
951    let debug_fields: Vec<_> = model
952        .fields
953        .iter()
954        .filter(|f| f.repr) // Only include fields with repr=true
955        .map(|f| {
956            let field_name = &f.name;
957            let field_name_str = field_name.to_string();
958            quote::quote! {
959                .field(#field_name_str, &self.#field_name)
960            }
961        })
962        .collect();
963
964    let struct_name_str = name.to_string();
965
966    quote::quote! {
967        impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause {
968            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
969                f.debug_struct(#struct_name_str)
970                    #(#debug_fields)*
971                    .finish()
972            }
973        }
974    }
975}
976
977/// Generate the RELATIONSHIPS constant from relationship fields.
978fn generate_relationships(model: &ModelDef) -> proc_macro2::TokenStream {
979    fn relationship_inner_model_ty(ty: &syn::Type) -> Option<syn::Type> {
980        let syn::Type::Path(tp) = ty else {
981            return None;
982        };
983
984        let last = tp.path.segments.last()?;
985        let ident = last.ident.to_string();
986        if ident != "Related" && ident != "RelatedMany" && ident != "Lazy" {
987            return None;
988        }
989
990        let syn::PathArguments::AngleBracketed(args) = &last.arguments else {
991            return None;
992        };
993
994        args.args.iter().find_map(|arg| match arg {
995            syn::GenericArgument::Type(t) => Some(t.clone()),
996            _ => None,
997        })
998    }
999
1000    let relationship_fields = model.relationship_fields();
1001
1002    if relationship_fields.is_empty() {
1003        return quote::quote! { &[] };
1004    }
1005
1006    let mut relationship_ts = Vec::new();
1007
1008    for field in relationship_fields {
1009        let Some(rel) = field.relationship.as_ref() else {
1010            relationship_ts.push(quote::quote! {
1011                ::core::compile_error!(
1012                    "sqlmodel: internal error: relationship field missing parsed relationship metadata"
1013                )
1014            });
1015            continue;
1016        };
1017        let field_name = &field.name;
1018        let related_table = &rel.model;
1019
1020        let Some(related_ty) = relationship_inner_model_ty(&field.ty) else {
1021            relationship_ts.push(quote::quote! {
1022                ::core::compile_error!(
1023                    "sqlmodel: relationship field type must be Related<T>, RelatedMany<T>, or Lazy<T>"
1024                )
1025            });
1026            continue;
1027        };
1028
1029        // Determine RelationshipKind token
1030        let kind_ts = match rel.kind {
1031            RelationshipKindAttr::OneToOne => {
1032                quote::quote! { sqlmodel_core::RelationshipKind::OneToOne }
1033            }
1034            RelationshipKindAttr::ManyToOne => {
1035                quote::quote! { sqlmodel_core::RelationshipKind::ManyToOne }
1036            }
1037            RelationshipKindAttr::OneToMany => {
1038                quote::quote! { sqlmodel_core::RelationshipKind::OneToMany }
1039            }
1040            RelationshipKindAttr::ManyToMany => {
1041                quote::quote! { sqlmodel_core::RelationshipKind::ManyToMany }
1042            }
1043        };
1044
1045        // Build optional method calls
1046        let local_key_call = if let Some(ref fk) = rel.foreign_key {
1047            quote::quote! { .local_key(#fk) }
1048        } else {
1049            quote::quote! {}
1050        };
1051
1052        let remote_key_call = if let Some(ref rk) = rel.remote_key {
1053            quote::quote! { .remote_key(#rk) }
1054        } else {
1055            quote::quote! {}
1056        };
1057
1058        let back_populates_call = if let Some(ref bp) = rel.back_populates {
1059            quote::quote! { .back_populates(#bp) }
1060        } else {
1061            quote::quote! {}
1062        };
1063
1064        let link_table_call = if let Some(ref lt) = rel.link_table {
1065            let table = &lt.table;
1066            let local_col = &lt.local_column;
1067            let remote_col = &lt.remote_column;
1068            quote::quote! {
1069                .link_table(sqlmodel_core::LinkTableInfo::new(#table, #local_col, #remote_col))
1070            }
1071        } else {
1072            quote::quote! {}
1073        };
1074
1075        let lazy_val = rel.lazy;
1076        let cascade_val = rel.cascade_delete;
1077        let passive_deletes_ts = match rel.passive_deletes {
1078            crate::parse::PassiveDeletesAttr::Active => {
1079                quote::quote! { sqlmodel_core::PassiveDeletes::Active }
1080            }
1081            crate::parse::PassiveDeletesAttr::Passive => {
1082                quote::quote! { sqlmodel_core::PassiveDeletes::Passive }
1083            }
1084            crate::parse::PassiveDeletesAttr::All => {
1085                quote::quote! { sqlmodel_core::PassiveDeletes::All }
1086            }
1087        };
1088
1089        // New sa_relationship fields
1090        let order_by_call = if let Some(ref ob) = rel.order_by {
1091            quote::quote! { .order_by(#ob) }
1092        } else {
1093            quote::quote! {}
1094        };
1095
1096        let lazy_strategy_call = if let Some(ref strategy) = rel.lazy_strategy {
1097            let strategy_ts = match strategy {
1098                crate::parse::LazyLoadStrategyAttr::Select => {
1099                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Select }
1100                }
1101                crate::parse::LazyLoadStrategyAttr::Joined => {
1102                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Joined }
1103                }
1104                crate::parse::LazyLoadStrategyAttr::Subquery => {
1105                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Subquery }
1106                }
1107                crate::parse::LazyLoadStrategyAttr::Selectin => {
1108                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Selectin }
1109                }
1110                crate::parse::LazyLoadStrategyAttr::Dynamic => {
1111                    quote::quote! { sqlmodel_core::LazyLoadStrategy::Dynamic }
1112                }
1113                crate::parse::LazyLoadStrategyAttr::NoLoad => {
1114                    quote::quote! { sqlmodel_core::LazyLoadStrategy::NoLoad }
1115                }
1116                crate::parse::LazyLoadStrategyAttr::RaiseOnSql => {
1117                    quote::quote! { sqlmodel_core::LazyLoadStrategy::RaiseOnSql }
1118                }
1119                crate::parse::LazyLoadStrategyAttr::WriteOnly => {
1120                    quote::quote! { sqlmodel_core::LazyLoadStrategy::WriteOnly }
1121                }
1122            };
1123            quote::quote! { .lazy_strategy(#strategy_ts) }
1124        } else {
1125            quote::quote! {}
1126        };
1127
1128        let cascade_call = if let Some(ref c) = rel.cascade {
1129            quote::quote! { .cascade(#c) }
1130        } else {
1131            quote::quote! {}
1132        };
1133
1134        let uselist_call = if let Some(ul) = rel.uselist {
1135            quote::quote! { .uselist(#ul) }
1136        } else {
1137            quote::quote! {}
1138        };
1139
1140        relationship_ts.push(quote::quote! {
1141            sqlmodel_core::RelationshipInfo::new(
1142                stringify!(#field_name),
1143                #related_table,
1144                #kind_ts
1145            )
1146            .related_fields(<#related_ty as sqlmodel_core::Model>::fields)
1147            #local_key_call
1148            #remote_key_call
1149            #back_populates_call
1150            #link_table_call
1151            .lazy(#lazy_val)
1152            .cascade_delete(#cascade_val)
1153            .passive_deletes(#passive_deletes_ts)
1154            #order_by_call
1155            #lazy_strategy_call
1156            #cascade_call
1157            #uselist_call
1158        });
1159    }
1160
1161    quote::quote! {
1162        &[#(#relationship_ts),*]
1163    }
1164}
1165
1166/// Derive macro for field validation.
1167///
1168/// Generates a `validate()` method that checks field constraints at runtime.
1169///
1170/// # Attributes
1171///
1172/// - `#[validate(min = N)]` - Minimum value for numbers
1173/// - `#[validate(max = N)]` - Maximum value for numbers
1174/// - `#[validate(min_length = N)]` - Minimum length for strings
1175/// - `#[validate(max_length = N)]` - Maximum length for strings
1176/// - `#[validate(pattern = "regex")]` - Regex pattern for strings
1177/// - `#[validate(email)]` - Email format validation
1178/// - `#[validate(url)]` - URL format validation
1179/// - `#[validate(required)]` - Mark an Option<T> field as required
1180/// - `#[validate(custom = "fn_name")]` - Custom validation function
1181///
1182/// # Example
1183///
1184/// ```ignore
1185/// use sqlmodel::Validate;
1186///
1187/// #[derive(Validate)]
1188/// struct User {
1189///     #[validate(min_length = 1, max_length = 100)]
1190///     name: String,
1191///
1192///     #[validate(min = 0, max = 150)]
1193///     age: i32,
1194///
1195///     #[validate(email)]
1196///     email: String,
1197///
1198///     #[validate(required)]
1199///     team_id: Option<i64>,
1200/// }
1201///
1202/// let user = User {
1203///     name: "".to_string(),
1204///     age: 200,
1205///     email: "invalid".to_string(),
1206///     team_id: None,
1207/// };
1208///
1209/// // Returns Err with all validation failures
1210/// let result = user.validate();
1211/// assert!(result.is_err());
1212/// ```
1213#[proc_macro_derive(Validate, attributes(validate))]
1214pub fn derive_validate(input: TokenStream) -> TokenStream {
1215    let input = syn::parse_macro_input!(input as syn::DeriveInput);
1216
1217    // Parse the struct and its validation attributes
1218    let def = match validate_derive::parse_validate(&input) {
1219        Ok(d) => d,
1220        Err(e) => return e.to_compile_error().into(),
1221    };
1222
1223    // Generate the validation implementation
1224    validate_derive::generate_validate_impl(&def).into()
1225}
1226
1227/// Derive macro for SQL enum types.
1228///
1229/// Generates `SqlEnum` trait implementation, `From<EnumType> for Value`,
1230/// `TryFrom<Value> for EnumType`, and `Display`/`FromStr` implementations.
1231///
1232/// Enum variants are mapped to their snake_case string representations by default.
1233/// Use `#[sqlmodel(rename = "custom_name")]` on variants to override.
1234///
1235/// # Example
1236///
1237/// ```ignore
1238/// #[derive(SqlEnum, Debug, Clone, PartialEq)]
1239/// enum Status {
1240///     Active,
1241///     Inactive,
1242///     #[sqlmodel(rename = "on_hold")]
1243///     OnHold,
1244/// }
1245/// ```
1246#[proc_macro_derive(SqlEnum, attributes(sqlmodel))]
1247pub fn derive_sql_enum(input: TokenStream) -> TokenStream {
1248    let input = syn::parse_macro_input!(input as syn::DeriveInput);
1249    match generate_sql_enum_impl(&input) {
1250        Ok(tokens) => tokens.into(),
1251        Err(e) => e.to_compile_error().into(),
1252    }
1253}
1254
1255fn generate_sql_enum_impl(input: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
1256    let name = &input.ident;
1257    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
1258
1259    let syn::Data::Enum(data) = &input.data else {
1260        return Err(syn::Error::new_spanned(
1261            input,
1262            "SqlEnum can only be derived for enums",
1263        ));
1264    };
1265
1266    // Collect variant info
1267    let mut variant_names = Vec::new();
1268    let mut variant_strings = Vec::new();
1269
1270    for variant in &data.variants {
1271        if !variant.fields.is_empty() {
1272            return Err(syn::Error::new_spanned(
1273                variant,
1274                "SqlEnum variants must be unit variants (no fields)",
1275            ));
1276        }
1277
1278        let ident = &variant.ident;
1279        variant_names.push(ident.clone());
1280
1281        // Check for #[sqlmodel(rename = "...")] attribute
1282        let mut custom_name = None;
1283        for attr in &variant.attrs {
1284            if attr.path().is_ident("sqlmodel") {
1285                attr.parse_nested_meta(|meta| {
1286                    if meta.path.is_ident("rename") {
1287                        let value = meta.value()?;
1288                        let s: syn::LitStr = value.parse()?;
1289                        custom_name = Some(s.value());
1290                    }
1291                    Ok(())
1292                })?;
1293            }
1294        }
1295
1296        let sql_str = custom_name.unwrap_or_else(|| to_snake_case(&ident.to_string()));
1297        variant_strings.push(sql_str);
1298    }
1299
1300    let type_name = to_snake_case(&name.to_string());
1301
1302    // Generate static VARIANTS array
1303    let variant_str_refs: Vec<_> = variant_strings.iter().map(|s| s.as_str()).collect();
1304
1305    let to_sql_arms: Vec<_> = variant_names
1306        .iter()
1307        .zip(variant_strings.iter())
1308        .map(|(ident, s)| {
1309            quote::quote! { #name::#ident => #s }
1310        })
1311        .collect();
1312
1313    let from_sql_arms: Vec<_> = variant_names
1314        .iter()
1315        .zip(variant_strings.iter())
1316        .map(|(ident, s)| {
1317            quote::quote! { #s => Ok(#name::#ident) }
1318        })
1319        .collect();
1320
1321    // Build the error message listing valid values
1322    let valid_values: String = variant_strings
1323        .iter()
1324        .map(|s| format!("'{}'", s))
1325        .collect::<Vec<_>>()
1326        .join(", ");
1327    let error_msg = format!(
1328        "invalid value for {}: expected one of {}",
1329        name, valid_values
1330    );
1331
1332    Ok(quote::quote! {
1333        impl #impl_generics sqlmodel_core::SqlEnum for #name #ty_generics #where_clause {
1334            const VARIANTS: &'static [&'static str] = &[#(#variant_str_refs),*];
1335            const TYPE_NAME: &'static str = #type_name;
1336
1337            fn to_sql_str(&self) -> &'static str {
1338                match self {
1339                    #(#to_sql_arms,)*
1340                }
1341            }
1342
1343            fn from_sql_str(s: &str) -> Result<Self, String> {
1344                match s {
1345                    #(#from_sql_arms,)*
1346                    _ => Err(format!("{}, got '{}'", #error_msg, s)),
1347                }
1348            }
1349        }
1350
1351        impl #impl_generics From<#name #ty_generics> for sqlmodel_core::Value #where_clause {
1352            fn from(v: #name #ty_generics) -> Self {
1353                sqlmodel_core::Value::Text(
1354                    sqlmodel_core::SqlEnum::to_sql_str(&v).to_string()
1355                )
1356            }
1357        }
1358
1359        impl #impl_generics From<&#name #ty_generics> for sqlmodel_core::Value #where_clause {
1360            fn from(v: &#name #ty_generics) -> Self {
1361                sqlmodel_core::Value::Text(
1362                    sqlmodel_core::SqlEnum::to_sql_str(v).to_string()
1363                )
1364            }
1365        }
1366
1367        impl #impl_generics TryFrom<sqlmodel_core::Value> for #name #ty_generics #where_clause {
1368            type Error = sqlmodel_core::Error;
1369
1370            fn try_from(value: sqlmodel_core::Value) -> Result<Self, Self::Error> {
1371                match value {
1372                    sqlmodel_core::Value::Text(ref s) => {
1373                        sqlmodel_core::SqlEnum::from_sql_str(s.as_str()).map_err(|e| {
1374                            sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1375                                expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1376                                actual: e,
1377                                column: None,
1378                                rust_type: None,
1379                            })
1380                        })
1381                    }
1382                    other => Err(sqlmodel_core::Error::Type(sqlmodel_core::error::TypeError {
1383                        expected: <#name as sqlmodel_core::SqlEnum>::TYPE_NAME,
1384                        actual: other.type_name().to_string(),
1385                        column: None,
1386                        rust_type: None,
1387                    })),
1388                }
1389            }
1390        }
1391
1392        impl #impl_generics ::core::fmt::Display for #name #ty_generics #where_clause {
1393            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1394                f.write_str(sqlmodel_core::SqlEnum::to_sql_str(self))
1395            }
1396        }
1397
1398        impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause {
1399            type Err = String;
1400
1401            fn from_str(s: &str) -> Result<Self, Self::Err> {
1402                sqlmodel_core::SqlEnum::from_sql_str(s)
1403            }
1404        }
1405    })
1406}
1407
1408fn to_snake_case(s: &str) -> String {
1409    let mut result = String::with_capacity(s.len() + 4);
1410    let chars: Vec<char> = s.chars().collect();
1411    for (i, &ch) in chars.iter().enumerate() {
1412        if ch.is_uppercase() {
1413            if i > 0 {
1414                let prev_lower = chars[i - 1].is_lowercase();
1415                let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
1416                // "FooBar" -> "foo_bar": insert _ when prev is lowercase
1417                // "HTTPStatus" -> "http_status": insert _ when next is lowercase (acronym boundary)
1418                if prev_lower || (next_lower && chars[i - 1].is_uppercase()) {
1419                    result.push('_');
1420                }
1421            }
1422            result.push(ch.to_ascii_lowercase());
1423        } else {
1424            result.push(ch);
1425        }
1426    }
1427    result
1428}
1429
1430/// Attribute macro for defining SQL functions in handlers.
1431///
1432/// # Example
1433///
1434/// ```ignore
1435/// #[sqlmodel::query]
1436/// async fn get_heroes(cx: &Cx, conn: &impl Connection) -> Vec<Hero> {
1437///     sqlmodel::select!(Hero).all(cx, conn).await
1438/// }
1439/// ```
1440#[proc_macro_attribute]
1441pub fn query(_attr: TokenStream, item: TokenStream) -> TokenStream {
1442    let original = item.clone();
1443
1444    let func: syn::ItemFn = match syn::parse(item) {
1445        Ok(f) => f,
1446        Err(e) => return e.to_compile_error().into(),
1447    };
1448
1449    if func.sig.asyncness.is_none() {
1450        return syn::Error::new_spanned(
1451            func.sig.fn_token,
1452            "#[sqlmodel::query] requires an async fn",
1453        )
1454        .to_compile_error()
1455        .into();
1456    }
1457
1458    let Some(first_arg) = func.sig.inputs.first() else {
1459        return syn::Error::new_spanned(
1460            &func.sig.ident,
1461            "#[sqlmodel::query] requires the first parameter to be `cx: &Cx`",
1462        )
1463        .to_compile_error()
1464        .into();
1465    };
1466
1467    let first_ty = match first_arg {
1468        syn::FnArg::Typed(pat_ty) => &*pat_ty.ty,
1469        syn::FnArg::Receiver(recv) => {
1470            return syn::Error::new_spanned(
1471                recv,
1472                "#[sqlmodel::query] does not support methods; use a free function",
1473            )
1474            .to_compile_error()
1475            .into();
1476        }
1477    };
1478
1479    if !is_ref_to_cx(first_ty) {
1480        return syn::Error::new_spanned(
1481            first_ty,
1482            "#[sqlmodel::query] requires the first parameter to be `cx: &Cx`",
1483        )
1484        .to_compile_error()
1485        .into();
1486    }
1487
1488    original
1489}
1490
1491fn is_ref_to_cx(ty: &syn::Type) -> bool {
1492    let syn::Type::Reference(r) = ty else {
1493        return false;
1494    };
1495    is_cx_path(&r.elem)
1496}
1497
1498fn is_cx_path(ty: &syn::Type) -> bool {
1499    let syn::Type::Path(p) = ty else {
1500        return false;
1501    };
1502    p.path.segments.last().is_some_and(|seg| seg.ident == "Cx")
1503}