Skip to main content

sqlx_gen/codegen/
entity_parser.rs

1use std::path::Path;
2
3use quote::ToTokens;
4
5/// Represents a field parsed from a generated entity struct.
6#[derive(Debug, Clone)]
7pub struct ParsedField {
8    /// Rust field name (e.g. "connector_type")
9    pub rust_name: String,
10    /// Original DB column name. From `#[sqlx(rename = "...")]` if present, otherwise same as rust_name.
11    pub column_name: String,
12    /// Full Rust type as a string (e.g. "Option<String>", "i32", "uuid::Uuid")
13    pub rust_type: String,
14    /// Whether the type is wrapped in Option<T>
15    pub is_nullable: bool,
16    /// The inner type if nullable, or the full type if not
17    pub inner_type: String,
18    /// Whether this field is a primary key (`#[sqlx_gen(primary_key)]`)
19    pub is_primary_key: bool,
20}
21
22/// Represents an entity parsed from a generated Rust file.
23#[derive(Debug, Clone)]
24pub struct ParsedEntity {
25    /// Struct name in PascalCase (e.g. "Users", "UserRoles")
26    pub struct_name: String,
27    /// Original table/view name from `#[sqlx_gen(table = "...")]`
28    pub table_name: String,
29    /// Whether this entity represents a view (`#[sqlx_gen(kind = "view")]`)
30    pub is_view: bool,
31    /// Parsed fields
32    pub fields: Vec<ParsedField>,
33    /// `use` imports from the entity source file (e.g. "use chrono::{DateTime, Utc};")
34    pub imports: Vec<String>,
35}
36
37/// Parse an entity struct from a `.rs` file on disk.
38pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
39    let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
40    parse_entity_source(&source).map_err(|e| {
41        crate::error::Error::Config(format!("{}: {}", path.display(), e))
42    })
43}
44
45/// Parse an entity struct from a Rust source string.
46pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
47    let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
48
49    // Collect use imports (excluding serde and sqlx derives)
50    let imports = extract_use_imports(&syntax);
51
52    for item in &syntax.items {
53        if let syn::Item::Struct(item_struct) = item {
54            if has_from_row_derive(item_struct) {
55                let mut entity = extract_entity(item_struct)?;
56                entity.imports = imports;
57                return Ok(entity);
58            }
59        }
60    }
61
62    Err("No struct with sqlx::FromRow derive found".to_string())
63}
64
65/// Check if a struct has `sqlx::FromRow` in its derive attributes.
66fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
67    for attr in &item.attrs {
68        if attr.path().is_ident("derive") {
69            let tokens = attr.meta.to_token_stream().to_string();
70            if tokens.contains("FromRow") {
71                return true;
72            }
73        }
74    }
75    false
76}
77
78/// Extract `use` imports from the source file, excluding serde/sqlx imports
79/// (those are already handled by the CRUD generator).
80fn extract_use_imports(file: &syn::File) -> Vec<String> {
81    file.items
82        .iter()
83        .filter_map(|item| {
84            if let syn::Item::Use(use_item) = item {
85                let text = use_item.to_token_stream().to_string();
86                // Skip serde and sqlx imports — the CRUD generator adds those itself
87                if text.contains("serde") || text.contains("sqlx") {
88                    return None;
89                }
90                // Normalize spacing: "use chrono :: { DateTime , Utc } ;" → cleaned up
91                let normalized = normalize_use_statement(&text);
92                Some(normalized)
93            } else {
94                None
95            }
96        })
97        .collect()
98}
99
100/// Normalize a tokenized `use` statement by removing extra spaces around `::`, `{`, `}`, and `,`.
101fn normalize_use_statement(s: &str) -> String {
102    s.replace(" :: ", "::")
103        .replace(":: ", "::")
104        .replace(" ::", "::")
105        .replace("{ ", "{")
106        .replace(" }", "}")
107        .replace(" ,", ",")
108        .replace(" ;", ";")
109}
110
111/// Extract a ParsedEntity from a struct item.
112fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
113    let struct_name = item.ident.to_string();
114
115    let (kind, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
116    let is_view = kind.as_deref() == Some("view");
117
118    // Fall back to struct name if no table annotation
119    let table_name = table_name.unwrap_or_else(|| struct_name.clone());
120
121    let fields = match &item.fields {
122        syn::Fields::Named(named) => {
123            named
124                .named
125                .iter()
126                .map(extract_field)
127                .collect::<Result<Vec<_>, _>>()?
128        }
129        _ => return Err("Expected named fields".to_string()),
130    };
131
132    Ok(ParsedEntity {
133        struct_name,
134        table_name,
135        is_view,
136        fields,
137        imports: Vec::new(), // filled by parse_entity_source
138    })
139}
140
141/// Parse `#[sqlx_gen(kind = "...", table = "...")]` from struct attributes.
142/// Returns (kind, table_name).
143fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>) {
144    let mut kind = None;
145    let mut table_name = None;
146
147    for attr in attrs {
148        if attr.path().is_ident("sqlx_gen") {
149            let tokens = attr.meta.to_token_stream().to_string();
150            if let Some(k) = extract_attr_value(&tokens, "kind") {
151                kind = Some(k);
152            }
153            if let Some(t) = extract_attr_value(&tokens, "table") {
154                table_name = Some(t);
155            }
156        }
157    }
158
159    (kind, table_name)
160}
161
162/// Extract a named string value from an attribute token string.
163/// e.g. extract_attr_value(`sqlx_gen(kind = "view", table = "users")`, "kind") -> Some("view")
164fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
165    let pattern = format!("{} = \"", key);
166    let start = tokens.find(&pattern)? + pattern.len();
167    let rest = &tokens[start..];
168    let end = rest.find('"')?;
169    Some(rest[..end].to_string())
170}
171
172/// Extract a ParsedField from a syn::Field.
173fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
174    let rust_name = field
175        .ident
176        .as_ref()
177        .ok_or("Unnamed field")?
178        .to_string();
179
180    let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
181    let is_primary_key = has_sqlx_gen_primary_key(&field.attrs);
182
183    let rust_type = field.ty.to_token_stream().to_string();
184    let (is_nullable, inner_type) = extract_option_type(&field.ty);
185    let inner_type = if is_nullable {
186        inner_type
187    } else {
188        rust_type.clone()
189    };
190
191    Ok(ParsedField {
192        rust_name,
193        column_name,
194        rust_type,
195        is_nullable,
196        inner_type,
197        is_primary_key,
198    })
199}
200
201/// Check for `#[sqlx_gen(primary_key)]` on a field.
202fn has_sqlx_gen_primary_key(attrs: &[syn::Attribute]) -> bool {
203    for attr in attrs {
204        if attr.path().is_ident("sqlx_gen") {
205            let tokens = attr.meta.to_token_stream().to_string();
206            if tokens.contains("primary_key") {
207                return true;
208            }
209        }
210    }
211    false
212}
213
214/// Extract `#[sqlx(rename = "...")]` value from field attributes.
215fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
216    for attr in attrs {
217        if attr.path().is_ident("sqlx") {
218            let tokens = attr.meta.to_token_stream().to_string();
219            return extract_attr_value(&tokens, "rename");
220        }
221    }
222    None
223}
224
225/// Check if a type is `Option<T>` and extract the inner type.
226fn extract_option_type(ty: &syn::Type) -> (bool, String) {
227    if let syn::Type::Path(type_path) = ty {
228        if let Some(segment) = type_path.path.segments.last() {
229            if segment.ident == "Option" {
230                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
231                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
232                        return (true, inner.to_token_stream().to_string());
233                    }
234                }
235            }
236        }
237    }
238    (false, String::new())
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    // --- basic parsing ---
246
247    #[test]
248    fn test_parse_simple_table() {
249        let source = r#"
250            #[derive(Debug, Clone, sqlx::FromRow)]
251            #[sqlx_gen(kind = "table", table = "users")]
252            pub struct Users {
253                pub id: i32,
254                pub name: String,
255            }
256        "#;
257        let entity = parse_entity_source(source).unwrap();
258        assert_eq!(entity.struct_name, "Users");
259        assert_eq!(entity.table_name, "users");
260        assert!(!entity.is_view);
261        assert_eq!(entity.fields.len(), 2);
262    }
263
264    #[test]
265    fn test_parse_view() {
266        let source = r#"
267            #[derive(Debug, Clone, sqlx::FromRow)]
268            #[sqlx_gen(kind = "view", table = "active_users")]
269            pub struct ActiveUsers {
270                pub id: i32,
271            }
272        "#;
273        let entity = parse_entity_source(source).unwrap();
274        assert!(entity.is_view);
275        assert_eq!(entity.table_name, "active_users");
276    }
277
278    #[test]
279    fn test_parse_table_not_view() {
280        let source = r#"
281            #[derive(Debug, Clone, sqlx::FromRow)]
282            #[sqlx_gen(kind = "table", table = "users")]
283            pub struct Users {
284                pub id: i32,
285            }
286        "#;
287        let entity = parse_entity_source(source).unwrap();
288        assert!(!entity.is_view);
289    }
290
291    // --- primary key ---
292
293    #[test]
294    fn test_parse_primary_key() {
295        let source = r#"
296            #[derive(Debug, Clone, sqlx::FromRow)]
297            #[sqlx_gen(kind = "table", table = "users")]
298            pub struct Users {
299                #[sqlx_gen(primary_key)]
300                pub id: i32,
301                pub name: String,
302            }
303        "#;
304        let entity = parse_entity_source(source).unwrap();
305        assert!(entity.fields[0].is_primary_key);
306        assert!(!entity.fields[1].is_primary_key);
307    }
308
309    #[test]
310    fn test_composite_primary_key() {
311        let source = r#"
312            #[derive(Debug, Clone, sqlx::FromRow)]
313            #[sqlx_gen(kind = "table", table = "user_roles")]
314            pub struct UserRoles {
315                #[sqlx_gen(primary_key)]
316                pub user_id: i32,
317                #[sqlx_gen(primary_key)]
318                pub role_id: i32,
319            }
320        "#;
321        let entity = parse_entity_source(source).unwrap();
322        assert!(entity.fields[0].is_primary_key);
323        assert!(entity.fields[1].is_primary_key);
324    }
325
326    #[test]
327    fn test_no_primary_key() {
328        let source = r#"
329            #[derive(Debug, Clone, sqlx::FromRow)]
330            #[sqlx_gen(kind = "table", table = "logs")]
331            pub struct Logs {
332                pub message: String,
333            }
334        "#;
335        let entity = parse_entity_source(source).unwrap();
336        assert!(!entity.fields[0].is_primary_key);
337    }
338
339    // --- sqlx rename ---
340
341    #[test]
342    fn test_sqlx_rename() {
343        let source = r#"
344            #[derive(Debug, Clone, sqlx::FromRow)]
345            #[sqlx_gen(kind = "table", table = "connector")]
346            pub struct Connector {
347                #[sqlx(rename = "type")]
348                pub connector_type: String,
349            }
350        "#;
351        let entity = parse_entity_source(source).unwrap();
352        assert_eq!(entity.fields[0].rust_name, "connector_type");
353        assert_eq!(entity.fields[0].column_name, "type");
354    }
355
356    #[test]
357    fn test_no_rename_uses_field_name() {
358        let source = r#"
359            #[derive(Debug, Clone, sqlx::FromRow)]
360            #[sqlx_gen(kind = "table", table = "users")]
361            pub struct Users {
362                pub name: String,
363            }
364        "#;
365        let entity = parse_entity_source(source).unwrap();
366        assert_eq!(entity.fields[0].rust_name, "name");
367        assert_eq!(entity.fields[0].column_name, "name");
368    }
369
370    // --- nullable types ---
371
372    #[test]
373    fn test_option_field_nullable() {
374        let source = r#"
375            #[derive(Debug, Clone, sqlx::FromRow)]
376            #[sqlx_gen(kind = "table", table = "users")]
377            pub struct Users {
378                pub email: Option<String>,
379            }
380        "#;
381        let entity = parse_entity_source(source).unwrap();
382        assert!(entity.fields[0].is_nullable);
383        assert_eq!(entity.fields[0].inner_type, "String");
384    }
385
386    #[test]
387    fn test_non_option_not_nullable() {
388        let source = r#"
389            #[derive(Debug, Clone, sqlx::FromRow)]
390            #[sqlx_gen(kind = "table", table = "users")]
391            pub struct Users {
392                pub id: i32,
393            }
394        "#;
395        let entity = parse_entity_source(source).unwrap();
396        assert!(!entity.fields[0].is_nullable);
397        assert_eq!(entity.fields[0].inner_type, "i32");
398    }
399
400    #[test]
401    fn test_option_complex_type() {
402        let source = r#"
403            #[derive(Debug, Clone, sqlx::FromRow)]
404            #[sqlx_gen(kind = "table", table = "users")]
405            pub struct Users {
406                pub created_at: Option<chrono::NaiveDateTime>,
407            }
408        "#;
409        let entity = parse_entity_source(source).unwrap();
410        assert!(entity.fields[0].is_nullable);
411        assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
412    }
413
414    // --- type preservation ---
415
416    #[test]
417    fn test_rust_type_preserved() {
418        let source = r#"
419            #[derive(Debug, Clone, sqlx::FromRow)]
420            #[sqlx_gen(kind = "table", table = "users")]
421            pub struct Users {
422                pub id: uuid::Uuid,
423            }
424        "#;
425        let entity = parse_entity_source(source).unwrap();
426        assert!(entity.fields[0].rust_type.contains("Uuid"));
427    }
428
429    // --- error cases ---
430
431    #[test]
432    fn test_no_from_row_struct() {
433        let source = r#"
434            pub struct NotAnEntity {
435                pub id: i32,
436            }
437        "#;
438        let result = parse_entity_source(source);
439        assert!(result.is_err());
440    }
441
442    #[test]
443    fn test_empty_source() {
444        let result = parse_entity_source("");
445        assert!(result.is_err());
446    }
447
448    // --- fallback table name ---
449
450    #[test]
451    fn test_fallback_table_name_to_struct_name() {
452        let source = r#"
453            #[derive(Debug, Clone, sqlx::FromRow)]
454            pub struct Users {
455                pub id: i32,
456            }
457        "#;
458        let entity = parse_entity_source(source).unwrap();
459        assert_eq!(entity.table_name, "Users");
460    }
461
462    // --- combined attributes ---
463
464    #[test]
465    fn test_pk_with_rename() {
466        let source = r#"
467            #[derive(Debug, Clone, sqlx::FromRow)]
468            #[sqlx_gen(kind = "table", table = "items")]
469            pub struct Items {
470                #[sqlx_gen(primary_key)]
471                #[sqlx(rename = "itemID")]
472                pub item_id: i32,
473            }
474        "#;
475        let entity = parse_entity_source(source).unwrap();
476        let f = &entity.fields[0];
477        assert!(f.is_primary_key);
478        assert_eq!(f.column_name, "itemID");
479        assert_eq!(f.rust_name, "item_id");
480    }
481
482    #[test]
483    fn test_full_entity() {
484        let source = r#"
485            #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
486            #[sqlx_gen(kind = "table", table = "users")]
487            pub struct Users {
488                #[sqlx_gen(primary_key)]
489                pub id: i32,
490                pub name: String,
491                pub email: Option<String>,
492                #[sqlx(rename = "createdAt")]
493                pub created_at: chrono::NaiveDateTime,
494            }
495        "#;
496        let entity = parse_entity_source(source).unwrap();
497        assert_eq!(entity.struct_name, "Users");
498        assert_eq!(entity.table_name, "users");
499        assert!(!entity.is_view);
500        assert_eq!(entity.fields.len(), 4);
501
502        assert!(entity.fields[0].is_primary_key);
503        assert_eq!(entity.fields[0].rust_name, "id");
504
505        assert!(!entity.fields[1].is_primary_key);
506        assert_eq!(entity.fields[1].rust_type, "String");
507
508        assert!(entity.fields[2].is_nullable);
509        assert_eq!(entity.fields[2].inner_type, "String");
510
511        assert_eq!(entity.fields[3].column_name, "createdAt");
512        assert_eq!(entity.fields[3].rust_name, "created_at");
513    }
514
515    // --- imports extraction ---
516
517    #[test]
518    fn test_imports_extracted() {
519        let source = r#"
520            use chrono::{DateTime, Utc};
521            use uuid::Uuid;
522            use serde::{Serialize, Deserialize};
523
524            #[derive(Debug, Clone, sqlx::FromRow)]
525            #[sqlx_gen(kind = "table", table = "users")]
526            pub struct Users {
527                pub id: Uuid,
528                pub created_at: DateTime<Utc>,
529            }
530        "#;
531        let entity = parse_entity_source(source).unwrap();
532        assert_eq!(entity.imports.len(), 2);
533        assert!(entity.imports.iter().any(|i| i.contains("chrono")));
534        assert!(entity.imports.iter().any(|i| i.contains("uuid")));
535        // serde should be excluded
536        assert!(!entity.imports.iter().any(|i| i.contains("serde")));
537    }
538
539    #[test]
540    fn test_imports_empty_when_none() {
541        let source = r#"
542            #[derive(Debug, Clone, sqlx::FromRow)]
543            #[sqlx_gen(kind = "table", table = "users")]
544            pub struct Users {
545                pub id: i32,
546            }
547        "#;
548        let entity = parse_entity_source(source).unwrap();
549        assert!(entity.imports.is_empty());
550    }
551
552    #[test]
553    fn test_imports_exclude_sqlx() {
554        let source = r#"
555            use sqlx::types::Uuid;
556            use chrono::NaiveDateTime;
557
558            #[derive(Debug, Clone, sqlx::FromRow)]
559            #[sqlx_gen(kind = "table", table = "users")]
560            pub struct Users {
561                pub id: i32,
562            }
563        "#;
564        let entity = parse_entity_source(source).unwrap();
565        assert_eq!(entity.imports.len(), 1);
566        assert!(entity.imports[0].contains("chrono"));
567    }
568}