1use std::path::Path;
2
3use quote::ToTokens;
4
5#[derive(Debug, Clone)]
7pub struct ParsedField {
8 pub rust_name: String,
10 pub column_name: String,
12 pub rust_type: String,
14 pub is_nullable: bool,
16 pub inner_type: String,
18 pub is_primary_key: bool,
20 pub sql_type: Option<String>,
22 pub is_sql_array: bool,
24 pub column_default: Option<String>,
26}
27
28#[derive(Debug, Clone)]
30pub struct ParsedEntity {
31 pub struct_name: String,
33 pub table_name: String,
35 pub schema_name: Option<String>,
37 pub is_view: bool,
39 pub fields: Vec<ParsedField>,
41 pub imports: Vec<String>,
43}
44
45pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
47 let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
48 parse_entity_source(&source).map_err(|e| {
49 crate::error::Error::Config(format!("{}: {}", path.display(), e))
50 })
51}
52
53pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
55 let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
56
57 let imports = extract_use_imports(&syntax);
59
60 for item in &syntax.items {
61 if let syn::Item::Struct(item_struct) = item {
62 if has_from_row_derive(item_struct) {
63 let mut entity = extract_entity(item_struct)?;
64 entity.imports = imports;
65 return Ok(entity);
66 }
67 }
68 }
69
70 Err("No struct with sqlx::FromRow derive found".to_string())
71}
72
73fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
75 for attr in &item.attrs {
76 if attr.path().is_ident("derive") {
77 let tokens = attr.meta.to_token_stream().to_string();
78 if tokens.contains("FromRow") {
79 return true;
80 }
81 }
82 }
83 false
84}
85
86fn extract_use_imports(file: &syn::File) -> Vec<String> {
89 file.items
90 .iter()
91 .filter_map(|item| {
92 if let syn::Item::Use(use_item) = item {
93 let text = use_item.to_token_stream().to_string();
94 if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
96 return None;
97 }
98 let normalized = normalize_use_statement(&text);
100 Some(normalized)
101 } else {
102 None
103 }
104 })
105 .collect()
106}
107
108fn normalize_use_statement(s: &str) -> String {
110 s.replace(" :: ", "::")
111 .replace(":: ", "::")
112 .replace(" ::", "::")
113 .replace("{ ", "{")
114 .replace(" }", "}")
115 .replace(" ,", ",")
116 .replace(" ;", ";")
117}
118
119fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
121 let struct_name = item.ident.to_string();
122
123 let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
124 let is_view = kind.as_deref() == Some("view");
125
126 let table_name = table_name.unwrap_or_else(|| struct_name.clone());
128
129 let fields = match &item.fields {
130 syn::Fields::Named(named) => {
131 named
132 .named
133 .iter()
134 .map(extract_field)
135 .collect::<Result<Vec<_>, _>>()?
136 }
137 _ => return Err("Expected named fields".to_string()),
138 };
139
140 Ok(ParsedEntity {
141 struct_name,
142 table_name,
143 schema_name,
144 is_view,
145 fields,
146 imports: Vec::new(), })
148}
149
150fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
153 let mut kind = None;
154 let mut schema_name = None;
155 let mut table_name = None;
156
157 for attr in attrs {
158 if attr.path().is_ident("sqlx_gen") {
159 let tokens = attr.meta.to_token_stream().to_string();
160 if let Some(k) = extract_attr_value(&tokens, "kind") {
161 kind = Some(k);
162 }
163 if let Some(s) = extract_attr_value(&tokens, "schema") {
164 schema_name = Some(s);
165 }
166 if let Some(t) = extract_attr_value(&tokens, "table") {
167 table_name = Some(t);
168 }
169 }
170 }
171
172 (kind, schema_name, table_name)
173}
174
175fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
178 let pattern = format!("{} = \"", key);
179 let start = tokens.find(&pattern)? + pattern.len();
180 let rest = &tokens[start..];
181 let bytes = rest.as_bytes();
182 let mut end = 0;
183 while end < bytes.len() {
184 if bytes[end] == b'"' && (end == 0 || bytes[end - 1] != b'\\') {
185 break;
186 }
187 end += 1;
188 }
189 if end >= bytes.len() {
190 return None;
191 }
192 Some(rest[..end].replace("\\\"", "\""))
193}
194
195fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
197 let rust_name = field
198 .ident
199 .as_ref()
200 .ok_or("Unnamed field")?
201 .to_string();
202
203 let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
204 let (is_primary_key, sql_type, is_sql_array, column_default) = parse_sqlx_gen_field_attrs(&field.attrs);
205
206 let rust_type = field.ty.to_token_stream().to_string();
207 let (is_nullable, inner_type) = extract_option_type(&field.ty);
208 let inner_type = if is_nullable {
209 inner_type
210 } else {
211 rust_type.clone()
212 };
213
214 Ok(ParsedField {
215 rust_name,
216 column_name,
217 rust_type,
218 is_nullable,
219 inner_type,
220 is_primary_key,
221 sql_type,
222 is_sql_array,
223 column_default,
224 })
225}
226
227fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option<String>, bool, Option<String>) {
230 let mut is_pk = false;
231 let mut sql_type = None;
232 let mut is_array = false;
233 let mut column_default = None;
234
235 for attr in attrs {
236 if attr.path().is_ident("sqlx_gen") {
237 let tokens = attr.meta.to_token_stream().to_string();
238 if tokens.contains("primary_key") {
239 is_pk = true;
240 }
241 if let Some(t) = extract_attr_value(&tokens, "sql_type") {
242 sql_type = Some(t);
243 }
244 if tokens.contains("is_array") {
245 is_array = true;
246 }
247 if let Some(d) = extract_attr_value(&tokens, "column_default") {
248 column_default = Some(d);
249 }
250 }
251 }
252
253 (is_pk, sql_type, is_array, column_default)
254}
255
256fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
258 for attr in attrs {
259 if attr.path().is_ident("sqlx") {
260 let tokens = attr.meta.to_token_stream().to_string();
261 return extract_attr_value(&tokens, "rename");
262 }
263 }
264 None
265}
266
267fn extract_option_type(ty: &syn::Type) -> (bool, String) {
269 if let syn::Type::Path(type_path) = ty {
270 if let Some(segment) = type_path.path.segments.last() {
271 if segment.ident == "Option" {
272 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
273 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
274 return (true, inner.to_token_stream().to_string());
275 }
276 }
277 }
278 }
279 }
280 (false, String::new())
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
290 fn test_parse_simple_table() {
291 let source = r#"
292 #[derive(Debug, Clone, sqlx::FromRow)]
293 #[sqlx_gen(kind = "table", table = "users")]
294 pub struct Users {
295 pub id: i32,
296 pub name: String,
297 }
298 "#;
299 let entity = parse_entity_source(source).unwrap();
300 assert_eq!(entity.struct_name, "Users");
301 assert_eq!(entity.table_name, "users");
302 assert!(!entity.is_view);
303 assert_eq!(entity.fields.len(), 2);
304 }
305
306 #[test]
307 fn test_parse_view() {
308 let source = r#"
309 #[derive(Debug, Clone, sqlx::FromRow)]
310 #[sqlx_gen(kind = "view", table = "active_users")]
311 pub struct ActiveUsers {
312 pub id: i32,
313 }
314 "#;
315 let entity = parse_entity_source(source).unwrap();
316 assert!(entity.is_view);
317 assert_eq!(entity.table_name, "active_users");
318 }
319
320 #[test]
321 fn test_parse_table_not_view() {
322 let source = r#"
323 #[derive(Debug, Clone, sqlx::FromRow)]
324 #[sqlx_gen(kind = "table", table = "users")]
325 pub struct Users {
326 pub id: i32,
327 }
328 "#;
329 let entity = parse_entity_source(source).unwrap();
330 assert!(!entity.is_view);
331 }
332
333 #[test]
336 fn test_parse_primary_key() {
337 let source = r#"
338 #[derive(Debug, Clone, sqlx::FromRow)]
339 #[sqlx_gen(kind = "table", table = "users")]
340 pub struct Users {
341 #[sqlx_gen(primary_key)]
342 pub id: i32,
343 pub name: String,
344 }
345 "#;
346 let entity = parse_entity_source(source).unwrap();
347 assert!(entity.fields[0].is_primary_key);
348 assert!(!entity.fields[1].is_primary_key);
349 }
350
351 #[test]
352 fn test_composite_primary_key() {
353 let source = r#"
354 #[derive(Debug, Clone, sqlx::FromRow)]
355 #[sqlx_gen(kind = "table", table = "user_roles")]
356 pub struct UserRoles {
357 #[sqlx_gen(primary_key)]
358 pub user_id: i32,
359 #[sqlx_gen(primary_key)]
360 pub role_id: i32,
361 }
362 "#;
363 let entity = parse_entity_source(source).unwrap();
364 assert!(entity.fields[0].is_primary_key);
365 assert!(entity.fields[1].is_primary_key);
366 }
367
368 #[test]
369 fn test_no_primary_key() {
370 let source = r#"
371 #[derive(Debug, Clone, sqlx::FromRow)]
372 #[sqlx_gen(kind = "table", table = "logs")]
373 pub struct Logs {
374 pub message: String,
375 }
376 "#;
377 let entity = parse_entity_source(source).unwrap();
378 assert!(!entity.fields[0].is_primary_key);
379 }
380
381 #[test]
384 fn test_sqlx_rename() {
385 let source = r#"
386 #[derive(Debug, Clone, sqlx::FromRow)]
387 #[sqlx_gen(kind = "table", table = "connector")]
388 pub struct Connector {
389 #[sqlx(rename = "type")]
390 pub connector_type: String,
391 }
392 "#;
393 let entity = parse_entity_source(source).unwrap();
394 assert_eq!(entity.fields[0].rust_name, "connector_type");
395 assert_eq!(entity.fields[0].column_name, "type");
396 }
397
398 #[test]
399 fn test_no_rename_uses_field_name() {
400 let source = r#"
401 #[derive(Debug, Clone, sqlx::FromRow)]
402 #[sqlx_gen(kind = "table", table = "users")]
403 pub struct Users {
404 pub name: String,
405 }
406 "#;
407 let entity = parse_entity_source(source).unwrap();
408 assert_eq!(entity.fields[0].rust_name, "name");
409 assert_eq!(entity.fields[0].column_name, "name");
410 }
411
412 #[test]
415 fn test_option_field_nullable() {
416 let source = r#"
417 #[derive(Debug, Clone, sqlx::FromRow)]
418 #[sqlx_gen(kind = "table", table = "users")]
419 pub struct Users {
420 pub email: Option<String>,
421 }
422 "#;
423 let entity = parse_entity_source(source).unwrap();
424 assert!(entity.fields[0].is_nullable);
425 assert_eq!(entity.fields[0].inner_type, "String");
426 }
427
428 #[test]
429 fn test_non_option_not_nullable() {
430 let source = r#"
431 #[derive(Debug, Clone, sqlx::FromRow)]
432 #[sqlx_gen(kind = "table", table = "users")]
433 pub struct Users {
434 pub id: i32,
435 }
436 "#;
437 let entity = parse_entity_source(source).unwrap();
438 assert!(!entity.fields[0].is_nullable);
439 assert_eq!(entity.fields[0].inner_type, "i32");
440 }
441
442 #[test]
443 fn test_option_complex_type() {
444 let source = r#"
445 #[derive(Debug, Clone, sqlx::FromRow)]
446 #[sqlx_gen(kind = "table", table = "users")]
447 pub struct Users {
448 pub created_at: Option<chrono::NaiveDateTime>,
449 }
450 "#;
451 let entity = parse_entity_source(source).unwrap();
452 assert!(entity.fields[0].is_nullable);
453 assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
454 }
455
456 #[test]
459 fn test_rust_type_preserved() {
460 let source = r#"
461 #[derive(Debug, Clone, sqlx::FromRow)]
462 #[sqlx_gen(kind = "table", table = "users")]
463 pub struct Users {
464 pub id: uuid::Uuid,
465 }
466 "#;
467 let entity = parse_entity_source(source).unwrap();
468 assert!(entity.fields[0].rust_type.contains("Uuid"));
469 }
470
471 #[test]
474 fn test_no_from_row_struct() {
475 let source = r#"
476 pub struct NotAnEntity {
477 pub id: i32,
478 }
479 "#;
480 let result = parse_entity_source(source);
481 assert!(result.is_err());
482 }
483
484 #[test]
485 fn test_empty_source() {
486 let result = parse_entity_source("");
487 assert!(result.is_err());
488 }
489
490 #[test]
493 fn test_fallback_table_name_to_struct_name() {
494 let source = r#"
495 #[derive(Debug, Clone, sqlx::FromRow)]
496 pub struct Users {
497 pub id: i32,
498 }
499 "#;
500 let entity = parse_entity_source(source).unwrap();
501 assert_eq!(entity.table_name, "Users");
502 }
503
504 #[test]
507 fn test_pk_with_rename() {
508 let source = r#"
509 #[derive(Debug, Clone, sqlx::FromRow)]
510 #[sqlx_gen(kind = "table", table = "items")]
511 pub struct Items {
512 #[sqlx_gen(primary_key)]
513 #[sqlx(rename = "itemID")]
514 pub item_id: i32,
515 }
516 "#;
517 let entity = parse_entity_source(source).unwrap();
518 let f = &entity.fields[0];
519 assert!(f.is_primary_key);
520 assert_eq!(f.column_name, "itemID");
521 assert_eq!(f.rust_name, "item_id");
522 }
523
524 #[test]
525 fn test_full_entity() {
526 let source = r#"
527 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
528 #[sqlx_gen(kind = "table", table = "users")]
529 pub struct Users {
530 #[sqlx_gen(primary_key)]
531 pub id: i32,
532 pub name: String,
533 pub email: Option<String>,
534 #[sqlx(rename = "createdAt")]
535 pub created_at: chrono::NaiveDateTime,
536 }
537 "#;
538 let entity = parse_entity_source(source).unwrap();
539 assert_eq!(entity.struct_name, "Users");
540 assert_eq!(entity.table_name, "users");
541 assert!(!entity.is_view);
542 assert_eq!(entity.fields.len(), 4);
543
544 assert!(entity.fields[0].is_primary_key);
545 assert_eq!(entity.fields[0].rust_name, "id");
546
547 assert!(!entity.fields[1].is_primary_key);
548 assert_eq!(entity.fields[1].rust_type, "String");
549
550 assert!(entity.fields[2].is_nullable);
551 assert_eq!(entity.fields[2].inner_type, "String");
552
553 assert_eq!(entity.fields[3].column_name, "createdAt");
554 assert_eq!(entity.fields[3].rust_name, "created_at");
555 }
556
557 #[test]
560 fn test_imports_extracted() {
561 let source = r#"
562 use chrono::{DateTime, Utc};
563 use uuid::Uuid;
564 use serde::{Serialize, Deserialize};
565
566 #[derive(Debug, Clone, sqlx::FromRow)]
567 #[sqlx_gen(kind = "table", table = "users")]
568 pub struct Users {
569 pub id: Uuid,
570 pub created_at: DateTime<Utc>,
571 }
572 "#;
573 let entity = parse_entity_source(source).unwrap();
574 assert_eq!(entity.imports.len(), 2);
575 assert!(entity.imports.iter().any(|i| i.contains("chrono")));
576 assert!(entity.imports.iter().any(|i| i.contains("uuid")));
577 assert!(!entity.imports.iter().any(|i| i.contains("serde")));
579 }
580
581 #[test]
582 fn test_imports_empty_when_none() {
583 let source = r#"
584 #[derive(Debug, Clone, sqlx::FromRow)]
585 #[sqlx_gen(kind = "table", table = "users")]
586 pub struct Users {
587 pub id: i32,
588 }
589 "#;
590 let entity = parse_entity_source(source).unwrap();
591 assert!(entity.imports.is_empty());
592 }
593
594 #[test]
595 fn test_imports_keep_serde_json() {
596 let source = r#"
597 use serde::{Serialize, Deserialize};
598 use serde_json::Value;
599
600 #[derive(Debug, Clone, sqlx::FromRow)]
601 #[sqlx_gen(kind = "table", table = "users")]
602 pub struct Users {
603 pub data: Value,
604 }
605 "#;
606 let entity = parse_entity_source(source).unwrap();
607 assert_eq!(entity.imports.len(), 1);
608 assert!(entity.imports[0].contains("serde_json"));
609 }
610
611 #[test]
612 fn test_imports_exclude_sqlx() {
613 let source = r#"
614 use sqlx::types::Uuid;
615 use chrono::NaiveDateTime;
616
617 #[derive(Debug, Clone, sqlx::FromRow)]
618 #[sqlx_gen(kind = "table", table = "users")]
619 pub struct Users {
620 pub id: i32,
621 }
622 "#;
623 let entity = parse_entity_source(source).unwrap();
624 assert_eq!(entity.imports.len(), 1);
625 assert!(entity.imports[0].contains("chrono"));
626 }
627
628 #[test]
631 fn test_parse_column_default() {
632 let source = r#"
633 #[derive(Debug, Clone, sqlx::FromRow)]
634 #[sqlx_gen(kind = "table", table = "tasks")]
635 pub struct Tasks {
636 #[sqlx_gen(primary_key)]
637 pub id: i32,
638 #[sqlx_gen(column_default = "now()")]
639 pub created_at: String,
640 }
641 "#;
642 let entity = parse_entity_source(source).unwrap();
643 let created_at = &entity.fields[1];
644 assert_eq!(created_at.column_default, Some("now()".to_string()));
645 }
646
647 #[test]
648 fn test_parse_no_column_default() {
649 let source = r#"
650 #[derive(Debug, Clone, sqlx::FromRow)]
651 #[sqlx_gen(kind = "table", table = "tasks")]
652 pub struct Tasks {
653 #[sqlx_gen(primary_key)]
654 pub id: i32,
655 pub title: String,
656 }
657 "#;
658 let entity = parse_entity_source(source).unwrap();
659 let title = &entity.fields[1];
660 assert_eq!(title.column_default, None);
661 }
662
663 #[test]
664 fn test_parse_column_default_with_cast() {
665 let source = r#"
666 #[derive(Debug, Clone, sqlx::FromRow)]
667 #[sqlx_gen(kind = "table", table = "tasks")]
668 pub struct Tasks {
669 #[sqlx_gen(primary_key)]
670 pub id: i32,
671 #[sqlx_gen(column_default = "'idle'::task_status")]
672 pub status: String,
673 }
674 "#;
675 let entity = parse_entity_source(source).unwrap();
676 let status = &entity.fields[1];
677 assert_eq!(status.column_default, Some("'idle'::task_status".to_string()));
678 }
679
680 #[test]
681 fn test_parse_column_default_with_json_quotes() {
682 let source = r#"
683 #[derive(Debug, Clone, sqlx::FromRow)]
684 #[sqlx_gen(kind = "table", table = "configs")]
685 pub struct Configs {
686 #[sqlx_gen(primary_key)]
687 pub id: i32,
688 #[sqlx_gen(column_default = "'{\"1\": \"\", \"2\": \"\"}'::jsonb")]
689 pub template_variables: String,
690 }
691 "#;
692 let entity = parse_entity_source(source).unwrap();
693 let field = &entity.fields[1];
694 assert_eq!(
695 field.column_default,
696 Some(r#"'{"1": "", "2": ""}'::jsonb"#.to_string())
697 );
698 }
699}