Skip to main content

sqlx_gen/codegen/
entity_parser.rs

1use std::path::Path;
2
3use quote::ToTokens;
4
5/// Represents a field parsed from a generated entity struct.
6#[derive(Debug, Clone)]
7pub struct ParsedField {
8    /// Rust field name (e.g. "connector_type")
9    pub rust_name: String,
10    /// Original DB column name. From `#[sqlx(rename = "...")]` if present, otherwise same as rust_name.
11    pub column_name: String,
12    /// Full Rust type as a string (e.g. "Option<String>", "i32", "uuid::Uuid")
13    pub rust_type: String,
14    /// Whether the type is wrapped in Option<T>
15    pub is_nullable: bool,
16    /// The inner type if nullable, or the full type if not
17    pub inner_type: String,
18    /// Whether this field is a primary key (`#[sqlx_gen(primary_key)]`)
19    pub is_primary_key: bool,
20    /// SQL type name for custom types needing a cast (e.g. "agent.connector_usages")
21    pub sql_type: Option<String>,
22    /// Whether the SQL type is an array (needs `[]` suffix in cast)
23    pub is_sql_array: bool,
24    /// Raw SQL default expression from the DB (e.g. "now()", "'idle'::task_status")
25    pub column_default: Option<String>,
26}
27
28/// Represents an entity parsed from a generated Rust file.
29#[derive(Debug, Clone)]
30pub struct ParsedEntity {
31    /// Struct name in PascalCase (e.g. "Users", "UserRoles")
32    pub struct_name: String,
33    /// Original table/view name from `#[sqlx_gen(table = "...")]`
34    pub table_name: String,
35    /// Schema name from `#[sqlx_gen(schema = "...")]`
36    pub schema_name: Option<String>,
37    /// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
38    pub is_view: bool,
39    /// Parsed fields
40    pub fields: Vec<ParsedField>,
41    /// `use` imports from the entity source file (e.g. "use chrono::{DateTime, Utc};")
42    pub imports: Vec<String>,
43}
44
45/// Parse an entity struct from a `.rs` file on disk.
46pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
47    let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
48    parse_entity_source(&source)
49        .map_err(|e| crate::error::Error::Config(format!("{}: {}", path.display(), e)))
50}
51
52/// Parse an entity struct from a Rust source string.
53pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
54    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
55
56    // Collect use imports (excluding serde and sqlx derives)
57    let imports = extract_use_imports(&syntax);
58
59    for item in &syntax.items {
60        if let syn::Item::Struct(item_struct) = item {
61            if has_from_row_derive(item_struct) {
62                let mut entity = extract_entity(item_struct)?;
63                entity.imports = imports;
64                return Ok(entity);
65            }
66        }
67    }
68
69    Err("No struct with sqlx::FromRow derive found".to_string())
70}
71
72/// Check if a struct has `sqlx::FromRow` in its derive attributes.
73fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
74    for attr in &item.attrs {
75        if attr.path().is_ident("derive") {
76            let tokens = attr.meta.to_token_stream().to_string();
77            if tokens.contains("FromRow") {
78                return true;
79            }
80        }
81    }
82    false
83}
84
85/// Extract `use` imports from the source file, excluding serde/sqlx imports
86/// (those are already handled by the CRUD generator).
87fn extract_use_imports(file: &syn::File) -> Vec<String> {
88    file.items
89        .iter()
90        .filter_map(|item| {
91            if let syn::Item::Use(use_item) = item {
92                let text = use_item.to_token_stream().to_string();
93                // Skip serde and sqlx imports — the CRUD generator adds those itself
94                if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
95                    return None;
96                }
97                // Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
98                let normalized = normalize_use_statement(&text);
99                Some(normalized)
100            } else {
101                None
102            }
103        })
104        .collect()
105}
106
107/// Normalize a tokenized `use` statement by removing extra spaces around `::`, `{`, `}`, and `,`.
108fn normalize_use_statement(s: &str) -> String {
109    s.replace(" :: ", "::")
110        .replace(":: ", "::")
111        .replace(" ::", "::")
112        .replace("{ ", "{")
113        .replace(" }", "}")
114        .replace(" ,", ",")
115        .replace(" ;", ";")
116}
117
118/// Extract a ParsedEntity from a struct item.
119fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
120    let struct_name = item.ident.to_string();
121
122    let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
123    let is_view = kind.as_deref() == Some("view");
124
125    // Fall back to struct name if no table annotation
126    let table_name = table_name.unwrap_or_else(|| struct_name.clone());
127
128    let fields = match &item.fields {
129        syn::Fields::Named(named) => named
130            .named
131            .iter()
132            .map(extract_field)
133            .collect::<Result<Vec<_>, _>>()?,
134        _ => return Err("Expected named fields".to_string()),
135    };
136
137    Ok(ParsedEntity {
138        struct_name,
139        table_name,
140        schema_name,
141        is_view,
142        fields,
143        imports: Vec::new(), // filled by parse_entity_source
144    })
145}
146
147/// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes.
148/// Returns (kind, schema_name, table_name).
149fn parse_sqlx_gen_struct_attrs(
150    attrs: &[syn::Attribute],
151) -> (Option<String>, Option<String>, Option<String>) {
152    let mut kind = None;
153    let mut schema_name = None;
154    let mut table_name = None;
155
156    for attr in attrs {
157        if attr.path().is_ident("sqlx_gen") {
158            let tokens = attr.meta.to_token_stream().to_string();
159            if let Some(k) = extract_attr_value(&tokens, "kind") {
160                kind = Some(k);
161            }
162            if let Some(s) = extract_attr_value(&tokens, "schema") {
163                schema_name = Some(s);
164            }
165            if let Some(t) = extract_attr_value(&tokens, "table") {
166                table_name = Some(t);
167            }
168        }
169    }
170
171    (kind, schema_name, table_name)
172}
173
174/// Extract a named string value from an attribute token string.
175/// e.g. extract_attr_value(`sqlx_gen(kind = "view", table = "users")`, "kind") -> Some("view")
176fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
177    let pattern = format!("{} = \"", key);
178    let start = tokens.find(&pattern)? + pattern.len();
179    let rest = &tokens[start..];
180    let bytes = rest.as_bytes();
181    let mut end = 0;
182    while end < bytes.len() {
183        if bytes[end] == b'"' && (end == 0 || bytes[end - 1] != b'\\') {
184            break;
185        }
186        end += 1;
187    }
188    if end >= bytes.len() {
189        return None;
190    }
191    Some(rest[..end].replace("\\\"", "\""))
192}
193
194/// Extract a ParsedField from a syn::Field.
195fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
196    let rust_name = field.ident.as_ref().ok_or("Unnamed field")?.to_string();
197
198    let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
199    let (is_primary_key, sql_type, is_sql_array, column_default) =
200        parse_sqlx_gen_field_attrs(&field.attrs);
201
202    let rust_type = field.ty.to_token_stream().to_string();
203    let (is_nullable, inner_type) = extract_option_type(&field.ty);
204    let inner_type = if is_nullable {
205        inner_type
206    } else {
207        rust_type.clone()
208    };
209
210    Ok(ParsedField {
211        rust_name,
212        column_name,
213        rust_type,
214        is_nullable,
215        inner_type,
216        is_primary_key,
217        sql_type,
218        is_sql_array,
219        column_default,
220    })
221}
222
223/// Parse `#[sqlx_gen(...)]` attributes on a field.
224/// Returns (is_primary_key, sql_type, is_sql_array, column_default).
225fn parse_sqlx_gen_field_attrs(
226    attrs: &[syn::Attribute],
227) -> (bool, Option<String>, bool, Option<String>) {
228    let mut is_pk = false;
229    let mut sql_type = None;
230    let mut is_array = false;
231    let mut column_default = None;
232
233    for attr in attrs {
234        if attr.path().is_ident("sqlx_gen") {
235            let tokens = attr.meta.to_token_stream().to_string();
236            if tokens.contains("primary_key") {
237                is_pk = true;
238            }
239            if let Some(t) = extract_attr_value(&tokens, "sql_type") {
240                sql_type = Some(t);
241            }
242            if tokens.contains("is_array") {
243                is_array = true;
244            }
245            if let Some(d) = extract_attr_value(&tokens, "column_default") {
246                column_default = Some(d);
247            }
248        }
249    }
250
251    (is_pk, sql_type, is_array, column_default)
252}
253
254/// Extract `#[sqlx(rename = "...")]` value from field attributes.
255fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
256    for attr in attrs {
257        if attr.path().is_ident("sqlx") {
258            let tokens = attr.meta.to_token_stream().to_string();
259            return extract_attr_value(&tokens, "rename");
260        }
261    }
262    None
263}
264
265/// Check if a type is `Option<T>` and extract the inner type.
266fn extract_option_type(ty: &syn::Type) -> (bool, String) {
267    if let syn::Type::Path(type_path) = ty {
268        if let Some(segment) = type_path.path.segments.last() {
269            if segment.ident == "Option" {
270                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
271                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
272                        return (true, inner.to_token_stream().to_string());
273                    }
274                }
275            }
276        }
277    }
278    (false, String::new())
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    // --- basic parsing ---
286
287    #[test]
288    fn test_parse_simple_table() {
289        let source = r#"
290            #[derive(Debug, Clone, sqlx::FromRow)]
291            #[sqlx_gen(kind = "table", table = "users")]
292            pub struct Users {
293                pub id: i32,
294                pub name: String,
295            }
296        "#;
297        let entity = parse_entity_source(source).unwrap();
298        assert_eq!(entity.struct_name, "Users");
299        assert_eq!(entity.table_name, "users");
300        assert!(!entity.is_view);
301        assert_eq!(entity.fields.len(), 2);
302    }
303
304    #[test]
305    fn test_parse_view() {
306        let source = r#"
307            #[derive(Debug, Clone, sqlx::FromRow)]
308            #[sqlx_gen(kind = "view", table = "active_users")]
309            pub struct ActiveUsers {
310                pub id: i32,
311            }
312        "#;
313        let entity = parse_entity_source(source).unwrap();
314        assert!(entity.is_view);
315        assert_eq!(entity.table_name, "active_users");
316    }
317
318    #[test]
319    fn test_parse_table_not_view() {
320        let source = r#"
321            #[derive(Debug, Clone, sqlx::FromRow)]
322            #[sqlx_gen(kind = "table", table = "users")]
323            pub struct Users {
324                pub id: i32,
325            }
326        "#;
327        let entity = parse_entity_source(source).unwrap();
328        assert!(!entity.is_view);
329    }
330
331    // --- primary key ---
332
333    #[test]
334    fn test_parse_primary_key() {
335        let source = r#"
336            #[derive(Debug, Clone, sqlx::FromRow)]
337            #[sqlx_gen(kind = "table", table = "users")]
338            pub struct Users {
339                #[sqlx_gen(primary_key)]
340                pub id: i32,
341                pub name: String,
342            }
343        "#;
344        let entity = parse_entity_source(source).unwrap();
345        assert!(entity.fields[0].is_primary_key);
346        assert!(!entity.fields[1].is_primary_key);
347    }
348
349    #[test]
350    fn test_composite_primary_key() {
351        let source = r#"
352            #[derive(Debug, Clone, sqlx::FromRow)]
353            #[sqlx_gen(kind = "table", table = "user_roles")]
354            pub struct UserRoles {
355                #[sqlx_gen(primary_key)]
356                pub user_id: i32,
357                #[sqlx_gen(primary_key)]
358                pub role_id: i32,
359            }
360        "#;
361        let entity = parse_entity_source(source).unwrap();
362        assert!(entity.fields[0].is_primary_key);
363        assert!(entity.fields[1].is_primary_key);
364    }
365
366    #[test]
367    fn test_no_primary_key() {
368        let source = r#"
369            #[derive(Debug, Clone, sqlx::FromRow)]
370            #[sqlx_gen(kind = "table", table = "logs")]
371            pub struct Logs {
372                pub message: String,
373            }
374        "#;
375        let entity = parse_entity_source(source).unwrap();
376        assert!(!entity.fields[0].is_primary_key);
377    }
378
379    // --- sqlx rename ---
380
381    #[test]
382    fn test_sqlx_rename() {
383        let source = r#"
384            #[derive(Debug, Clone, sqlx::FromRow)]
385            #[sqlx_gen(kind = "table", table = "connector")]
386            pub struct Connector {
387                #[sqlx(rename = "type")]
388                pub connector_type: String,
389            }
390        "#;
391        let entity = parse_entity_source(source).unwrap();
392        assert_eq!(entity.fields[0].rust_name, "connector_type");
393        assert_eq!(entity.fields[0].column_name, "type");
394    }
395
396    #[test]
397    fn test_no_rename_uses_field_name() {
398        let source = r#"
399            #[derive(Debug, Clone, sqlx::FromRow)]
400            #[sqlx_gen(kind = "table", table = "users")]
401            pub struct Users {
402                pub name: String,
403            }
404        "#;
405        let entity = parse_entity_source(source).unwrap();
406        assert_eq!(entity.fields[0].rust_name, "name");
407        assert_eq!(entity.fields[0].column_name, "name");
408    }
409
410    // --- nullable types ---
411
412    #[test]
413    fn test_option_field_nullable() {
414        let source = r#"
415            #[derive(Debug, Clone, sqlx::FromRow)]
416            #[sqlx_gen(kind = "table", table = "users")]
417            pub struct Users {
418                pub email: Option<String>,
419            }
420        "#;
421        let entity = parse_entity_source(source).unwrap();
422        assert!(entity.fields[0].is_nullable);
423        assert_eq!(entity.fields[0].inner_type, "String");
424    }
425
426    #[test]
427    fn test_non_option_not_nullable() {
428        let source = r#"
429            #[derive(Debug, Clone, sqlx::FromRow)]
430            #[sqlx_gen(kind = "table", table = "users")]
431            pub struct Users {
432                pub id: i32,
433            }
434        "#;
435        let entity = parse_entity_source(source).unwrap();
436        assert!(!entity.fields[0].is_nullable);
437        assert_eq!(entity.fields[0].inner_type, "i32");
438    }
439
440    #[test]
441    fn test_option_complex_type() {
442        let source = r#"
443            #[derive(Debug, Clone, sqlx::FromRow)]
444            #[sqlx_gen(kind = "table", table = "users")]
445            pub struct Users {
446                pub created_at: Option<chrono::NaiveDateTime>,
447            }
448        "#;
449        let entity = parse_entity_source(source).unwrap();
450        assert!(entity.fields[0].is_nullable);
451        assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
452    }
453
454    // --- type preservation ---
455
456    #[test]
457    fn test_rust_type_preserved() {
458        let source = r#"
459            #[derive(Debug, Clone, sqlx::FromRow)]
460            #[sqlx_gen(kind = "table", table = "users")]
461            pub struct Users {
462                pub id: uuid::Uuid,
463            }
464        "#;
465        let entity = parse_entity_source(source).unwrap();
466        assert!(entity.fields[0].rust_type.contains("Uuid"));
467    }
468
469    // --- error cases ---
470
471    #[test]
472    fn test_no_from_row_struct() {
473        let source = r#"
474            pub struct NotAnEntity {
475                pub id: i32,
476            }
477        "#;
478        let result = parse_entity_source(source);
479        assert!(result.is_err());
480    }
481
482    #[test]
483    fn test_empty_source() {
484        let result = parse_entity_source("");
485        assert!(result.is_err());
486    }
487
488    // --- fallback table name ---
489
490    #[test]
491    fn test_fallback_table_name_to_struct_name() {
492        let source = r#"
493            #[derive(Debug, Clone, sqlx::FromRow)]
494            pub struct Users {
495                pub id: i32,
496            }
497        "#;
498        let entity = parse_entity_source(source).unwrap();
499        assert_eq!(entity.table_name, "Users");
500    }
501
502    // --- combined attributes ---
503
504    #[test]
505    fn test_pk_with_rename() {
506        let source = r#"
507            #[derive(Debug, Clone, sqlx::FromRow)]
508            #[sqlx_gen(kind = "table", table = "items")]
509            pub struct Items {
510                #[sqlx_gen(primary_key)]
511                #[sqlx(rename = "itemID")]
512                pub item_id: i32,
513            }
514        "#;
515        let entity = parse_entity_source(source).unwrap();
516        let f = &entity.fields[0];
517        assert!(f.is_primary_key);
518        assert_eq!(f.column_name, "itemID");
519        assert_eq!(f.rust_name, "item_id");
520    }
521
522    #[test]
523    fn test_full_entity() {
524        let source = r#"
525            #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
526            #[sqlx_gen(kind = "table", table = "users")]
527            pub struct Users {
528                #[sqlx_gen(primary_key)]
529                pub id: i32,
530                pub name: String,
531                pub email: Option<String>,
532                #[sqlx(rename = "createdAt")]
533                pub created_at: chrono::NaiveDateTime,
534            }
535        "#;
536        let entity = parse_entity_source(source).unwrap();
537        assert_eq!(entity.struct_name, "Users");
538        assert_eq!(entity.table_name, "users");
539        assert!(!entity.is_view);
540        assert_eq!(entity.fields.len(), 4);
541
542        assert!(entity.fields[0].is_primary_key);
543        assert_eq!(entity.fields[0].rust_name, "id");
544
545        assert!(!entity.fields[1].is_primary_key);
546        assert_eq!(entity.fields[1].rust_type, "String");
547
548        assert!(entity.fields[2].is_nullable);
549        assert_eq!(entity.fields[2].inner_type, "String");
550
551        assert_eq!(entity.fields[3].column_name, "createdAt");
552        assert_eq!(entity.fields[3].rust_name, "created_at");
553    }
554
555    // --- imports extraction ---
556
557    #[test]
558    fn test_imports_extracted() {
559        let source = r#"
560            use chrono::{DateTime, Utc};
561            use uuid::Uuid;
562            use serde::{Serialize, Deserialize};
563
564            #[derive(Debug, Clone, sqlx::FromRow)]
565            #[sqlx_gen(kind = "table", table = "users")]
566            pub struct Users {
567                pub id: Uuid,
568                pub created_at: DateTime<Utc>,
569            }
570        "#;
571        let entity = parse_entity_source(source).unwrap();
572        assert_eq!(entity.imports.len(), 2);
573        assert!(entity.imports.iter().any(|i| i.contains("chrono")));
574        assert!(entity.imports.iter().any(|i| i.contains("uuid")));
575        // serde should be excluded
576        assert!(!entity.imports.iter().any(|i| i.contains("serde")));
577    }
578
579    #[test]
580    fn test_imports_empty_when_none() {
581        let source = r#"
582            #[derive(Debug, Clone, sqlx::FromRow)]
583            #[sqlx_gen(kind = "table", table = "users")]
584            pub struct Users {
585                pub id: i32,
586            }
587        "#;
588        let entity = parse_entity_source(source).unwrap();
589        assert!(entity.imports.is_empty());
590    }
591
592    #[test]
593    fn test_imports_keep_serde_json() {
594        let source = r#"
595            use serde::{Serialize, Deserialize};
596            use serde_json::Value;
597
598            #[derive(Debug, Clone, sqlx::FromRow)]
599            #[sqlx_gen(kind = "table", table = "users")]
600            pub struct Users {
601                pub data: Value,
602            }
603        "#;
604        let entity = parse_entity_source(source).unwrap();
605        assert_eq!(entity.imports.len(), 1);
606        assert!(entity.imports[0].contains("serde_json"));
607    }
608
609    #[test]
610    fn test_imports_exclude_sqlx() {
611        let source = r#"
612            use sqlx::types::Uuid;
613            use chrono::NaiveDateTime;
614
615            #[derive(Debug, Clone, sqlx::FromRow)]
616            #[sqlx_gen(kind = "table", table = "users")]
617            pub struct Users {
618                pub id: i32,
619            }
620        "#;
621        let entity = parse_entity_source(source).unwrap();
622        assert_eq!(entity.imports.len(), 1);
623        assert!(entity.imports[0].contains("chrono"));
624    }
625
626    // --- column_default parsing ---
627
628    #[test]
629    fn test_parse_column_default() {
630        let source = r#"
631            #[derive(Debug, Clone, sqlx::FromRow)]
632            #[sqlx_gen(kind = "table", table = "tasks")]
633            pub struct Tasks {
634                #[sqlx_gen(primary_key)]
635                pub id: i32,
636                #[sqlx_gen(column_default = "now()")]
637                pub created_at: String,
638            }
639        "#;
640        let entity = parse_entity_source(source).unwrap();
641        let created_at = &entity.fields[1];
642        assert_eq!(created_at.column_default, Some("now()".to_string()));
643    }
644
645    #[test]
646    fn test_parse_no_column_default() {
647        let source = r#"
648            #[derive(Debug, Clone, sqlx::FromRow)]
649            #[sqlx_gen(kind = "table", table = "tasks")]
650            pub struct Tasks {
651                #[sqlx_gen(primary_key)]
652                pub id: i32,
653                pub title: String,
654            }
655        "#;
656        let entity = parse_entity_source(source).unwrap();
657        let title = &entity.fields[1];
658        assert_eq!(title.column_default, None);
659    }
660
661    #[test]
662    fn test_parse_column_default_with_cast() {
663        let source = r#"
664            #[derive(Debug, Clone, sqlx::FromRow)]
665            #[sqlx_gen(kind = "table", table = "tasks")]
666            pub struct Tasks {
667                #[sqlx_gen(primary_key)]
668                pub id: i32,
669                #[sqlx_gen(column_default = "'idle'::task_status")]
670                pub status: String,
671            }
672        "#;
673        let entity = parse_entity_source(source).unwrap();
674        let status = &entity.fields[1];
675        assert_eq!(
676            status.column_default,
677            Some("'idle'::task_status".to_string())
678        );
679    }
680
681    #[test]
682    fn test_parse_column_default_with_json_quotes() {
683        let source = r#"
684            #[derive(Debug, Clone, sqlx::FromRow)]
685            #[sqlx_gen(kind = "table", table = "configs")]
686            pub struct Configs {
687                #[sqlx_gen(primary_key)]
688                pub id: i32,
689                #[sqlx_gen(column_default = "'{\"1\": \"\", \"2\": \"\"}'::jsonb")]
690                pub template_variables: String,
691            }
692        "#;
693        let entity = parse_entity_source(source).unwrap();
694        let field = &entity.fields[1];
695        assert_eq!(
696            field.column_default,
697            Some(r#"'{"1": "", "2": ""}'::jsonb"#.to_string())
698        );
699    }
700}