1use std::path::Path;
2
3use quote::ToTokens;
4
5#[derive(Debug, Clone)]
7pub struct ParsedField {
8 pub rust_name: String,
10 pub column_name: String,
12 pub rust_type: String,
14 pub is_nullable: bool,
16 pub inner_type: String,
18 pub is_primary_key: bool,
20 pub sql_type: Option<String>,
22 pub is_sql_array: bool,
24 pub column_default: Option<String>,
26}
27
28#[derive(Debug, Clone)]
30pub struct ParsedEntity {
31 pub struct_name: String,
33 pub table_name: String,
35 pub schema_name: Option<String>,
37 pub is_view: bool,
39 pub fields: Vec<ParsedField>,
41 pub imports: Vec<String>,
43}
44
45pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
47 let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
48 parse_entity_source(&source).map_err(|e| {
49 crate::error::Error::Config(format!("{}: {}", path.display(), e))
50 })
51}
52
53pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
55 let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
56
57 let imports = extract_use_imports(&syntax);
59
60 for item in &syntax.items {
61 if let syn::Item::Struct(item_struct) = item {
62 if has_from_row_derive(item_struct) {
63 let mut entity = extract_entity(item_struct)?;
64 entity.imports = imports;
65 return Ok(entity);
66 }
67 }
68 }
69
70 Err("No struct with sqlx::FromRow derive found".to_string())
71}
72
73fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
75 for attr in &item.attrs {
76 if attr.path().is_ident("derive") {
77 let tokens = attr.meta.to_token_stream().to_string();
78 if tokens.contains("FromRow") {
79 return true;
80 }
81 }
82 }
83 false
84}
85
86fn extract_use_imports(file: &syn::File) -> Vec<String> {
89 file.items
90 .iter()
91 .filter_map(|item| {
92 if let syn::Item::Use(use_item) = item {
93 let text = use_item.to_token_stream().to_string();
94 if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
96 return None;
97 }
98 let normalized = normalize_use_statement(&text);
100 Some(normalized)
101 } else {
102 None
103 }
104 })
105 .collect()
106}
107
108fn normalize_use_statement(s: &str) -> String {
110 s.replace(" :: ", "::")
111 .replace(":: ", "::")
112 .replace(" ::", "::")
113 .replace("{ ", "{")
114 .replace(" }", "}")
115 .replace(" ,", ",")
116 .replace(" ;", ";")
117}
118
119fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
121 let struct_name = item.ident.to_string();
122
123 let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
124 let is_view = kind.as_deref() == Some("view");
125
126 let table_name = table_name.unwrap_or_else(|| struct_name.clone());
128
129 let fields = match &item.fields {
130 syn::Fields::Named(named) => {
131 named
132 .named
133 .iter()
134 .map(extract_field)
135 .collect::<Result<Vec<_>, _>>()?
136 }
137 _ => return Err("Expected named fields".to_string()),
138 };
139
140 Ok(ParsedEntity {
141 struct_name,
142 table_name,
143 schema_name,
144 is_view,
145 fields,
146 imports: Vec::new(), })
148}
149
150fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
153 let mut kind = None;
154 let mut schema_name = None;
155 let mut table_name = None;
156
157 for attr in attrs {
158 if attr.path().is_ident("sqlx_gen") {
159 let tokens = attr.meta.to_token_stream().to_string();
160 if let Some(k) = extract_attr_value(&tokens, "kind") {
161 kind = Some(k);
162 }
163 if let Some(s) = extract_attr_value(&tokens, "schema") {
164 schema_name = Some(s);
165 }
166 if let Some(t) = extract_attr_value(&tokens, "table") {
167 table_name = Some(t);
168 }
169 }
170 }
171
172 (kind, schema_name, table_name)
173}
174
175fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
178 let pattern = format!("{} = \"", key);
179 let start = tokens.find(&pattern)? + pattern.len();
180 let rest = &tokens[start..];
181 let end = rest.find('"')?;
182 Some(rest[..end].to_string())
183}
184
185fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
187 let rust_name = field
188 .ident
189 .as_ref()
190 .ok_or("Unnamed field")?
191 .to_string();
192
193 let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
194 let (is_primary_key, sql_type, is_sql_array, column_default) = parse_sqlx_gen_field_attrs(&field.attrs);
195
196 let rust_type = field.ty.to_token_stream().to_string();
197 let (is_nullable, inner_type) = extract_option_type(&field.ty);
198 let inner_type = if is_nullable {
199 inner_type
200 } else {
201 rust_type.clone()
202 };
203
204 Ok(ParsedField {
205 rust_name,
206 column_name,
207 rust_type,
208 is_nullable,
209 inner_type,
210 is_primary_key,
211 sql_type,
212 is_sql_array,
213 column_default,
214 })
215}
216
217fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option<String>, bool, Option<String>) {
220 let mut is_pk = false;
221 let mut sql_type = None;
222 let mut is_array = false;
223 let mut column_default = None;
224
225 for attr in attrs {
226 if attr.path().is_ident("sqlx_gen") {
227 let tokens = attr.meta.to_token_stream().to_string();
228 if tokens.contains("primary_key") {
229 is_pk = true;
230 }
231 if let Some(t) = extract_attr_value(&tokens, "sql_type") {
232 sql_type = Some(t);
233 }
234 if tokens.contains("is_array") {
235 is_array = true;
236 }
237 if let Some(d) = extract_attr_value(&tokens, "column_default") {
238 column_default = Some(d);
239 }
240 }
241 }
242
243 (is_pk, sql_type, is_array, column_default)
244}
245
246fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
248 for attr in attrs {
249 if attr.path().is_ident("sqlx") {
250 let tokens = attr.meta.to_token_stream().to_string();
251 return extract_attr_value(&tokens, "rename");
252 }
253 }
254 None
255}
256
257fn extract_option_type(ty: &syn::Type) -> (bool, String) {
259 if let syn::Type::Path(type_path) = ty {
260 if let Some(segment) = type_path.path.segments.last() {
261 if segment.ident == "Option" {
262 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
263 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
264 return (true, inner.to_token_stream().to_string());
265 }
266 }
267 }
268 }
269 }
270 (false, String::new())
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
280 fn test_parse_simple_table() {
281 let source = r#"
282 #[derive(Debug, Clone, sqlx::FromRow)]
283 #[sqlx_gen(kind = "table", table = "users")]
284 pub struct Users {
285 pub id: i32,
286 pub name: String,
287 }
288 "#;
289 let entity = parse_entity_source(source).unwrap();
290 assert_eq!(entity.struct_name, "Users");
291 assert_eq!(entity.table_name, "users");
292 assert!(!entity.is_view);
293 assert_eq!(entity.fields.len(), 2);
294 }
295
296 #[test]
297 fn test_parse_view() {
298 let source = r#"
299 #[derive(Debug, Clone, sqlx::FromRow)]
300 #[sqlx_gen(kind = "view", table = "active_users")]
301 pub struct ActiveUsers {
302 pub id: i32,
303 }
304 "#;
305 let entity = parse_entity_source(source).unwrap();
306 assert!(entity.is_view);
307 assert_eq!(entity.table_name, "active_users");
308 }
309
310 #[test]
311 fn test_parse_table_not_view() {
312 let source = r#"
313 #[derive(Debug, Clone, sqlx::FromRow)]
314 #[sqlx_gen(kind = "table", table = "users")]
315 pub struct Users {
316 pub id: i32,
317 }
318 "#;
319 let entity = parse_entity_source(source).unwrap();
320 assert!(!entity.is_view);
321 }
322
323 #[test]
326 fn test_parse_primary_key() {
327 let source = r#"
328 #[derive(Debug, Clone, sqlx::FromRow)]
329 #[sqlx_gen(kind = "table", table = "users")]
330 pub struct Users {
331 #[sqlx_gen(primary_key)]
332 pub id: i32,
333 pub name: String,
334 }
335 "#;
336 let entity = parse_entity_source(source).unwrap();
337 assert!(entity.fields[0].is_primary_key);
338 assert!(!entity.fields[1].is_primary_key);
339 }
340
341 #[test]
342 fn test_composite_primary_key() {
343 let source = r#"
344 #[derive(Debug, Clone, sqlx::FromRow)]
345 #[sqlx_gen(kind = "table", table = "user_roles")]
346 pub struct UserRoles {
347 #[sqlx_gen(primary_key)]
348 pub user_id: i32,
349 #[sqlx_gen(primary_key)]
350 pub role_id: i32,
351 }
352 "#;
353 let entity = parse_entity_source(source).unwrap();
354 assert!(entity.fields[0].is_primary_key);
355 assert!(entity.fields[1].is_primary_key);
356 }
357
358 #[test]
359 fn test_no_primary_key() {
360 let source = r#"
361 #[derive(Debug, Clone, sqlx::FromRow)]
362 #[sqlx_gen(kind = "table", table = "logs")]
363 pub struct Logs {
364 pub message: String,
365 }
366 "#;
367 let entity = parse_entity_source(source).unwrap();
368 assert!(!entity.fields[0].is_primary_key);
369 }
370
371 #[test]
374 fn test_sqlx_rename() {
375 let source = r#"
376 #[derive(Debug, Clone, sqlx::FromRow)]
377 #[sqlx_gen(kind = "table", table = "connector")]
378 pub struct Connector {
379 #[sqlx(rename = "type")]
380 pub connector_type: String,
381 }
382 "#;
383 let entity = parse_entity_source(source).unwrap();
384 assert_eq!(entity.fields[0].rust_name, "connector_type");
385 assert_eq!(entity.fields[0].column_name, "type");
386 }
387
388 #[test]
389 fn test_no_rename_uses_field_name() {
390 let source = r#"
391 #[derive(Debug, Clone, sqlx::FromRow)]
392 #[sqlx_gen(kind = "table", table = "users")]
393 pub struct Users {
394 pub name: String,
395 }
396 "#;
397 let entity = parse_entity_source(source).unwrap();
398 assert_eq!(entity.fields[0].rust_name, "name");
399 assert_eq!(entity.fields[0].column_name, "name");
400 }
401
402 #[test]
405 fn test_option_field_nullable() {
406 let source = r#"
407 #[derive(Debug, Clone, sqlx::FromRow)]
408 #[sqlx_gen(kind = "table", table = "users")]
409 pub struct Users {
410 pub email: Option<String>,
411 }
412 "#;
413 let entity = parse_entity_source(source).unwrap();
414 assert!(entity.fields[0].is_nullable);
415 assert_eq!(entity.fields[0].inner_type, "String");
416 }
417
418 #[test]
419 fn test_non_option_not_nullable() {
420 let source = r#"
421 #[derive(Debug, Clone, sqlx::FromRow)]
422 #[sqlx_gen(kind = "table", table = "users")]
423 pub struct Users {
424 pub id: i32,
425 }
426 "#;
427 let entity = parse_entity_source(source).unwrap();
428 assert!(!entity.fields[0].is_nullable);
429 assert_eq!(entity.fields[0].inner_type, "i32");
430 }
431
432 #[test]
433 fn test_option_complex_type() {
434 let source = r#"
435 #[derive(Debug, Clone, sqlx::FromRow)]
436 #[sqlx_gen(kind = "table", table = "users")]
437 pub struct Users {
438 pub created_at: Option<chrono::NaiveDateTime>,
439 }
440 "#;
441 let entity = parse_entity_source(source).unwrap();
442 assert!(entity.fields[0].is_nullable);
443 assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
444 }
445
446 #[test]
449 fn test_rust_type_preserved() {
450 let source = r#"
451 #[derive(Debug, Clone, sqlx::FromRow)]
452 #[sqlx_gen(kind = "table", table = "users")]
453 pub struct Users {
454 pub id: uuid::Uuid,
455 }
456 "#;
457 let entity = parse_entity_source(source).unwrap();
458 assert!(entity.fields[0].rust_type.contains("Uuid"));
459 }
460
461 #[test]
464 fn test_no_from_row_struct() {
465 let source = r#"
466 pub struct NotAnEntity {
467 pub id: i32,
468 }
469 "#;
470 let result = parse_entity_source(source);
471 assert!(result.is_err());
472 }
473
474 #[test]
475 fn test_empty_source() {
476 let result = parse_entity_source("");
477 assert!(result.is_err());
478 }
479
480 #[test]
483 fn test_fallback_table_name_to_struct_name() {
484 let source = r#"
485 #[derive(Debug, Clone, sqlx::FromRow)]
486 pub struct Users {
487 pub id: i32,
488 }
489 "#;
490 let entity = parse_entity_source(source).unwrap();
491 assert_eq!(entity.table_name, "Users");
492 }
493
494 #[test]
497 fn test_pk_with_rename() {
498 let source = r#"
499 #[derive(Debug, Clone, sqlx::FromRow)]
500 #[sqlx_gen(kind = "table", table = "items")]
501 pub struct Items {
502 #[sqlx_gen(primary_key)]
503 #[sqlx(rename = "itemID")]
504 pub item_id: i32,
505 }
506 "#;
507 let entity = parse_entity_source(source).unwrap();
508 let f = &entity.fields[0];
509 assert!(f.is_primary_key);
510 assert_eq!(f.column_name, "itemID");
511 assert_eq!(f.rust_name, "item_id");
512 }
513
514 #[test]
515 fn test_full_entity() {
516 let source = r#"
517 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
518 #[sqlx_gen(kind = "table", table = "users")]
519 pub struct Users {
520 #[sqlx_gen(primary_key)]
521 pub id: i32,
522 pub name: String,
523 pub email: Option<String>,
524 #[sqlx(rename = "createdAt")]
525 pub created_at: chrono::NaiveDateTime,
526 }
527 "#;
528 let entity = parse_entity_source(source).unwrap();
529 assert_eq!(entity.struct_name, "Users");
530 assert_eq!(entity.table_name, "users");
531 assert!(!entity.is_view);
532 assert_eq!(entity.fields.len(), 4);
533
534 assert!(entity.fields[0].is_primary_key);
535 assert_eq!(entity.fields[0].rust_name, "id");
536
537 assert!(!entity.fields[1].is_primary_key);
538 assert_eq!(entity.fields[1].rust_type, "String");
539
540 assert!(entity.fields[2].is_nullable);
541 assert_eq!(entity.fields[2].inner_type, "String");
542
543 assert_eq!(entity.fields[3].column_name, "createdAt");
544 assert_eq!(entity.fields[3].rust_name, "created_at");
545 }
546
547 #[test]
550 fn test_imports_extracted() {
551 let source = r#"
552 use chrono::{DateTime, Utc};
553 use uuid::Uuid;
554 use serde::{Serialize, Deserialize};
555
556 #[derive(Debug, Clone, sqlx::FromRow)]
557 #[sqlx_gen(kind = "table", table = "users")]
558 pub struct Users {
559 pub id: Uuid,
560 pub created_at: DateTime<Utc>,
561 }
562 "#;
563 let entity = parse_entity_source(source).unwrap();
564 assert_eq!(entity.imports.len(), 2);
565 assert!(entity.imports.iter().any(|i| i.contains("chrono")));
566 assert!(entity.imports.iter().any(|i| i.contains("uuid")));
567 assert!(!entity.imports.iter().any(|i| i.contains("serde")));
569 }
570
571 #[test]
572 fn test_imports_empty_when_none() {
573 let source = r#"
574 #[derive(Debug, Clone, sqlx::FromRow)]
575 #[sqlx_gen(kind = "table", table = "users")]
576 pub struct Users {
577 pub id: i32,
578 }
579 "#;
580 let entity = parse_entity_source(source).unwrap();
581 assert!(entity.imports.is_empty());
582 }
583
584 #[test]
585 fn test_imports_keep_serde_json() {
586 let source = r#"
587 use serde::{Serialize, Deserialize};
588 use serde_json::Value;
589
590 #[derive(Debug, Clone, sqlx::FromRow)]
591 #[sqlx_gen(kind = "table", table = "users")]
592 pub struct Users {
593 pub data: Value,
594 }
595 "#;
596 let entity = parse_entity_source(source).unwrap();
597 assert_eq!(entity.imports.len(), 1);
598 assert!(entity.imports[0].contains("serde_json"));
599 }
600
601 #[test]
602 fn test_imports_exclude_sqlx() {
603 let source = r#"
604 use sqlx::types::Uuid;
605 use chrono::NaiveDateTime;
606
607 #[derive(Debug, Clone, sqlx::FromRow)]
608 #[sqlx_gen(kind = "table", table = "users")]
609 pub struct Users {
610 pub id: i32,
611 }
612 "#;
613 let entity = parse_entity_source(source).unwrap();
614 assert_eq!(entity.imports.len(), 1);
615 assert!(entity.imports[0].contains("chrono"));
616 }
617
618 #[test]
621 fn test_parse_column_default() {
622 let source = r#"
623 #[derive(Debug, Clone, sqlx::FromRow)]
624 #[sqlx_gen(kind = "table", table = "tasks")]
625 pub struct Tasks {
626 #[sqlx_gen(primary_key)]
627 pub id: i32,
628 #[sqlx_gen(column_default = "now()")]
629 pub created_at: String,
630 }
631 "#;
632 let entity = parse_entity_source(source).unwrap();
633 let created_at = &entity.fields[1];
634 assert_eq!(created_at.column_default, Some("now()".to_string()));
635 }
636
637 #[test]
638 fn test_parse_no_column_default() {
639 let source = r#"
640 #[derive(Debug, Clone, sqlx::FromRow)]
641 #[sqlx_gen(kind = "table", table = "tasks")]
642 pub struct Tasks {
643 #[sqlx_gen(primary_key)]
644 pub id: i32,
645 pub title: String,
646 }
647 "#;
648 let entity = parse_entity_source(source).unwrap();
649 let title = &entity.fields[1];
650 assert_eq!(title.column_default, None);
651 }
652
653 #[test]
654 fn test_parse_column_default_with_cast() {
655 let source = r#"
656 #[derive(Debug, Clone, sqlx::FromRow)]
657 #[sqlx_gen(kind = "table", table = "tasks")]
658 pub struct Tasks {
659 #[sqlx_gen(primary_key)]
660 pub id: i32,
661 #[sqlx_gen(column_default = "'idle'::task_status")]
662 pub status: String,
663 }
664 "#;
665 let entity = parse_entity_source(source).unwrap();
666 let status = &entity.fields[1];
667 assert_eq!(status.column_default, Some("'idle'::task_status".to_string()));
668 }
669}