Skip to main content

sqlx_gen/codegen/
entity_parser.rs

1use std::path::Path;
2
3use quote::ToTokens;
4
5/// Represents a field parsed from a generated entity struct.
6#[derive(Debug, Clone)]
7pub struct ParsedField {
8    /// Rust field name (e.g. "connector_type")
9    pub rust_name: String,
10    /// Original DB column name. From `#[sqlx(rename = "...")]` if present, otherwise same as rust_name.
11    pub column_name: String,
12    /// Full Rust type as a string (e.g. "Option<String>", "i32", "uuid::Uuid")
13    pub rust_type: String,
14    /// Whether the type is wrapped in Option<T>
15    pub is_nullable: bool,
16    /// The inner type if nullable, or the full type if not
17    pub inner_type: String,
18    /// Whether this field is a primary key (`#[sqlx_gen(primary_key)]`)
19    pub is_primary_key: bool,
20    /// SQL type name for custom types needing a cast (e.g. "agent.connector_usages")
21    pub sql_type: Option<String>,
22    /// Whether the SQL type is an array (needs `[]` suffix in cast)
23    pub is_sql_array: bool,
24}
25
26/// Represents an entity parsed from a generated Rust file.
27#[derive(Debug, Clone)]
28pub struct ParsedEntity {
29    /// Struct name in PascalCase (e.g. "Users", "UserRoles")
30    pub struct_name: String,
31    /// Original table/view name from `#[sqlx_gen(table = "...")]`
32    pub table_name: String,
33    /// Schema name from `#[sqlx_gen(schema = "...")]`
34    pub schema_name: Option<String>,
35    /// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
36    pub is_view: bool,
37    /// Parsed fields
38    pub fields: Vec<ParsedField>,
39    /// `use` imports from the entity source file (e.g. "use chrono::{DateTime, Utc};")
40    pub imports: Vec<String>,
41}
42
43/// Parse an entity struct from a `.rs` file on disk.
44pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
45    let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
46    parse_entity_source(&source).map_err(|e| {
47        crate::error::Error::Config(format!("{}: {}", path.display(), e))
48    })
49}
50
51/// Parse an entity struct from a Rust source string.
52pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
53    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
54
55    // Collect use imports (excluding serde and sqlx derives)
56    let imports = extract_use_imports(&syntax);
57
58    for item in &syntax.items {
59        if let syn::Item::Struct(item_struct) = item {
60            if has_from_row_derive(item_struct) {
61                let mut entity = extract_entity(item_struct)?;
62                entity.imports = imports;
63                return Ok(entity);
64            }
65        }
66    }
67
68    Err("No struct with sqlx::FromRow derive found".to_string())
69}
70
71/// Check if a struct has `sqlx::FromRow` in its derive attributes.
72fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
73    for attr in &item.attrs {
74        if attr.path().is_ident("derive") {
75            let tokens = attr.meta.to_token_stream().to_string();
76            if tokens.contains("FromRow") {
77                return true;
78            }
79        }
80    }
81    false
82}
83
84/// Extract `use` imports from the source file, excluding serde/sqlx imports
85/// (those are already handled by the CRUD generator).
86fn extract_use_imports(file: &syn::File) -> Vec<String> {
87    file.items
88        .iter()
89        .filter_map(|item| {
90            if let syn::Item::Use(use_item) = item {
91                let text = use_item.to_token_stream().to_string();
92                // Skip serde and sqlx imports — the CRUD generator adds those itself
93                if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
94                    return None;
95                }
96                // Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
97                let normalized = normalize_use_statement(&text);
98                Some(normalized)
99            } else {
100                None
101            }
102        })
103        .collect()
104}
105
106/// Normalize a tokenized `use` statement by removing extra spaces around `::`, `{`, `}`, and `,`.
107fn normalize_use_statement(s: &str) -> String {
108    s.replace(" :: ", "::")
109        .replace(":: ", "::")
110        .replace(" ::", "::")
111        .replace("{ ", "{")
112        .replace(" }", "}")
113        .replace(" ,", ",")
114        .replace(" ;", ";")
115}
116
117/// Extract a ParsedEntity from a struct item.
118fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
119    let struct_name = item.ident.to_string();
120
121    let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
122    let is_view = kind.as_deref() == Some("view");
123
124    // Fall back to struct name if no table annotation
125    let table_name = table_name.unwrap_or_else(|| struct_name.clone());
126
127    let fields = match &item.fields {
128        syn::Fields::Named(named) => {
129            named
130                .named
131                .iter()
132                .map(extract_field)
133                .collect::<Result<Vec<_>, _>>()?
134        }
135        _ => return Err("Expected named fields".to_string()),
136    };
137
138    Ok(ParsedEntity {
139        struct_name,
140        table_name,
141        schema_name,
142        is_view,
143        fields,
144        imports: Vec::new(), // filled by parse_entity_source
145    })
146}
147
148/// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes.
149/// Returns (kind, schema_name, table_name).
150fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
151    let mut kind = None;
152    let mut schema_name = None;
153    let mut table_name = None;
154
155    for attr in attrs {
156        if attr.path().is_ident("sqlx_gen") {
157            let tokens = attr.meta.to_token_stream().to_string();
158            if let Some(k) = extract_attr_value(&tokens, "kind") {
159                kind = Some(k);
160            }
161            if let Some(s) = extract_attr_value(&tokens, "schema") {
162                schema_name = Some(s);
163            }
164            if let Some(t) = extract_attr_value(&tokens, "table") {
165                table_name = Some(t);
166            }
167        }
168    }
169
170    (kind, schema_name, table_name)
171}
172
173/// Extract a named string value from an attribute token string.
174/// e.g. extract_attr_value(`sqlx_gen(kind = "view", table = "users")`, "kind") -> Some("view")
175fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
176    let pattern = format!("{} = \"", key);
177    let start = tokens.find(&pattern)? + pattern.len();
178    let rest = &tokens[start..];
179    let end = rest.find('"')?;
180    Some(rest[..end].to_string())
181}
182
183/// Extract a ParsedField from a syn::Field.
184fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
185    let rust_name = field
186        .ident
187        .as_ref()
188        .ok_or("Unnamed field")?
189        .to_string();
190
191    let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
192    let (is_primary_key, sql_type, is_sql_array) = parse_sqlx_gen_field_attrs(&field.attrs);
193
194    let rust_type = field.ty.to_token_stream().to_string();
195    let (is_nullable, inner_type) = extract_option_type(&field.ty);
196    let inner_type = if is_nullable {
197        inner_type
198    } else {
199        rust_type.clone()
200    };
201
202    Ok(ParsedField {
203        rust_name,
204        column_name,
205        rust_type,
206        is_nullable,
207        inner_type,
208        is_primary_key,
209        sql_type,
210        is_sql_array,
211    })
212}
213
214/// Parse `#[sqlx_gen(...)]` attributes on a field.
215/// Returns (is_primary_key, sql_type, is_sql_array).
216fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option<String>, bool) {
217    let mut is_pk = false;
218    let mut sql_type = None;
219    let mut is_array = false;
220
221    for attr in attrs {
222        if attr.path().is_ident("sqlx_gen") {
223            let tokens = attr.meta.to_token_stream().to_string();
224            if tokens.contains("primary_key") {
225                is_pk = true;
226            }
227            if let Some(t) = extract_attr_value(&tokens, "sql_type") {
228                sql_type = Some(t);
229            }
230            if tokens.contains("is_array") {
231                is_array = true;
232            }
233        }
234    }
235
236    (is_pk, sql_type, is_array)
237}
238
239/// Extract `#[sqlx(rename = "...")]` value from field attributes.
240fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
241    for attr in attrs {
242        if attr.path().is_ident("sqlx") {
243            let tokens = attr.meta.to_token_stream().to_string();
244            return extract_attr_value(&tokens, "rename");
245        }
246    }
247    None
248}
249
250/// Check if a type is `Option<T>` and extract the inner type.
251fn extract_option_type(ty: &syn::Type) -> (bool, String) {
252    if let syn::Type::Path(type_path) = ty {
253        if let Some(segment) = type_path.path.segments.last() {
254            if segment.ident == "Option" {
255                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
256                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
257                        return (true, inner.to_token_stream().to_string());
258                    }
259                }
260            }
261        }
262    }
263    (false, String::new())
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    // --- basic parsing ---
271
272    #[test]
273    fn test_parse_simple_table() {
274        let source = r#"
275            #[derive(Debug, Clone, sqlx::FromRow)]
276            #[sqlx_gen(kind = "table", table = "users")]
277            pub struct Users {
278                pub id: i32,
279                pub name: String,
280            }
281        "#;
282        let entity = parse_entity_source(source).unwrap();
283        assert_eq!(entity.struct_name, "Users");
284        assert_eq!(entity.table_name, "users");
285        assert!(!entity.is_view);
286        assert_eq!(entity.fields.len(), 2);
287    }
288
289    #[test]
290    fn test_parse_view() {
291        let source = r#"
292            #[derive(Debug, Clone, sqlx::FromRow)]
293            #[sqlx_gen(kind = "view", table = "active_users")]
294            pub struct ActiveUsers {
295                pub id: i32,
296            }
297        "#;
298        let entity = parse_entity_source(source).unwrap();
299        assert!(entity.is_view);
300        assert_eq!(entity.table_name, "active_users");
301    }
302
303    #[test]
304    fn test_parse_table_not_view() {
305        let source = r#"
306            #[derive(Debug, Clone, sqlx::FromRow)]
307            #[sqlx_gen(kind = "table", table = "users")]
308            pub struct Users {
309                pub id: i32,
310            }
311        "#;
312        let entity = parse_entity_source(source).unwrap();
313        assert!(!entity.is_view);
314    }
315
316    // --- primary key ---
317
318    #[test]
319    fn test_parse_primary_key() {
320        let source = r#"
321            #[derive(Debug, Clone, sqlx::FromRow)]
322            #[sqlx_gen(kind = "table", table = "users")]
323            pub struct Users {
324                #[sqlx_gen(primary_key)]
325                pub id: i32,
326                pub name: String,
327            }
328        "#;
329        let entity = parse_entity_source(source).unwrap();
330        assert!(entity.fields[0].is_primary_key);
331        assert!(!entity.fields[1].is_primary_key);
332    }
333
334    #[test]
335    fn test_composite_primary_key() {
336        let source = r#"
337            #[derive(Debug, Clone, sqlx::FromRow)]
338            #[sqlx_gen(kind = "table", table = "user_roles")]
339            pub struct UserRoles {
340                #[sqlx_gen(primary_key)]
341                pub user_id: i32,
342                #[sqlx_gen(primary_key)]
343                pub role_id: i32,
344            }
345        "#;
346        let entity = parse_entity_source(source).unwrap();
347        assert!(entity.fields[0].is_primary_key);
348        assert!(entity.fields[1].is_primary_key);
349    }
350
351    #[test]
352    fn test_no_primary_key() {
353        let source = r#"
354            #[derive(Debug, Clone, sqlx::FromRow)]
355            #[sqlx_gen(kind = "table", table = "logs")]
356            pub struct Logs {
357                pub message: String,
358            }
359        "#;
360        let entity = parse_entity_source(source).unwrap();
361        assert!(!entity.fields[0].is_primary_key);
362    }
363
364    // --- sqlx rename ---
365
366    #[test]
367    fn test_sqlx_rename() {
368        let source = r#"
369            #[derive(Debug, Clone, sqlx::FromRow)]
370            #[sqlx_gen(kind = "table", table = "connector")]
371            pub struct Connector {
372                #[sqlx(rename = "type")]
373                pub connector_type: String,
374            }
375        "#;
376        let entity = parse_entity_source(source).unwrap();
377        assert_eq!(entity.fields[0].rust_name, "connector_type");
378        assert_eq!(entity.fields[0].column_name, "type");
379    }
380
381    #[test]
382    fn test_no_rename_uses_field_name() {
383        let source = r#"
384            #[derive(Debug, Clone, sqlx::FromRow)]
385            #[sqlx_gen(kind = "table", table = "users")]
386            pub struct Users {
387                pub name: String,
388            }
389        "#;
390        let entity = parse_entity_source(source).unwrap();
391        assert_eq!(entity.fields[0].rust_name, "name");
392        assert_eq!(entity.fields[0].column_name, "name");
393    }
394
395    // --- nullable types ---
396
397    #[test]
398    fn test_option_field_nullable() {
399        let source = r#"
400            #[derive(Debug, Clone, sqlx::FromRow)]
401            #[sqlx_gen(kind = "table", table = "users")]
402            pub struct Users {
403                pub email: Option<String>,
404            }
405        "#;
406        let entity = parse_entity_source(source).unwrap();
407        assert!(entity.fields[0].is_nullable);
408        assert_eq!(entity.fields[0].inner_type, "String");
409    }
410
411    #[test]
412    fn test_non_option_not_nullable() {
413        let source = r#"
414            #[derive(Debug, Clone, sqlx::FromRow)]
415            #[sqlx_gen(kind = "table", table = "users")]
416            pub struct Users {
417                pub id: i32,
418            }
419        "#;
420        let entity = parse_entity_source(source).unwrap();
421        assert!(!entity.fields[0].is_nullable);
422        assert_eq!(entity.fields[0].inner_type, "i32");
423    }
424
425    #[test]
426    fn test_option_complex_type() {
427        let source = r#"
428            #[derive(Debug, Clone, sqlx::FromRow)]
429            #[sqlx_gen(kind = "table", table = "users")]
430            pub struct Users {
431                pub created_at: Option<chrono::NaiveDateTime>,
432            }
433        "#;
434        let entity = parse_entity_source(source).unwrap();
435        assert!(entity.fields[0].is_nullable);
436        assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
437    }
438
439    // --- type preservation ---
440
441    #[test]
442    fn test_rust_type_preserved() {
443        let source = r#"
444            #[derive(Debug, Clone, sqlx::FromRow)]
445            #[sqlx_gen(kind = "table", table = "users")]
446            pub struct Users {
447                pub id: uuid::Uuid,
448            }
449        "#;
450        let entity = parse_entity_source(source).unwrap();
451        assert!(entity.fields[0].rust_type.contains("Uuid"));
452    }
453
454    // --- error cases ---
455
456    #[test]
457    fn test_no_from_row_struct() {
458        let source = r#"
459            pub struct NotAnEntity {
460                pub id: i32,
461            }
462        "#;
463        let result = parse_entity_source(source);
464        assert!(result.is_err());
465    }
466
467    #[test]
468    fn test_empty_source() {
469        let result = parse_entity_source("");
470        assert!(result.is_err());
471    }
472
473    // --- fallback table name ---
474
475    #[test]
476    fn test_fallback_table_name_to_struct_name() {
477        let source = r#"
478            #[derive(Debug, Clone, sqlx::FromRow)]
479            pub struct Users {
480                pub id: i32,
481            }
482        "#;
483        let entity = parse_entity_source(source).unwrap();
484        assert_eq!(entity.table_name, "Users");
485    }
486
487    // --- combined attributes ---
488
489    #[test]
490    fn test_pk_with_rename() {
491        let source = r#"
492            #[derive(Debug, Clone, sqlx::FromRow)]
493            #[sqlx_gen(kind = "table", table = "items")]
494            pub struct Items {
495                #[sqlx_gen(primary_key)]
496                #[sqlx(rename = "itemID")]
497                pub item_id: i32,
498            }
499        "#;
500        let entity = parse_entity_source(source).unwrap();
501        let f = &entity.fields[0];
502        assert!(f.is_primary_key);
503        assert_eq!(f.column_name, "itemID");
504        assert_eq!(f.rust_name, "item_id");
505    }
506
507    #[test]
508    fn test_full_entity() {
509        let source = r#"
510            #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
511            #[sqlx_gen(kind = "table", table = "users")]
512            pub struct Users {
513                #[sqlx_gen(primary_key)]
514                pub id: i32,
515                pub name: String,
516                pub email: Option<String>,
517                #[sqlx(rename = "createdAt")]
518                pub created_at: chrono::NaiveDateTime,
519            }
520        "#;
521        let entity = parse_entity_source(source).unwrap();
522        assert_eq!(entity.struct_name, "Users");
523        assert_eq!(entity.table_name, "users");
524        assert!(!entity.is_view);
525        assert_eq!(entity.fields.len(), 4);
526
527        assert!(entity.fields[0].is_primary_key);
528        assert_eq!(entity.fields[0].rust_name, "id");
529
530        assert!(!entity.fields[1].is_primary_key);
531        assert_eq!(entity.fields[1].rust_type, "String");
532
533        assert!(entity.fields[2].is_nullable);
534        assert_eq!(entity.fields[2].inner_type, "String");
535
536        assert_eq!(entity.fields[3].column_name, "createdAt");
537        assert_eq!(entity.fields[3].rust_name, "created_at");
538    }
539
540    // --- imports extraction ---
541
542    #[test]
543    fn test_imports_extracted() {
544        let source = r#"
545            use chrono::{DateTime, Utc};
546            use uuid::Uuid;
547            use serde::{Serialize, Deserialize};
548
549            #[derive(Debug, Clone, sqlx::FromRow)]
550            #[sqlx_gen(kind = "table", table = "users")]
551            pub struct Users {
552                pub id: Uuid,
553                pub created_at: DateTime<Utc>,
554            }
555        "#;
556        let entity = parse_entity_source(source).unwrap();
557        assert_eq!(entity.imports.len(), 2);
558        assert!(entity.imports.iter().any(|i| i.contains("chrono")));
559        assert!(entity.imports.iter().any(|i| i.contains("uuid")));
560        // serde should be excluded
561        assert!(!entity.imports.iter().any(|i| i.contains("serde")));
562    }
563
564    #[test]
565    fn test_imports_empty_when_none() {
566        let source = r#"
567            #[derive(Debug, Clone, sqlx::FromRow)]
568            #[sqlx_gen(kind = "table", table = "users")]
569            pub struct Users {
570                pub id: i32,
571            }
572        "#;
573        let entity = parse_entity_source(source).unwrap();
574        assert!(entity.imports.is_empty());
575    }
576
577    #[test]
578    fn test_imports_keep_serde_json() {
579        let source = r#"
580            use serde::{Serialize, Deserialize};
581            use serde_json::Value;
582
583            #[derive(Debug, Clone, sqlx::FromRow)]
584            #[sqlx_gen(kind = "table", table = "users")]
585            pub struct Users {
586                pub data: Value,
587            }
588        "#;
589        let entity = parse_entity_source(source).unwrap();
590        assert_eq!(entity.imports.len(), 1);
591        assert!(entity.imports[0].contains("serde_json"));
592    }
593
594    #[test]
595    fn test_imports_exclude_sqlx() {
596        let source = r#"
597            use sqlx::types::Uuid;
598            use chrono::NaiveDateTime;
599
600            #[derive(Debug, Clone, sqlx::FromRow)]
601            #[sqlx_gen(kind = "table", table = "users")]
602            pub struct Users {
603                pub id: i32,
604            }
605        "#;
606        let entity = parse_entity_source(source).unwrap();
607        assert_eq!(entity.imports.len(), 1);
608        assert!(entity.imports[0].contains("chrono"));
609    }
610}