Skip to main content

sqlx_gen/codegen/
entity_parser.rs

1use std::path::Path;
2
3use quote::ToTokens;
4
5/// Represents a field parsed from a generated entity struct.
6#[derive(Debug, Clone)]
7pub struct ParsedField {
8    /// Rust field name (e.g. "connector_type")
9    pub rust_name: String,
10    /// Original DB column name. From `#[sqlx(rename = "...")]` if present, otherwise same as rust_name.
11    pub column_name: String,
12    /// Full Rust type as a string (e.g. "Option<String>", "i32", "uuid::Uuid")
13    pub rust_type: String,
14    /// Whether the type is wrapped in Option<T>
15    pub is_nullable: bool,
16    /// The inner type if nullable, or the full type if not
17    pub inner_type: String,
18    /// Whether this field is a primary key (`#[sqlx_gen(primary_key)]`)
19    pub is_primary_key: bool,
20    /// SQL type name for custom types needing a cast (e.g. "agent.connector_usages")
21    pub sql_type: Option<String>,
22    /// Whether the SQL type is an array (needs `[]` suffix in cast)
23    pub is_sql_array: bool,
24    /// Raw SQL default expression from the DB (e.g. "now()", "'idle'::task_status")
25    pub column_default: Option<String>,
26}
27
28/// Represents an entity parsed from a generated Rust file.
29#[derive(Debug, Clone)]
30pub struct ParsedEntity {
31    /// Struct name in PascalCase (e.g. "Users", "UserRoles")
32    pub struct_name: String,
33    /// Original table/view name from `#[sqlx_gen(table = "...")]`
34    pub table_name: String,
35    /// Schema name from `#[sqlx_gen(schema = "...")]`
36    pub schema_name: Option<String>,
37    /// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
38    pub is_view: bool,
39    /// Parsed fields
40    pub fields: Vec<ParsedField>,
41    /// `use` imports from the entity source file (e.g. "use chrono::{DateTime, Utc};")
42    pub imports: Vec<String>,
43}
44
45/// Parse an entity struct from a `.rs` file on disk.
46pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
47    let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
48    parse_entity_source(&source).map_err(|e| {
49        crate::error::Error::Config(format!("{}: {}", path.display(), e))
50    })
51}
52
53/// Parse an entity struct from a Rust source string.
54pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
55    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
56
57    // Collect use imports (excluding serde and sqlx derives)
58    let imports = extract_use_imports(&syntax);
59
60    for item in &syntax.items {
61        if let syn::Item::Struct(item_struct) = item {
62            if has_from_row_derive(item_struct) {
63                let mut entity = extract_entity(item_struct)?;
64                entity.imports = imports;
65                return Ok(entity);
66            }
67        }
68    }
69
70    Err("No struct with sqlx::FromRow derive found".to_string())
71}
72
73/// Check if a struct has `sqlx::FromRow` in its derive attributes.
74fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
75    for attr in &item.attrs {
76        if attr.path().is_ident("derive") {
77            let tokens = attr.meta.to_token_stream().to_string();
78            if tokens.contains("FromRow") {
79                return true;
80            }
81        }
82    }
83    false
84}
85
86/// Extract `use` imports from the source file, excluding serde/sqlx imports
87/// (those are already handled by the CRUD generator).
88fn extract_use_imports(file: &syn::File) -> Vec<String> {
89    file.items
90        .iter()
91        .filter_map(|item| {
92            if let syn::Item::Use(use_item) = item {
93                let text = use_item.to_token_stream().to_string();
94                // Skip serde and sqlx imports — the CRUD generator adds those itself
95                if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
96                    return None;
97                }
98                // Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
99                let normalized = normalize_use_statement(&text);
100                Some(normalized)
101            } else {
102                None
103            }
104        })
105        .collect()
106}
107
108/// Normalize a tokenized `use` statement by removing extra spaces around `::`, `{`, `}`, and `,`.
109fn normalize_use_statement(s: &str) -> String {
110    s.replace(" :: ", "::")
111        .replace(":: ", "::")
112        .replace(" ::", "::")
113        .replace("{ ", "{")
114        .replace(" }", "}")
115        .replace(" ,", ",")
116        .replace(" ;", ";")
117}
118
119/// Extract a ParsedEntity from a struct item.
120fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
121    let struct_name = item.ident.to_string();
122
123    let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
124    let is_view = kind.as_deref() == Some("view");
125
126    // Fall back to struct name if no table annotation
127    let table_name = table_name.unwrap_or_else(|| struct_name.clone());
128
129    let fields = match &item.fields {
130        syn::Fields::Named(named) => {
131            named
132                .named
133                .iter()
134                .map(extract_field)
135                .collect::<Result<Vec<_>, _>>()?
136        }
137        _ => return Err("Expected named fields".to_string()),
138    };
139
140    Ok(ParsedEntity {
141        struct_name,
142        table_name,
143        schema_name,
144        is_view,
145        fields,
146        imports: Vec::new(), // filled by parse_entity_source
147    })
148}
149
150/// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes.
151/// Returns (kind, schema_name, table_name).
152fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
153    let mut kind = None;
154    let mut schema_name = None;
155    let mut table_name = None;
156
157    for attr in attrs {
158        if attr.path().is_ident("sqlx_gen") {
159            let tokens = attr.meta.to_token_stream().to_string();
160            if let Some(k) = extract_attr_value(&tokens, "kind") {
161                kind = Some(k);
162            }
163            if let Some(s) = extract_attr_value(&tokens, "schema") {
164                schema_name = Some(s);
165            }
166            if let Some(t) = extract_attr_value(&tokens, "table") {
167                table_name = Some(t);
168            }
169        }
170    }
171
172    (kind, schema_name, table_name)
173}
174
175/// Extract a named string value from an attribute token string.
176/// e.g. extract_attr_value(`sqlx_gen(kind = "view", table = "users")`, "kind") -> Some("view")
177fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
178    let pattern = format!("{} = \"", key);
179    let start = tokens.find(&pattern)? + pattern.len();
180    let rest = &tokens[start..];
181    let end = rest.find('"')?;
182    Some(rest[..end].to_string())
183}
184
185/// Extract a ParsedField from a syn::Field.
186fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
187    let rust_name = field
188        .ident
189        .as_ref()
190        .ok_or("Unnamed field")?
191        .to_string();
192
193    let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
194    let (is_primary_key, sql_type, is_sql_array, column_default) = parse_sqlx_gen_field_attrs(&field.attrs);
195
196    let rust_type = field.ty.to_token_stream().to_string();
197    let (is_nullable, inner_type) = extract_option_type(&field.ty);
198    let inner_type = if is_nullable {
199        inner_type
200    } else {
201        rust_type.clone()
202    };
203
204    Ok(ParsedField {
205        rust_name,
206        column_name,
207        rust_type,
208        is_nullable,
209        inner_type,
210        is_primary_key,
211        sql_type,
212        is_sql_array,
213        column_default,
214    })
215}
216
217/// Parse `#[sqlx_gen(...)]` attributes on a field.
218/// Returns (is_primary_key, sql_type, is_sql_array, column_default).
219fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option<String>, bool, Option<String>) {
220    let mut is_pk = false;
221    let mut sql_type = None;
222    let mut is_array = false;
223    let mut column_default = None;
224
225    for attr in attrs {
226        if attr.path().is_ident("sqlx_gen") {
227            let tokens = attr.meta.to_token_stream().to_string();
228            if tokens.contains("primary_key") {
229                is_pk = true;
230            }
231            if let Some(t) = extract_attr_value(&tokens, "sql_type") {
232                sql_type = Some(t);
233            }
234            if tokens.contains("is_array") {
235                is_array = true;
236            }
237            if let Some(d) = extract_attr_value(&tokens, "column_default") {
238                column_default = Some(d);
239            }
240        }
241    }
242
243    (is_pk, sql_type, is_array, column_default)
244}
245
246/// Extract `#[sqlx(rename = "...")]` value from field attributes.
247fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
248    for attr in attrs {
249        if attr.path().is_ident("sqlx") {
250            let tokens = attr.meta.to_token_stream().to_string();
251            return extract_attr_value(&tokens, "rename");
252        }
253    }
254    None
255}
256
257/// Check if a type is `Option<T>` and extract the inner type.
258fn extract_option_type(ty: &syn::Type) -> (bool, String) {
259    if let syn::Type::Path(type_path) = ty {
260        if let Some(segment) = type_path.path.segments.last() {
261            if segment.ident == "Option" {
262                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
263                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
264                        return (true, inner.to_token_stream().to_string());
265                    }
266                }
267            }
268        }
269    }
270    (false, String::new())
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    // --- basic parsing ---
278
279    #[test]
280    fn test_parse_simple_table() {
281        let source = r#"
282            #[derive(Debug, Clone, sqlx::FromRow)]
283            #[sqlx_gen(kind = "table", table = "users")]
284            pub struct Users {
285                pub id: i32,
286                pub name: String,
287            }
288        "#;
289        let entity = parse_entity_source(source).unwrap();
290        assert_eq!(entity.struct_name, "Users");
291        assert_eq!(entity.table_name, "users");
292        assert!(!entity.is_view);
293        assert_eq!(entity.fields.len(), 2);
294    }
295
296    #[test]
297    fn test_parse_view() {
298        let source = r#"
299            #[derive(Debug, Clone, sqlx::FromRow)]
300            #[sqlx_gen(kind = "view", table = "active_users")]
301            pub struct ActiveUsers {
302                pub id: i32,
303            }
304        "#;
305        let entity = parse_entity_source(source).unwrap();
306        assert!(entity.is_view);
307        assert_eq!(entity.table_name, "active_users");
308    }
309
310    #[test]
311    fn test_parse_table_not_view() {
312        let source = r#"
313            #[derive(Debug, Clone, sqlx::FromRow)]
314            #[sqlx_gen(kind = "table", table = "users")]
315            pub struct Users {
316                pub id: i32,
317            }
318        "#;
319        let entity = parse_entity_source(source).unwrap();
320        assert!(!entity.is_view);
321    }
322
323    // --- primary key ---
324
325    #[test]
326    fn test_parse_primary_key() {
327        let source = r#"
328            #[derive(Debug, Clone, sqlx::FromRow)]
329            #[sqlx_gen(kind = "table", table = "users")]
330            pub struct Users {
331                #[sqlx_gen(primary_key)]
332                pub id: i32,
333                pub name: String,
334            }
335        "#;
336        let entity = parse_entity_source(source).unwrap();
337        assert!(entity.fields[0].is_primary_key);
338        assert!(!entity.fields[1].is_primary_key);
339    }
340
341    #[test]
342    fn test_composite_primary_key() {
343        let source = r#"
344            #[derive(Debug, Clone, sqlx::FromRow)]
345            #[sqlx_gen(kind = "table", table = "user_roles")]
346            pub struct UserRoles {
347                #[sqlx_gen(primary_key)]
348                pub user_id: i32,
349                #[sqlx_gen(primary_key)]
350                pub role_id: i32,
351            }
352        "#;
353        let entity = parse_entity_source(source).unwrap();
354        assert!(entity.fields[0].is_primary_key);
355        assert!(entity.fields[1].is_primary_key);
356    }
357
358    #[test]
359    fn test_no_primary_key() {
360        let source = r#"
361            #[derive(Debug, Clone, sqlx::FromRow)]
362            #[sqlx_gen(kind = "table", table = "logs")]
363            pub struct Logs {
364                pub message: String,
365            }
366        "#;
367        let entity = parse_entity_source(source).unwrap();
368        assert!(!entity.fields[0].is_primary_key);
369    }
370
371    // --- sqlx rename ---
372
373    #[test]
374    fn test_sqlx_rename() {
375        let source = r#"
376            #[derive(Debug, Clone, sqlx::FromRow)]
377            #[sqlx_gen(kind = "table", table = "connector")]
378            pub struct Connector {
379                #[sqlx(rename = "type")]
380                pub connector_type: String,
381            }
382        "#;
383        let entity = parse_entity_source(source).unwrap();
384        assert_eq!(entity.fields[0].rust_name, "connector_type");
385        assert_eq!(entity.fields[0].column_name, "type");
386    }
387
388    #[test]
389    fn test_no_rename_uses_field_name() {
390        let source = r#"
391            #[derive(Debug, Clone, sqlx::FromRow)]
392            #[sqlx_gen(kind = "table", table = "users")]
393            pub struct Users {
394                pub name: String,
395            }
396        "#;
397        let entity = parse_entity_source(source).unwrap();
398        assert_eq!(entity.fields[0].rust_name, "name");
399        assert_eq!(entity.fields[0].column_name, "name");
400    }
401
402    // --- nullable types ---
403
404    #[test]
405    fn test_option_field_nullable() {
406        let source = r#"
407            #[derive(Debug, Clone, sqlx::FromRow)]
408            #[sqlx_gen(kind = "table", table = "users")]
409            pub struct Users {
410                pub email: Option<String>,
411            }
412        "#;
413        let entity = parse_entity_source(source).unwrap();
414        assert!(entity.fields[0].is_nullable);
415        assert_eq!(entity.fields[0].inner_type, "String");
416    }
417
418    #[test]
419    fn test_non_option_not_nullable() {
420        let source = r#"
421            #[derive(Debug, Clone, sqlx::FromRow)]
422            #[sqlx_gen(kind = "table", table = "users")]
423            pub struct Users {
424                pub id: i32,
425            }
426        "#;
427        let entity = parse_entity_source(source).unwrap();
428        assert!(!entity.fields[0].is_nullable);
429        assert_eq!(entity.fields[0].inner_type, "i32");
430    }
431
432    #[test]
433    fn test_option_complex_type() {
434        let source = r#"
435            #[derive(Debug, Clone, sqlx::FromRow)]
436            #[sqlx_gen(kind = "table", table = "users")]
437            pub struct Users {
438                pub created_at: Option<chrono::NaiveDateTime>,
439            }
440        "#;
441        let entity = parse_entity_source(source).unwrap();
442        assert!(entity.fields[0].is_nullable);
443        assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
444    }
445
446    // --- type preservation ---
447
448    #[test]
449    fn test_rust_type_preserved() {
450        let source = r#"
451            #[derive(Debug, Clone, sqlx::FromRow)]
452            #[sqlx_gen(kind = "table", table = "users")]
453            pub struct Users {
454                pub id: uuid::Uuid,
455            }
456        "#;
457        let entity = parse_entity_source(source).unwrap();
458        assert!(entity.fields[0].rust_type.contains("Uuid"));
459    }
460
461    // --- error cases ---
462
463    #[test]
464    fn test_no_from_row_struct() {
465        let source = r#"
466            pub struct NotAnEntity {
467                pub id: i32,
468            }
469        "#;
470        let result = parse_entity_source(source);
471        assert!(result.is_err());
472    }
473
474    #[test]
475    fn test_empty_source() {
476        let result = parse_entity_source("");
477        assert!(result.is_err());
478    }
479
480    // --- fallback table name ---
481
482    #[test]
483    fn test_fallback_table_name_to_struct_name() {
484        let source = r#"
485            #[derive(Debug, Clone, sqlx::FromRow)]
486            pub struct Users {
487                pub id: i32,
488            }
489        "#;
490        let entity = parse_entity_source(source).unwrap();
491        assert_eq!(entity.table_name, "Users");
492    }
493
494    // --- combined attributes ---
495
496    #[test]
497    fn test_pk_with_rename() {
498        let source = r#"
499            #[derive(Debug, Clone, sqlx::FromRow)]
500            #[sqlx_gen(kind = "table", table = "items")]
501            pub struct Items {
502                #[sqlx_gen(primary_key)]
503                #[sqlx(rename = "itemID")]
504                pub item_id: i32,
505            }
506        "#;
507        let entity = parse_entity_source(source).unwrap();
508        let f = &entity.fields[0];
509        assert!(f.is_primary_key);
510        assert_eq!(f.column_name, "itemID");
511        assert_eq!(f.rust_name, "item_id");
512    }
513
514    #[test]
515    fn test_full_entity() {
516        let source = r#"
517            #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
518            #[sqlx_gen(kind = "table", table = "users")]
519            pub struct Users {
520                #[sqlx_gen(primary_key)]
521                pub id: i32,
522                pub name: String,
523                pub email: Option<String>,
524                #[sqlx(rename = "createdAt")]
525                pub created_at: chrono::NaiveDateTime,
526            }
527        "#;
528        let entity = parse_entity_source(source).unwrap();
529        assert_eq!(entity.struct_name, "Users");
530        assert_eq!(entity.table_name, "users");
531        assert!(!entity.is_view);
532        assert_eq!(entity.fields.len(), 4);
533
534        assert!(entity.fields[0].is_primary_key);
535        assert_eq!(entity.fields[0].rust_name, "id");
536
537        assert!(!entity.fields[1].is_primary_key);
538        assert_eq!(entity.fields[1].rust_type, "String");
539
540        assert!(entity.fields[2].is_nullable);
541        assert_eq!(entity.fields[2].inner_type, "String");
542
543        assert_eq!(entity.fields[3].column_name, "createdAt");
544        assert_eq!(entity.fields[3].rust_name, "created_at");
545    }
546
547    // --- imports extraction ---
548
549    #[test]
550    fn test_imports_extracted() {
551        let source = r#"
552            use chrono::{DateTime, Utc};
553            use uuid::Uuid;
554            use serde::{Serialize, Deserialize};
555
556            #[derive(Debug, Clone, sqlx::FromRow)]
557            #[sqlx_gen(kind = "table", table = "users")]
558            pub struct Users {
559                pub id: Uuid,
560                pub created_at: DateTime<Utc>,
561            }
562        "#;
563        let entity = parse_entity_source(source).unwrap();
564        assert_eq!(entity.imports.len(), 2);
565        assert!(entity.imports.iter().any(|i| i.contains("chrono")));
566        assert!(entity.imports.iter().any(|i| i.contains("uuid")));
567        // serde should be excluded
568        assert!(!entity.imports.iter().any(|i| i.contains("serde")));
569    }
570
571    #[test]
572    fn test_imports_empty_when_none() {
573        let source = r#"
574            #[derive(Debug, Clone, sqlx::FromRow)]
575            #[sqlx_gen(kind = "table", table = "users")]
576            pub struct Users {
577                pub id: i32,
578            }
579        "#;
580        let entity = parse_entity_source(source).unwrap();
581        assert!(entity.imports.is_empty());
582    }
583
584    #[test]
585    fn test_imports_keep_serde_json() {
586        let source = r#"
587            use serde::{Serialize, Deserialize};
588            use serde_json::Value;
589
590            #[derive(Debug, Clone, sqlx::FromRow)]
591            #[sqlx_gen(kind = "table", table = "users")]
592            pub struct Users {
593                pub data: Value,
594            }
595        "#;
596        let entity = parse_entity_source(source).unwrap();
597        assert_eq!(entity.imports.len(), 1);
598        assert!(entity.imports[0].contains("serde_json"));
599    }
600
601    #[test]
602    fn test_imports_exclude_sqlx() {
603        let source = r#"
604            use sqlx::types::Uuid;
605            use chrono::NaiveDateTime;
606
607            #[derive(Debug, Clone, sqlx::FromRow)]
608            #[sqlx_gen(kind = "table", table = "users")]
609            pub struct Users {
610                pub id: i32,
611            }
612        "#;
613        let entity = parse_entity_source(source).unwrap();
614        assert_eq!(entity.imports.len(), 1);
615        assert!(entity.imports[0].contains("chrono"));
616    }
617
618    // --- column_default parsing ---
619
620    #[test]
621    fn test_parse_column_default() {
622        let source = r#"
623            #[derive(Debug, Clone, sqlx::FromRow)]
624            #[sqlx_gen(kind = "table", table = "tasks")]
625            pub struct Tasks {
626                #[sqlx_gen(primary_key)]
627                pub id: i32,
628                #[sqlx_gen(column_default = "now()")]
629                pub created_at: String,
630            }
631        "#;
632        let entity = parse_entity_source(source).unwrap();
633        let created_at = &entity.fields[1];
634        assert_eq!(created_at.column_default, Some("now()".to_string()));
635    }
636
637    #[test]
638    fn test_parse_no_column_default() {
639        let source = r#"
640            #[derive(Debug, Clone, sqlx::FromRow)]
641            #[sqlx_gen(kind = "table", table = "tasks")]
642            pub struct Tasks {
643                #[sqlx_gen(primary_key)]
644                pub id: i32,
645                pub title: String,
646            }
647        "#;
648        let entity = parse_entity_source(source).unwrap();
649        let title = &entity.fields[1];
650        assert_eq!(title.column_default, None);
651    }
652
653    #[test]
654    fn test_parse_column_default_with_cast() {
655        let source = r#"
656            #[derive(Debug, Clone, sqlx::FromRow)]
657            #[sqlx_gen(kind = "table", table = "tasks")]
658            pub struct Tasks {
659                #[sqlx_gen(primary_key)]
660                pub id: i32,
661                #[sqlx_gen(column_default = "'idle'::task_status")]
662                pub status: String,
663            }
664        "#;
665        let entity = parse_entity_source(source).unwrap();
666        let status = &entity.fields[1];
667        assert_eq!(status.column_default, Some("'idle'::task_status".to_string()));
668    }
669}