Skip to main content

sqlx_gen/codegen/
entity_parser.rs

1use std::path::Path;
2
3use quote::ToTokens;
4
5/// Represents a field parsed from a generated entity struct.
6#[derive(Debug, Clone)]
7pub struct ParsedField {
8    /// Rust field name (e.g. "connector_type")
9    pub rust_name: String,
10    /// Original DB column name. From `#[sqlx(rename = "...")]` if present, otherwise same as rust_name.
11    pub column_name: String,
12    /// Full Rust type as a string (e.g. "Option<String>", "i32", "uuid::Uuid")
13    pub rust_type: String,
14    /// Whether the type is wrapped in Option<T>
15    pub is_nullable: bool,
16    /// The inner type if nullable, or the full type if not
17    pub inner_type: String,
18    /// Whether this field is a primary key (`#[sqlx_gen(primary_key)]`)
19    pub is_primary_key: bool,
20    /// SQL type name for custom types needing a cast (e.g. "agent.connector_usages")
21    pub sql_type: Option<String>,
22    /// Whether the SQL type is an array (needs `[]` suffix in cast)
23    pub is_sql_array: bool,
24    /// Raw SQL default expression from the DB (e.g. "now()", "'idle'::task_status")
25    pub column_default: Option<String>,
26}
27
28/// Represents an entity parsed from a generated Rust file.
29#[derive(Debug, Clone)]
30pub struct ParsedEntity {
31    /// Struct name in PascalCase (e.g. "Users", "UserRoles")
32    pub struct_name: String,
33    /// Original table/view name from `#[sqlx_gen(table = "...")]`
34    pub table_name: String,
35    /// Schema name from `#[sqlx_gen(schema = "...")]`
36    pub schema_name: Option<String>,
37    /// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
38    pub is_view: bool,
39    /// Parsed fields
40    pub fields: Vec<ParsedField>,
41    /// `use` imports from the entity source file (e.g. "use chrono::{DateTime, Utc};")
42    pub imports: Vec<String>,
43}
44
45/// Parse an entity struct from a `.rs` file on disk.
46pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
47    let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
48    parse_entity_source(&source).map_err(|e| {
49        crate::error::Error::Config(format!("{}: {}", path.display(), e))
50    })
51}
52
53/// Parse an entity struct from a Rust source string.
54pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
55    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
56
57    // Collect use imports (excluding serde and sqlx derives)
58    let imports = extract_use_imports(&syntax);
59
60    for item in &syntax.items {
61        if let syn::Item::Struct(item_struct) = item {
62            if has_from_row_derive(item_struct) {
63                let mut entity = extract_entity(item_struct)?;
64                entity.imports = imports;
65                return Ok(entity);
66            }
67        }
68    }
69
70    Err("No struct with sqlx::FromRow derive found".to_string())
71}
72
73/// Check if a struct has `sqlx::FromRow` in its derive attributes.
74fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
75    for attr in &item.attrs {
76        if attr.path().is_ident("derive") {
77            let tokens = attr.meta.to_token_stream().to_string();
78            if tokens.contains("FromRow") {
79                return true;
80            }
81        }
82    }
83    false
84}
85
86/// Extract `use` imports from the source file, excluding serde/sqlx imports
87/// (those are already handled by the CRUD generator).
88fn extract_use_imports(file: &syn::File) -> Vec<String> {
89    file.items
90        .iter()
91        .filter_map(|item| {
92            if let syn::Item::Use(use_item) = item {
93                let text = use_item.to_token_stream().to_string();
94                // Skip serde and sqlx imports — the CRUD generator adds those itself
95                if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
96                    return None;
97                }
98                // Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
99                let normalized = normalize_use_statement(&text);
100                Some(normalized)
101            } else {
102                None
103            }
104        })
105        .collect()
106}
107
108/// Normalize a tokenized `use` statement by removing extra spaces around `::`, `{`, `}`, and `,`.
109fn normalize_use_statement(s: &str) -> String {
110    s.replace(" :: ", "::")
111        .replace(":: ", "::")
112        .replace(" ::", "::")
113        .replace("{ ", "{")
114        .replace(" }", "}")
115        .replace(" ,", ",")
116        .replace(" ;", ";")
117}
118
119/// Extract a ParsedEntity from a struct item.
120fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
121    let struct_name = item.ident.to_string();
122
123    let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
124    let is_view = kind.as_deref() == Some("view");
125
126    // Fall back to struct name if no table annotation
127    let table_name = table_name.unwrap_or_else(|| struct_name.clone());
128
129    let fields = match &item.fields {
130        syn::Fields::Named(named) => {
131            named
132                .named
133                .iter()
134                .map(extract_field)
135                .collect::<Result<Vec<_>, _>>()?
136        }
137        _ => return Err("Expected named fields".to_string()),
138    };
139
140    Ok(ParsedEntity {
141        struct_name,
142        table_name,
143        schema_name,
144        is_view,
145        fields,
146        imports: Vec::new(), // filled by parse_entity_source
147    })
148}
149
150/// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes.
151/// Returns (kind, schema_name, table_name).
152fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
153    let mut kind = None;
154    let mut schema_name = None;
155    let mut table_name = None;
156
157    for attr in attrs {
158        if attr.path().is_ident("sqlx_gen") {
159            let tokens = attr.meta.to_token_stream().to_string();
160            if let Some(k) = extract_attr_value(&tokens, "kind") {
161                kind = Some(k);
162            }
163            if let Some(s) = extract_attr_value(&tokens, "schema") {
164                schema_name = Some(s);
165            }
166            if let Some(t) = extract_attr_value(&tokens, "table") {
167                table_name = Some(t);
168            }
169        }
170    }
171
172    (kind, schema_name, table_name)
173}
174
175/// Extract a named string value from an attribute token string.
176/// e.g. extract_attr_value(`sqlx_gen(kind = "view", table = "users")`, "kind") -> Some("view")
177fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
178    let pattern = format!("{} = \"", key);
179    let start = tokens.find(&pattern)? + pattern.len();
180    let rest = &tokens[start..];
181    let bytes = rest.as_bytes();
182    let mut end = 0;
183    while end < bytes.len() {
184        if bytes[end] == b'"' && (end == 0 || bytes[end - 1] != b'\\') {
185            break;
186        }
187        end += 1;
188    }
189    if end >= bytes.len() {
190        return None;
191    }
192    Some(rest[..end].replace("\\\"", "\""))
193}
194
195/// Extract a ParsedField from a syn::Field.
196fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
197    let rust_name = field
198        .ident
199        .as_ref()
200        .ok_or("Unnamed field")?
201        .to_string();
202
203    let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
204    let (is_primary_key, sql_type, is_sql_array, column_default) = parse_sqlx_gen_field_attrs(&field.attrs);
205
206    let rust_type = field.ty.to_token_stream().to_string();
207    let (is_nullable, inner_type) = extract_option_type(&field.ty);
208    let inner_type = if is_nullable {
209        inner_type
210    } else {
211        rust_type.clone()
212    };
213
214    Ok(ParsedField {
215        rust_name,
216        column_name,
217        rust_type,
218        is_nullable,
219        inner_type,
220        is_primary_key,
221        sql_type,
222        is_sql_array,
223        column_default,
224    })
225}
226
227/// Parse `#[sqlx_gen(...)]` attributes on a field.
228/// Returns (is_primary_key, sql_type, is_sql_array, column_default).
229fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option<String>, bool, Option<String>) {
230    let mut is_pk = false;
231    let mut sql_type = None;
232    let mut is_array = false;
233    let mut column_default = None;
234
235    for attr in attrs {
236        if attr.path().is_ident("sqlx_gen") {
237            let tokens = attr.meta.to_token_stream().to_string();
238            if tokens.contains("primary_key") {
239                is_pk = true;
240            }
241            if let Some(t) = extract_attr_value(&tokens, "sql_type") {
242                sql_type = Some(t);
243            }
244            if tokens.contains("is_array") {
245                is_array = true;
246            }
247            if let Some(d) = extract_attr_value(&tokens, "column_default") {
248                column_default = Some(d);
249            }
250        }
251    }
252
253    (is_pk, sql_type, is_array, column_default)
254}
255
256/// Extract `#[sqlx(rename = "...")]` value from field attributes.
257fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
258    for attr in attrs {
259        if attr.path().is_ident("sqlx") {
260            let tokens = attr.meta.to_token_stream().to_string();
261            return extract_attr_value(&tokens, "rename");
262        }
263    }
264    None
265}
266
267/// Check if a type is `Option<T>` and extract the inner type.
268fn extract_option_type(ty: &syn::Type) -> (bool, String) {
269    if let syn::Type::Path(type_path) = ty {
270        if let Some(segment) = type_path.path.segments.last() {
271            if segment.ident == "Option" {
272                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
273                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
274                        return (true, inner.to_token_stream().to_string());
275                    }
276                }
277            }
278        }
279    }
280    (false, String::new())
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    // --- basic parsing ---
288
289    #[test]
290    fn test_parse_simple_table() {
291        let source = r#"
292            #[derive(Debug, Clone, sqlx::FromRow)]
293            #[sqlx_gen(kind = "table", table = "users")]
294            pub struct Users {
295                pub id: i32,
296                pub name: String,
297            }
298        "#;
299        let entity = parse_entity_source(source).unwrap();
300        assert_eq!(entity.struct_name, "Users");
301        assert_eq!(entity.table_name, "users");
302        assert!(!entity.is_view);
303        assert_eq!(entity.fields.len(), 2);
304    }
305
306    #[test]
307    fn test_parse_view() {
308        let source = r#"
309            #[derive(Debug, Clone, sqlx::FromRow)]
310            #[sqlx_gen(kind = "view", table = "active_users")]
311            pub struct ActiveUsers {
312                pub id: i32,
313            }
314        "#;
315        let entity = parse_entity_source(source).unwrap();
316        assert!(entity.is_view);
317        assert_eq!(entity.table_name, "active_users");
318    }
319
320    #[test]
321    fn test_parse_table_not_view() {
322        let source = r#"
323            #[derive(Debug, Clone, sqlx::FromRow)]
324            #[sqlx_gen(kind = "table", table = "users")]
325            pub struct Users {
326                pub id: i32,
327            }
328        "#;
329        let entity = parse_entity_source(source).unwrap();
330        assert!(!entity.is_view);
331    }
332
333    // --- primary key ---
334
335    #[test]
336    fn test_parse_primary_key() {
337        let source = r#"
338            #[derive(Debug, Clone, sqlx::FromRow)]
339            #[sqlx_gen(kind = "table", table = "users")]
340            pub struct Users {
341                #[sqlx_gen(primary_key)]
342                pub id: i32,
343                pub name: String,
344            }
345        "#;
346        let entity = parse_entity_source(source).unwrap();
347        assert!(entity.fields[0].is_primary_key);
348        assert!(!entity.fields[1].is_primary_key);
349    }
350
351    #[test]
352    fn test_composite_primary_key() {
353        let source = r#"
354            #[derive(Debug, Clone, sqlx::FromRow)]
355            #[sqlx_gen(kind = "table", table = "user_roles")]
356            pub struct UserRoles {
357                #[sqlx_gen(primary_key)]
358                pub user_id: i32,
359                #[sqlx_gen(primary_key)]
360                pub role_id: i32,
361            }
362        "#;
363        let entity = parse_entity_source(source).unwrap();
364        assert!(entity.fields[0].is_primary_key);
365        assert!(entity.fields[1].is_primary_key);
366    }
367
368    #[test]
369    fn test_no_primary_key() {
370        let source = r#"
371            #[derive(Debug, Clone, sqlx::FromRow)]
372            #[sqlx_gen(kind = "table", table = "logs")]
373            pub struct Logs {
374                pub message: String,
375            }
376        "#;
377        let entity = parse_entity_source(source).unwrap();
378        assert!(!entity.fields[0].is_primary_key);
379    }
380
381    // --- sqlx rename ---
382
383    #[test]
384    fn test_sqlx_rename() {
385        let source = r#"
386            #[derive(Debug, Clone, sqlx::FromRow)]
387            #[sqlx_gen(kind = "table", table = "connector")]
388            pub struct Connector {
389                #[sqlx(rename = "type")]
390                pub connector_type: String,
391            }
392        "#;
393        let entity = parse_entity_source(source).unwrap();
394        assert_eq!(entity.fields[0].rust_name, "connector_type");
395        assert_eq!(entity.fields[0].column_name, "type");
396    }
397
398    #[test]
399    fn test_no_rename_uses_field_name() {
400        let source = r#"
401            #[derive(Debug, Clone, sqlx::FromRow)]
402            #[sqlx_gen(kind = "table", table = "users")]
403            pub struct Users {
404                pub name: String,
405            }
406        "#;
407        let entity = parse_entity_source(source).unwrap();
408        assert_eq!(entity.fields[0].rust_name, "name");
409        assert_eq!(entity.fields[0].column_name, "name");
410    }
411
412    // --- nullable types ---
413
414    #[test]
415    fn test_option_field_nullable() {
416        let source = r#"
417            #[derive(Debug, Clone, sqlx::FromRow)]
418            #[sqlx_gen(kind = "table", table = "users")]
419            pub struct Users {
420                pub email: Option<String>,
421            }
422        "#;
423        let entity = parse_entity_source(source).unwrap();
424        assert!(entity.fields[0].is_nullable);
425        assert_eq!(entity.fields[0].inner_type, "String");
426    }
427
428    #[test]
429    fn test_non_option_not_nullable() {
430        let source = r#"
431            #[derive(Debug, Clone, sqlx::FromRow)]
432            #[sqlx_gen(kind = "table", table = "users")]
433            pub struct Users {
434                pub id: i32,
435            }
436        "#;
437        let entity = parse_entity_source(source).unwrap();
438        assert!(!entity.fields[0].is_nullable);
439        assert_eq!(entity.fields[0].inner_type, "i32");
440    }
441
442    #[test]
443    fn test_option_complex_type() {
444        let source = r#"
445            #[derive(Debug, Clone, sqlx::FromRow)]
446            #[sqlx_gen(kind = "table", table = "users")]
447            pub struct Users {
448                pub created_at: Option<chrono::NaiveDateTime>,
449            }
450        "#;
451        let entity = parse_entity_source(source).unwrap();
452        assert!(entity.fields[0].is_nullable);
453        assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
454    }
455
456    // --- type preservation ---
457
458    #[test]
459    fn test_rust_type_preserved() {
460        let source = r#"
461            #[derive(Debug, Clone, sqlx::FromRow)]
462            #[sqlx_gen(kind = "table", table = "users")]
463            pub struct Users {
464                pub id: uuid::Uuid,
465            }
466        "#;
467        let entity = parse_entity_source(source).unwrap();
468        assert!(entity.fields[0].rust_type.contains("Uuid"));
469    }
470
471    // --- error cases ---
472
473    #[test]
474    fn test_no_from_row_struct() {
475        let source = r#"
476            pub struct NotAnEntity {
477                pub id: i32,
478            }
479        "#;
480        let result = parse_entity_source(source);
481        assert!(result.is_err());
482    }
483
484    #[test]
485    fn test_empty_source() {
486        let result = parse_entity_source("");
487        assert!(result.is_err());
488    }
489
490    // --- fallback table name ---
491
492    #[test]
493    fn test_fallback_table_name_to_struct_name() {
494        let source = r#"
495            #[derive(Debug, Clone, sqlx::FromRow)]
496            pub struct Users {
497                pub id: i32,
498            }
499        "#;
500        let entity = parse_entity_source(source).unwrap();
501        assert_eq!(entity.table_name, "Users");
502    }
503
504    // --- combined attributes ---
505
506    #[test]
507    fn test_pk_with_rename() {
508        let source = r#"
509            #[derive(Debug, Clone, sqlx::FromRow)]
510            #[sqlx_gen(kind = "table", table = "items")]
511            pub struct Items {
512                #[sqlx_gen(primary_key)]
513                #[sqlx(rename = "itemID")]
514                pub item_id: i32,
515            }
516        "#;
517        let entity = parse_entity_source(source).unwrap();
518        let f = &entity.fields[0];
519        assert!(f.is_primary_key);
520        assert_eq!(f.column_name, "itemID");
521        assert_eq!(f.rust_name, "item_id");
522    }
523
524    #[test]
525    fn test_full_entity() {
526        let source = r#"
527            #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
528            #[sqlx_gen(kind = "table", table = "users")]
529            pub struct Users {
530                #[sqlx_gen(primary_key)]
531                pub id: i32,
532                pub name: String,
533                pub email: Option<String>,
534                #[sqlx(rename = "createdAt")]
535                pub created_at: chrono::NaiveDateTime,
536            }
537        "#;
538        let entity = parse_entity_source(source).unwrap();
539        assert_eq!(entity.struct_name, "Users");
540        assert_eq!(entity.table_name, "users");
541        assert!(!entity.is_view);
542        assert_eq!(entity.fields.len(), 4);
543
544        assert!(entity.fields[0].is_primary_key);
545        assert_eq!(entity.fields[0].rust_name, "id");
546
547        assert!(!entity.fields[1].is_primary_key);
548        assert_eq!(entity.fields[1].rust_type, "String");
549
550        assert!(entity.fields[2].is_nullable);
551        assert_eq!(entity.fields[2].inner_type, "String");
552
553        assert_eq!(entity.fields[3].column_name, "createdAt");
554        assert_eq!(entity.fields[3].rust_name, "created_at");
555    }
556
557    // --- imports extraction ---
558
559    #[test]
560    fn test_imports_extracted() {
561        let source = r#"
562            use chrono::{DateTime, Utc};
563            use uuid::Uuid;
564            use serde::{Serialize, Deserialize};
565
566            #[derive(Debug, Clone, sqlx::FromRow)]
567            #[sqlx_gen(kind = "table", table = "users")]
568            pub struct Users {
569                pub id: Uuid,
570                pub created_at: DateTime<Utc>,
571            }
572        "#;
573        let entity = parse_entity_source(source).unwrap();
574        assert_eq!(entity.imports.len(), 2);
575        assert!(entity.imports.iter().any(|i| i.contains("chrono")));
576        assert!(entity.imports.iter().any(|i| i.contains("uuid")));
577        // serde should be excluded
578        assert!(!entity.imports.iter().any(|i| i.contains("serde")));
579    }
580
581    #[test]
582    fn test_imports_empty_when_none() {
583        let source = r#"
584            #[derive(Debug, Clone, sqlx::FromRow)]
585            #[sqlx_gen(kind = "table", table = "users")]
586            pub struct Users {
587                pub id: i32,
588            }
589        "#;
590        let entity = parse_entity_source(source).unwrap();
591        assert!(entity.imports.is_empty());
592    }
593
594    #[test]
595    fn test_imports_keep_serde_json() {
596        let source = r#"
597            use serde::{Serialize, Deserialize};
598            use serde_json::Value;
599
600            #[derive(Debug, Clone, sqlx::FromRow)]
601            #[sqlx_gen(kind = "table", table = "users")]
602            pub struct Users {
603                pub data: Value,
604            }
605        "#;
606        let entity = parse_entity_source(source).unwrap();
607        assert_eq!(entity.imports.len(), 1);
608        assert!(entity.imports[0].contains("serde_json"));
609    }
610
611    #[test]
612    fn test_imports_exclude_sqlx() {
613        let source = r#"
614            use sqlx::types::Uuid;
615            use chrono::NaiveDateTime;
616
617            #[derive(Debug, Clone, sqlx::FromRow)]
618            #[sqlx_gen(kind = "table", table = "users")]
619            pub struct Users {
620                pub id: i32,
621            }
622        "#;
623        let entity = parse_entity_source(source).unwrap();
624        assert_eq!(entity.imports.len(), 1);
625        assert!(entity.imports[0].contains("chrono"));
626    }
627
628    // --- column_default parsing ---
629
630    #[test]
631    fn test_parse_column_default() {
632        let source = r#"
633            #[derive(Debug, Clone, sqlx::FromRow)]
634            #[sqlx_gen(kind = "table", table = "tasks")]
635            pub struct Tasks {
636                #[sqlx_gen(primary_key)]
637                pub id: i32,
638                #[sqlx_gen(column_default = "now()")]
639                pub created_at: String,
640            }
641        "#;
642        let entity = parse_entity_source(source).unwrap();
643        let created_at = &entity.fields[1];
644        assert_eq!(created_at.column_default, Some("now()".to_string()));
645    }
646
647    #[test]
648    fn test_parse_no_column_default() {
649        let source = r#"
650            #[derive(Debug, Clone, sqlx::FromRow)]
651            #[sqlx_gen(kind = "table", table = "tasks")]
652            pub struct Tasks {
653                #[sqlx_gen(primary_key)]
654                pub id: i32,
655                pub title: String,
656            }
657        "#;
658        let entity = parse_entity_source(source).unwrap();
659        let title = &entity.fields[1];
660        assert_eq!(title.column_default, None);
661    }
662
663    #[test]
664    fn test_parse_column_default_with_cast() {
665        let source = r#"
666            #[derive(Debug, Clone, sqlx::FromRow)]
667            #[sqlx_gen(kind = "table", table = "tasks")]
668            pub struct Tasks {
669                #[sqlx_gen(primary_key)]
670                pub id: i32,
671                #[sqlx_gen(column_default = "'idle'::task_status")]
672                pub status: String,
673            }
674        "#;
675        let entity = parse_entity_source(source).unwrap();
676        let status = &entity.fields[1];
677        assert_eq!(status.column_default, Some("'idle'::task_status".to_string()));
678    }
679
680    #[test]
681    fn test_parse_column_default_with_json_quotes() {
682        let source = r#"
683            #[derive(Debug, Clone, sqlx::FromRow)]
684            #[sqlx_gen(kind = "table", table = "configs")]
685            pub struct Configs {
686                #[sqlx_gen(primary_key)]
687                pub id: i32,
688                #[sqlx_gen(column_default = "'{\"1\": \"\", \"2\": \"\"}'::jsonb")]
689                pub template_variables: String,
690            }
691        "#;
692        let entity = parse_entity_source(source).unwrap();
693        let field = &entity.fields[1];
694        assert_eq!(
695            field.column_default,
696            Some(r#"'{"1": "", "2": ""}'::jsonb"#.to_string())
697        );
698    }
699}