Skip to main content

sqlx_gen/codegen/
entity_parser.rs

1use std::path::Path;
2
3use quote::ToTokens;
4
5/// Represents a field parsed from a generated entity struct.
6#[derive(Debug, Clone)]
7pub struct ParsedField {
8    /// Rust field name (e.g. "connector_type")
9    pub rust_name: String,
10    /// Original DB column name. From `#[sqlx(rename = "...")]` if present, otherwise same as rust_name.
11    pub column_name: String,
12    /// Full Rust type as a string (e.g. "Option<String>", "i32", "uuid::Uuid")
13    pub rust_type: String,
14    /// Whether the type is wrapped in Option<T>
15    pub is_nullable: bool,
16    /// The inner type if nullable, or the full type if not
17    pub inner_type: String,
18    /// Whether this field is a primary key (`#[sqlx_gen(primary_key)]`)
19    pub is_primary_key: bool,
20}
21
22/// Represents an entity parsed from a generated Rust file.
23#[derive(Debug, Clone)]
24pub struct ParsedEntity {
25    /// Struct name in PascalCase (e.g. "Users", "UserRoles")
26    pub struct_name: String,
27    /// Original table/view name from `#[sqlx_gen(table = "...")]`
28    pub table_name: String,
29    /// Schema name from `#[sqlx_gen(schema = "...")]`
30    pub schema_name: Option<String>,
31    /// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
32    pub is_view: bool,
33    /// Parsed fields
34    pub fields: Vec<ParsedField>,
35    /// `use` imports from the entity source file (e.g. "use chrono::{DateTime, Utc};")
36    pub imports: Vec<String>,
37}
38
39/// Parse an entity struct from a `.rs` file on disk.
40pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
41    let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
42    parse_entity_source(&source).map_err(|e| {
43        crate::error::Error::Config(format!("{}: {}", path.display(), e))
44    })
45}
46
47/// Parse an entity struct from a Rust source string.
48pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
49    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
50
51    // Collect use imports (excluding serde and sqlx derives)
52    let imports = extract_use_imports(&syntax);
53
54    for item in &syntax.items {
55        if let syn::Item::Struct(item_struct) = item {
56            if has_from_row_derive(item_struct) {
57                let mut entity = extract_entity(item_struct)?;
58                entity.imports = imports;
59                return Ok(entity);
60            }
61        }
62    }
63
64    Err("No struct with sqlx::FromRow derive found".to_string())
65}
66
67/// Check if a struct has `sqlx::FromRow` in its derive attributes.
68fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
69    for attr in &item.attrs {
70        if attr.path().is_ident("derive") {
71            let tokens = attr.meta.to_token_stream().to_string();
72            if tokens.contains("FromRow") {
73                return true;
74            }
75        }
76    }
77    false
78}
79
80/// Extract `use` imports from the source file, excluding serde/sqlx imports
81/// (those are already handled by the CRUD generator).
82fn extract_use_imports(file: &syn::File) -> Vec<String> {
83    file.items
84        .iter()
85        .filter_map(|item| {
86            if let syn::Item::Use(use_item) = item {
87                let text = use_item.to_token_stream().to_string();
88                // Skip serde and sqlx imports — the CRUD generator adds those itself
89                if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
90                    return None;
91                }
92                // Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
93                let normalized = normalize_use_statement(&text);
94                Some(normalized)
95            } else {
96                None
97            }
98        })
99        .collect()
100}
101
102/// Normalize a tokenized `use` statement by removing extra spaces around `::`, `{`, `}`, and `,`.
103fn normalize_use_statement(s: &str) -> String {
104    s.replace(" :: ", "::")
105        .replace(":: ", "::")
106        .replace(" ::", "::")
107        .replace("{ ", "{")
108        .replace(" }", "}")
109        .replace(" ,", ",")
110        .replace(" ;", ";")
111}
112
113/// Extract a ParsedEntity from a struct item.
114fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
115    let struct_name = item.ident.to_string();
116
117    let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
118    let is_view = kind.as_deref() == Some("view");
119
120    // Fall back to struct name if no table annotation
121    let table_name = table_name.unwrap_or_else(|| struct_name.clone());
122
123    let fields = match &item.fields {
124        syn::Fields::Named(named) => {
125            named
126                .named
127                .iter()
128                .map(extract_field)
129                .collect::<Result<Vec<_>, _>>()?
130        }
131        _ => return Err("Expected named fields".to_string()),
132    };
133
134    Ok(ParsedEntity {
135        struct_name,
136        table_name,
137        schema_name,
138        is_view,
139        fields,
140        imports: Vec::new(), // filled by parse_entity_source
141    })
142}
143
144/// Parse `#[sqlx_gen(kind = "...", schema = "...", table = "...")]` from struct attributes.
145/// Returns (kind, schema_name, table_name).
146fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
147    let mut kind = None;
148    let mut schema_name = None;
149    let mut table_name = None;
150
151    for attr in attrs {
152        if attr.path().is_ident("sqlx_gen") {
153            let tokens = attr.meta.to_token_stream().to_string();
154            if let Some(k) = extract_attr_value(&tokens, "kind") {
155                kind = Some(k);
156            }
157            if let Some(s) = extract_attr_value(&tokens, "schema") {
158                schema_name = Some(s);
159            }
160            if let Some(t) = extract_attr_value(&tokens, "table") {
161                table_name = Some(t);
162            }
163        }
164    }
165
166    (kind, schema_name, table_name)
167}
168
169/// Extract a named string value from an attribute token string.
170/// e.g. extract_attr_value(`sqlx_gen(kind = "view", table = "users")`, "kind") -> Some("view")
171fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
172    let pattern = format!("{} = \"", key);
173    let start = tokens.find(&pattern)? + pattern.len();
174    let rest = &tokens[start..];
175    let end = rest.find('"')?;
176    Some(rest[..end].to_string())
177}
178
179/// Extract a ParsedField from a syn::Field.
180fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
181    let rust_name = field
182        .ident
183        .as_ref()
184        .ok_or("Unnamed field")?
185        .to_string();
186
187    let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
188    let is_primary_key = has_sqlx_gen_primary_key(&field.attrs);
189
190    let rust_type = field.ty.to_token_stream().to_string();
191    let (is_nullable, inner_type) = extract_option_type(&field.ty);
192    let inner_type = if is_nullable {
193        inner_type
194    } else {
195        rust_type.clone()
196    };
197
198    Ok(ParsedField {
199        rust_name,
200        column_name,
201        rust_type,
202        is_nullable,
203        inner_type,
204        is_primary_key,
205    })
206}
207
208/// Check for `#[sqlx_gen(primary_key)]` on a field.
209fn has_sqlx_gen_primary_key(attrs: &[syn::Attribute]) -> bool {
210    for attr in attrs {
211        if attr.path().is_ident("sqlx_gen") {
212            let tokens = attr.meta.to_token_stream().to_string();
213            if tokens.contains("primary_key") {
214                return true;
215            }
216        }
217    }
218    false
219}
220
221/// Extract `#[sqlx(rename = "...")]` value from field attributes.
222fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
223    for attr in attrs {
224        if attr.path().is_ident("sqlx") {
225            let tokens = attr.meta.to_token_stream().to_string();
226            return extract_attr_value(&tokens, "rename");
227        }
228    }
229    None
230}
231
232/// Check if a type is `Option<T>` and extract the inner type.
233fn extract_option_type(ty: &syn::Type) -> (bool, String) {
234    if let syn::Type::Path(type_path) = ty {
235        if let Some(segment) = type_path.path.segments.last() {
236            if segment.ident == "Option" {
237                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
238                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
239                        return (true, inner.to_token_stream().to_string());
240                    }
241                }
242            }
243        }
244    }
245    (false, String::new())
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    // --- basic parsing ---
253
254    #[test]
255    fn test_parse_simple_table() {
256        let source = r#"
257            #[derive(Debug, Clone, sqlx::FromRow)]
258            #[sqlx_gen(kind = "table", table = "users")]
259            pub struct Users {
260                pub id: i32,
261                pub name: String,
262            }
263        "#;
264        let entity = parse_entity_source(source).unwrap();
265        assert_eq!(entity.struct_name, "Users");
266        assert_eq!(entity.table_name, "users");
267        assert!(!entity.is_view);
268        assert_eq!(entity.fields.len(), 2);
269    }
270
271    #[test]
272    fn test_parse_view() {
273        let source = r#"
274            #[derive(Debug, Clone, sqlx::FromRow)]
275            #[sqlx_gen(kind = "view", table = "active_users")]
276            pub struct ActiveUsers {
277                pub id: i32,
278            }
279        "#;
280        let entity = parse_entity_source(source).unwrap();
281        assert!(entity.is_view);
282        assert_eq!(entity.table_name, "active_users");
283    }
284
285    #[test]
286    fn test_parse_table_not_view() {
287        let source = r#"
288            #[derive(Debug, Clone, sqlx::FromRow)]
289            #[sqlx_gen(kind = "table", table = "users")]
290            pub struct Users {
291                pub id: i32,
292            }
293        "#;
294        let entity = parse_entity_source(source).unwrap();
295        assert!(!entity.is_view);
296    }
297
298    // --- primary key ---
299
300    #[test]
301    fn test_parse_primary_key() {
302        let source = r#"
303            #[derive(Debug, Clone, sqlx::FromRow)]
304            #[sqlx_gen(kind = "table", table = "users")]
305            pub struct Users {
306                #[sqlx_gen(primary_key)]
307                pub id: i32,
308                pub name: String,
309            }
310        "#;
311        let entity = parse_entity_source(source).unwrap();
312        assert!(entity.fields[0].is_primary_key);
313        assert!(!entity.fields[1].is_primary_key);
314    }
315
316    #[test]
317    fn test_composite_primary_key() {
318        let source = r#"
319            #[derive(Debug, Clone, sqlx::FromRow)]
320            #[sqlx_gen(kind = "table", table = "user_roles")]
321            pub struct UserRoles {
322                #[sqlx_gen(primary_key)]
323                pub user_id: i32,
324                #[sqlx_gen(primary_key)]
325                pub role_id: i32,
326            }
327        "#;
328        let entity = parse_entity_source(source).unwrap();
329        assert!(entity.fields[0].is_primary_key);
330        assert!(entity.fields[1].is_primary_key);
331    }
332
333    #[test]
334    fn test_no_primary_key() {
335        let source = r#"
336            #[derive(Debug, Clone, sqlx::FromRow)]
337            #[sqlx_gen(kind = "table", table = "logs")]
338            pub struct Logs {
339                pub message: String,
340            }
341        "#;
342        let entity = parse_entity_source(source).unwrap();
343        assert!(!entity.fields[0].is_primary_key);
344    }
345
346    // --- sqlx rename ---
347
348    #[test]
349    fn test_sqlx_rename() {
350        let source = r#"
351            #[derive(Debug, Clone, sqlx::FromRow)]
352            #[sqlx_gen(kind = "table", table = "connector")]
353            pub struct Connector {
354                #[sqlx(rename = "type")]
355                pub connector_type: String,
356            }
357        "#;
358        let entity = parse_entity_source(source).unwrap();
359        assert_eq!(entity.fields[0].rust_name, "connector_type");
360        assert_eq!(entity.fields[0].column_name, "type");
361    }
362
363    #[test]
364    fn test_no_rename_uses_field_name() {
365        let source = r#"
366            #[derive(Debug, Clone, sqlx::FromRow)]
367            #[sqlx_gen(kind = "table", table = "users")]
368            pub struct Users {
369                pub name: String,
370            }
371        "#;
372        let entity = parse_entity_source(source).unwrap();
373        assert_eq!(entity.fields[0].rust_name, "name");
374        assert_eq!(entity.fields[0].column_name, "name");
375    }
376
377    // --- nullable types ---
378
379    #[test]
380    fn test_option_field_nullable() {
381        let source = r#"
382            #[derive(Debug, Clone, sqlx::FromRow)]
383            #[sqlx_gen(kind = "table", table = "users")]
384            pub struct Users {
385                pub email: Option<String>,
386            }
387        "#;
388        let entity = parse_entity_source(source).unwrap();
389        assert!(entity.fields[0].is_nullable);
390        assert_eq!(entity.fields[0].inner_type, "String");
391    }
392
393    #[test]
394    fn test_non_option_not_nullable() {
395        let source = r#"
396            #[derive(Debug, Clone, sqlx::FromRow)]
397            #[sqlx_gen(kind = "table", table = "users")]
398            pub struct Users {
399                pub id: i32,
400            }
401        "#;
402        let entity = parse_entity_source(source).unwrap();
403        assert!(!entity.fields[0].is_nullable);
404        assert_eq!(entity.fields[0].inner_type, "i32");
405    }
406
407    #[test]
408    fn test_option_complex_type() {
409        let source = r#"
410            #[derive(Debug, Clone, sqlx::FromRow)]
411            #[sqlx_gen(kind = "table", table = "users")]
412            pub struct Users {
413                pub created_at: Option<chrono::NaiveDateTime>,
414            }
415        "#;
416        let entity = parse_entity_source(source).unwrap();
417        assert!(entity.fields[0].is_nullable);
418        assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
419    }
420
421    // --- type preservation ---
422
423    #[test]
424    fn test_rust_type_preserved() {
425        let source = r#"
426            #[derive(Debug, Clone, sqlx::FromRow)]
427            #[sqlx_gen(kind = "table", table = "users")]
428            pub struct Users {
429                pub id: uuid::Uuid,
430            }
431        "#;
432        let entity = parse_entity_source(source).unwrap();
433        assert!(entity.fields[0].rust_type.contains("Uuid"));
434    }
435
436    // --- error cases ---
437
438    #[test]
439    fn test_no_from_row_struct() {
440        let source = r#"
441            pub struct NotAnEntity {
442                pub id: i32,
443            }
444        "#;
445        let result = parse_entity_source(source);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn test_empty_source() {
451        let result = parse_entity_source("");
452        assert!(result.is_err());
453    }
454
455    // --- fallback table name ---
456
457    #[test]
458    fn test_fallback_table_name_to_struct_name() {
459        let source = r#"
460            #[derive(Debug, Clone, sqlx::FromRow)]
461            pub struct Users {
462                pub id: i32,
463            }
464        "#;
465        let entity = parse_entity_source(source).unwrap();
466        assert_eq!(entity.table_name, "Users");
467    }
468
469    // --- combined attributes ---
470
471    #[test]
472    fn test_pk_with_rename() {
473        let source = r#"
474            #[derive(Debug, Clone, sqlx::FromRow)]
475            #[sqlx_gen(kind = "table", table = "items")]
476            pub struct Items {
477                #[sqlx_gen(primary_key)]
478                #[sqlx(rename = "itemID")]
479                pub item_id: i32,
480            }
481        "#;
482        let entity = parse_entity_source(source).unwrap();
483        let f = &entity.fields[0];
484        assert!(f.is_primary_key);
485        assert_eq!(f.column_name, "itemID");
486        assert_eq!(f.rust_name, "item_id");
487    }
488
489    #[test]
490    fn test_full_entity() {
491        let source = r#"
492            #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
493            #[sqlx_gen(kind = "table", table = "users")]
494            pub struct Users {
495                #[sqlx_gen(primary_key)]
496                pub id: i32,
497                pub name: String,
498                pub email: Option<String>,
499                #[sqlx(rename = "createdAt")]
500                pub created_at: chrono::NaiveDateTime,
501            }
502        "#;
503        let entity = parse_entity_source(source).unwrap();
504        assert_eq!(entity.struct_name, "Users");
505        assert_eq!(entity.table_name, "users");
506        assert!(!entity.is_view);
507        assert_eq!(entity.fields.len(), 4);
508
509        assert!(entity.fields[0].is_primary_key);
510        assert_eq!(entity.fields[0].rust_name, "id");
511
512        assert!(!entity.fields[1].is_primary_key);
513        assert_eq!(entity.fields[1].rust_type, "String");
514
515        assert!(entity.fields[2].is_nullable);
516        assert_eq!(entity.fields[2].inner_type, "String");
517
518        assert_eq!(entity.fields[3].column_name, "createdAt");
519        assert_eq!(entity.fields[3].rust_name, "created_at");
520    }
521
522    // --- imports extraction ---
523
524    #[test]
525    fn test_imports_extracted() {
526        let source = r#"
527            use chrono::{DateTime, Utc};
528            use uuid::Uuid;
529            use serde::{Serialize, Deserialize};
530
531            #[derive(Debug, Clone, sqlx::FromRow)]
532            #[sqlx_gen(kind = "table", table = "users")]
533            pub struct Users {
534                pub id: Uuid,
535                pub created_at: DateTime<Utc>,
536            }
537        "#;
538        let entity = parse_entity_source(source).unwrap();
539        assert_eq!(entity.imports.len(), 2);
540        assert!(entity.imports.iter().any(|i| i.contains("chrono")));
541        assert!(entity.imports.iter().any(|i| i.contains("uuid")));
542        // serde should be excluded
543        assert!(!entity.imports.iter().any(|i| i.contains("serde")));
544    }
545
546    #[test]
547    fn test_imports_empty_when_none() {
548        let source = r#"
549            #[derive(Debug, Clone, sqlx::FromRow)]
550            #[sqlx_gen(kind = "table", table = "users")]
551            pub struct Users {
552                pub id: i32,
553            }
554        "#;
555        let entity = parse_entity_source(source).unwrap();
556        assert!(entity.imports.is_empty());
557    }
558
559    #[test]
560    fn test_imports_keep_serde_json() {
561        let source = r#"
562            use serde::{Serialize, Deserialize};
563            use serde_json::Value;
564
565            #[derive(Debug, Clone, sqlx::FromRow)]
566            #[sqlx_gen(kind = "table", table = "users")]
567            pub struct Users {
568                pub data: Value,
569            }
570        "#;
571        let entity = parse_entity_source(source).unwrap();
572        assert_eq!(entity.imports.len(), 1);
573        assert!(entity.imports[0].contains("serde_json"));
574    }
575
576    #[test]
577    fn test_imports_exclude_sqlx() {
578        let source = r#"
579            use sqlx::types::Uuid;
580            use chrono::NaiveDateTime;
581
582            #[derive(Debug, Clone, sqlx::FromRow)]
583            #[sqlx_gen(kind = "table", table = "users")]
584            pub struct Users {
585                pub id: i32,
586            }
587        "#;
588        let entity = parse_entity_source(source).unwrap();
589        assert_eq!(entity.imports.len(), 1);
590        assert!(entity.imports[0].contains("chrono"));
591    }
592}