1use std::path::Path;
2
3use quote::ToTokens;
4
5#[derive(Debug, Clone)]
7pub struct ParsedField {
8 pub rust_name: String,
10 pub column_name: String,
12 pub rust_type: String,
14 pub is_nullable: bool,
16 pub inner_type: String,
18 pub is_primary_key: bool,
20 pub sql_type: Option<String>,
22 pub is_sql_array: bool,
24 pub column_default: Option<String>,
26}
27
28#[derive(Debug, Clone)]
30pub struct ParsedEntity {
31 pub struct_name: String,
33 pub table_name: String,
35 pub schema_name: Option<String>,
37 pub is_view: bool,
39 pub fields: Vec<ParsedField>,
41 pub imports: Vec<String>,
43}
44
45pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
47 let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
48 parse_entity_source(&source)
49 .map_err(|e| crate::error::Error::Config(format!("{}: {}", path.display(), e)))
50}
51
52pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
54 let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
55
56 let imports = extract_use_imports(&syntax);
58
59 for item in &syntax.items {
60 if let syn::Item::Struct(item_struct) = item {
61 if has_from_row_derive(item_struct) {
62 let mut entity = extract_entity(item_struct)?;
63 entity.imports = imports;
64 return Ok(entity);
65 }
66 }
67 }
68
69 Err("No struct with sqlx::FromRow derive found".to_string())
70}
71
72fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
74 for attr in &item.attrs {
75 if attr.path().is_ident("derive") {
76 let tokens = attr.meta.to_token_stream().to_string();
77 if tokens.contains("FromRow") {
78 return true;
79 }
80 }
81 }
82 false
83}
84
85fn extract_use_imports(file: &syn::File) -> Vec<String> {
88 file.items
89 .iter()
90 .filter_map(|item| {
91 if let syn::Item::Use(use_item) = item {
92 let text = use_item.to_token_stream().to_string();
93 if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
95 return None;
96 }
97 let normalized = normalize_use_statement(&text);
99 Some(normalized)
100 } else {
101 None
102 }
103 })
104 .collect()
105}
106
107fn normalize_use_statement(s: &str) -> String {
109 s.replace(" :: ", "::")
110 .replace(":: ", "::")
111 .replace(" ::", "::")
112 .replace("{ ", "{")
113 .replace(" }", "}")
114 .replace(" ,", ",")
115 .replace(" ;", ";")
116}
117
118fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
120 let struct_name = item.ident.to_string();
121
122 let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
123 let is_view = kind.as_deref() == Some("view");
124
125 let table_name = table_name.unwrap_or_else(|| struct_name.clone());
127
128 let fields = match &item.fields {
129 syn::Fields::Named(named) => named
130 .named
131 .iter()
132 .map(extract_field)
133 .collect::<Result<Vec<_>, _>>()?,
134 _ => return Err("Expected named fields".to_string()),
135 };
136
137 Ok(ParsedEntity {
138 struct_name,
139 table_name,
140 schema_name,
141 is_view,
142 fields,
143 imports: Vec::new(), })
145}
146
147fn parse_sqlx_gen_struct_attrs(
150 attrs: &[syn::Attribute],
151) -> (Option<String>, Option<String>, Option<String>) {
152 let mut kind = None;
153 let mut schema_name = None;
154 let mut table_name = None;
155
156 for attr in attrs {
157 if attr.path().is_ident("sqlx_gen") {
158 let tokens = attr.meta.to_token_stream().to_string();
159 if let Some(k) = extract_attr_value(&tokens, "kind") {
160 kind = Some(k);
161 }
162 if let Some(s) = extract_attr_value(&tokens, "schema") {
163 schema_name = Some(s);
164 }
165 if let Some(t) = extract_attr_value(&tokens, "table") {
166 table_name = Some(t);
167 }
168 }
169 }
170
171 (kind, schema_name, table_name)
172}
173
174fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
177 let pattern = format!("{} = \"", key);
178 let start = tokens.find(&pattern)? + pattern.len();
179 let rest = &tokens[start..];
180 let bytes = rest.as_bytes();
181 let mut end = 0;
182 while end < bytes.len() {
183 if bytes[end] == b'"' && (end == 0 || bytes[end - 1] != b'\\') {
184 break;
185 }
186 end += 1;
187 }
188 if end >= bytes.len() {
189 return None;
190 }
191 Some(rest[..end].replace("\\\"", "\""))
192}
193
194fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
196 let rust_name = field.ident.as_ref().ok_or("Unnamed field")?.to_string();
197
198 let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
199 let (is_primary_key, sql_type, is_sql_array, column_default) =
200 parse_sqlx_gen_field_attrs(&field.attrs);
201
202 let rust_type = field.ty.to_token_stream().to_string();
203 let (is_nullable, inner_type) = extract_option_type(&field.ty);
204 let inner_type = if is_nullable {
205 inner_type
206 } else {
207 rust_type.clone()
208 };
209
210 Ok(ParsedField {
211 rust_name,
212 column_name,
213 rust_type,
214 is_nullable,
215 inner_type,
216 is_primary_key,
217 sql_type,
218 is_sql_array,
219 column_default,
220 })
221}
222
223fn parse_sqlx_gen_field_attrs(
226 attrs: &[syn::Attribute],
227) -> (bool, Option<String>, bool, Option<String>) {
228 let mut is_pk = false;
229 let mut sql_type = None;
230 let mut is_array = false;
231 let mut column_default = None;
232
233 for attr in attrs {
234 if attr.path().is_ident("sqlx_gen") {
235 let tokens = attr.meta.to_token_stream().to_string();
236 if tokens.contains("primary_key") {
237 is_pk = true;
238 }
239 if let Some(t) = extract_attr_value(&tokens, "sql_type") {
240 sql_type = Some(t);
241 }
242 if tokens.contains("is_array") {
243 is_array = true;
244 }
245 if let Some(d) = extract_attr_value(&tokens, "column_default") {
246 column_default = Some(d);
247 }
248 }
249 }
250
251 (is_pk, sql_type, is_array, column_default)
252}
253
254fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
256 for attr in attrs {
257 if attr.path().is_ident("sqlx") {
258 let tokens = attr.meta.to_token_stream().to_string();
259 return extract_attr_value(&tokens, "rename");
260 }
261 }
262 None
263}
264
265fn extract_option_type(ty: &syn::Type) -> (bool, String) {
267 if let syn::Type::Path(type_path) = ty {
268 if let Some(segment) = type_path.path.segments.last() {
269 if segment.ident == "Option" {
270 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
271 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
272 return (true, inner.to_token_stream().to_string());
273 }
274 }
275 }
276 }
277 }
278 (false, String::new())
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
288 fn test_parse_simple_table() {
289 let source = r#"
290 #[derive(Debug, Clone, sqlx::FromRow)]
291 #[sqlx_gen(kind = "table", table = "users")]
292 pub struct Users {
293 pub id: i32,
294 pub name: String,
295 }
296 "#;
297 let entity = parse_entity_source(source).unwrap();
298 assert_eq!(entity.struct_name, "Users");
299 assert_eq!(entity.table_name, "users");
300 assert!(!entity.is_view);
301 assert_eq!(entity.fields.len(), 2);
302 }
303
304 #[test]
305 fn test_parse_view() {
306 let source = r#"
307 #[derive(Debug, Clone, sqlx::FromRow)]
308 #[sqlx_gen(kind = "view", table = "active_users")]
309 pub struct ActiveUsers {
310 pub id: i32,
311 }
312 "#;
313 let entity = parse_entity_source(source).unwrap();
314 assert!(entity.is_view);
315 assert_eq!(entity.table_name, "active_users");
316 }
317
318 #[test]
319 fn test_parse_table_not_view() {
320 let source = r#"
321 #[derive(Debug, Clone, sqlx::FromRow)]
322 #[sqlx_gen(kind = "table", table = "users")]
323 pub struct Users {
324 pub id: i32,
325 }
326 "#;
327 let entity = parse_entity_source(source).unwrap();
328 assert!(!entity.is_view);
329 }
330
331 #[test]
334 fn test_parse_primary_key() {
335 let source = r#"
336 #[derive(Debug, Clone, sqlx::FromRow)]
337 #[sqlx_gen(kind = "table", table = "users")]
338 pub struct Users {
339 #[sqlx_gen(primary_key)]
340 pub id: i32,
341 pub name: String,
342 }
343 "#;
344 let entity = parse_entity_source(source).unwrap();
345 assert!(entity.fields[0].is_primary_key);
346 assert!(!entity.fields[1].is_primary_key);
347 }
348
349 #[test]
350 fn test_composite_primary_key() {
351 let source = r#"
352 #[derive(Debug, Clone, sqlx::FromRow)]
353 #[sqlx_gen(kind = "table", table = "user_roles")]
354 pub struct UserRoles {
355 #[sqlx_gen(primary_key)]
356 pub user_id: i32,
357 #[sqlx_gen(primary_key)]
358 pub role_id: i32,
359 }
360 "#;
361 let entity = parse_entity_source(source).unwrap();
362 assert!(entity.fields[0].is_primary_key);
363 assert!(entity.fields[1].is_primary_key);
364 }
365
366 #[test]
367 fn test_no_primary_key() {
368 let source = r#"
369 #[derive(Debug, Clone, sqlx::FromRow)]
370 #[sqlx_gen(kind = "table", table = "logs")]
371 pub struct Logs {
372 pub message: String,
373 }
374 "#;
375 let entity = parse_entity_source(source).unwrap();
376 assert!(!entity.fields[0].is_primary_key);
377 }
378
379 #[test]
382 fn test_sqlx_rename() {
383 let source = r#"
384 #[derive(Debug, Clone, sqlx::FromRow)]
385 #[sqlx_gen(kind = "table", table = "connector")]
386 pub struct Connector {
387 #[sqlx(rename = "type")]
388 pub connector_type: String,
389 }
390 "#;
391 let entity = parse_entity_source(source).unwrap();
392 assert_eq!(entity.fields[0].rust_name, "connector_type");
393 assert_eq!(entity.fields[0].column_name, "type");
394 }
395
396 #[test]
397 fn test_no_rename_uses_field_name() {
398 let source = r#"
399 #[derive(Debug, Clone, sqlx::FromRow)]
400 #[sqlx_gen(kind = "table", table = "users")]
401 pub struct Users {
402 pub name: String,
403 }
404 "#;
405 let entity = parse_entity_source(source).unwrap();
406 assert_eq!(entity.fields[0].rust_name, "name");
407 assert_eq!(entity.fields[0].column_name, "name");
408 }
409
410 #[test]
413 fn test_option_field_nullable() {
414 let source = r#"
415 #[derive(Debug, Clone, sqlx::FromRow)]
416 #[sqlx_gen(kind = "table", table = "users")]
417 pub struct Users {
418 pub email: Option<String>,
419 }
420 "#;
421 let entity = parse_entity_source(source).unwrap();
422 assert!(entity.fields[0].is_nullable);
423 assert_eq!(entity.fields[0].inner_type, "String");
424 }
425
426 #[test]
427 fn test_non_option_not_nullable() {
428 let source = r#"
429 #[derive(Debug, Clone, sqlx::FromRow)]
430 #[sqlx_gen(kind = "table", table = "users")]
431 pub struct Users {
432 pub id: i32,
433 }
434 "#;
435 let entity = parse_entity_source(source).unwrap();
436 assert!(!entity.fields[0].is_nullable);
437 assert_eq!(entity.fields[0].inner_type, "i32");
438 }
439
440 #[test]
441 fn test_option_complex_type() {
442 let source = r#"
443 #[derive(Debug, Clone, sqlx::FromRow)]
444 #[sqlx_gen(kind = "table", table = "users")]
445 pub struct Users {
446 pub created_at: Option<chrono::NaiveDateTime>,
447 }
448 "#;
449 let entity = parse_entity_source(source).unwrap();
450 assert!(entity.fields[0].is_nullable);
451 assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
452 }
453
454 #[test]
457 fn test_rust_type_preserved() {
458 let source = r#"
459 #[derive(Debug, Clone, sqlx::FromRow)]
460 #[sqlx_gen(kind = "table", table = "users")]
461 pub struct Users {
462 pub id: uuid::Uuid,
463 }
464 "#;
465 let entity = parse_entity_source(source).unwrap();
466 assert!(entity.fields[0].rust_type.contains("Uuid"));
467 }
468
469 #[test]
472 fn test_no_from_row_struct() {
473 let source = r#"
474 pub struct NotAnEntity {
475 pub id: i32,
476 }
477 "#;
478 let result = parse_entity_source(source);
479 assert!(result.is_err());
480 }
481
482 #[test]
483 fn test_empty_source() {
484 let result = parse_entity_source("");
485 assert!(result.is_err());
486 }
487
488 #[test]
491 fn test_fallback_table_name_to_struct_name() {
492 let source = r#"
493 #[derive(Debug, Clone, sqlx::FromRow)]
494 pub struct Users {
495 pub id: i32,
496 }
497 "#;
498 let entity = parse_entity_source(source).unwrap();
499 assert_eq!(entity.table_name, "Users");
500 }
501
502 #[test]
505 fn test_pk_with_rename() {
506 let source = r#"
507 #[derive(Debug, Clone, sqlx::FromRow)]
508 #[sqlx_gen(kind = "table", table = "items")]
509 pub struct Items {
510 #[sqlx_gen(primary_key)]
511 #[sqlx(rename = "itemID")]
512 pub item_id: i32,
513 }
514 "#;
515 let entity = parse_entity_source(source).unwrap();
516 let f = &entity.fields[0];
517 assert!(f.is_primary_key);
518 assert_eq!(f.column_name, "itemID");
519 assert_eq!(f.rust_name, "item_id");
520 }
521
522 #[test]
523 fn test_full_entity() {
524 let source = r#"
525 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
526 #[sqlx_gen(kind = "table", table = "users")]
527 pub struct Users {
528 #[sqlx_gen(primary_key)]
529 pub id: i32,
530 pub name: String,
531 pub email: Option<String>,
532 #[sqlx(rename = "createdAt")]
533 pub created_at: chrono::NaiveDateTime,
534 }
535 "#;
536 let entity = parse_entity_source(source).unwrap();
537 assert_eq!(entity.struct_name, "Users");
538 assert_eq!(entity.table_name, "users");
539 assert!(!entity.is_view);
540 assert_eq!(entity.fields.len(), 4);
541
542 assert!(entity.fields[0].is_primary_key);
543 assert_eq!(entity.fields[0].rust_name, "id");
544
545 assert!(!entity.fields[1].is_primary_key);
546 assert_eq!(entity.fields[1].rust_type, "String");
547
548 assert!(entity.fields[2].is_nullable);
549 assert_eq!(entity.fields[2].inner_type, "String");
550
551 assert_eq!(entity.fields[3].column_name, "createdAt");
552 assert_eq!(entity.fields[3].rust_name, "created_at");
553 }
554
555 #[test]
558 fn test_imports_extracted() {
559 let source = r#"
560 use chrono::{DateTime, Utc};
561 use uuid::Uuid;
562 use serde::{Serialize, Deserialize};
563
564 #[derive(Debug, Clone, sqlx::FromRow)]
565 #[sqlx_gen(kind = "table", table = "users")]
566 pub struct Users {
567 pub id: Uuid,
568 pub created_at: DateTime<Utc>,
569 }
570 "#;
571 let entity = parse_entity_source(source).unwrap();
572 assert_eq!(entity.imports.len(), 2);
573 assert!(entity.imports.iter().any(|i| i.contains("chrono")));
574 assert!(entity.imports.iter().any(|i| i.contains("uuid")));
575 assert!(!entity.imports.iter().any(|i| i.contains("serde")));
577 }
578
579 #[test]
580 fn test_imports_empty_when_none() {
581 let source = r#"
582 #[derive(Debug, Clone, sqlx::FromRow)]
583 #[sqlx_gen(kind = "table", table = "users")]
584 pub struct Users {
585 pub id: i32,
586 }
587 "#;
588 let entity = parse_entity_source(source).unwrap();
589 assert!(entity.imports.is_empty());
590 }
591
592 #[test]
593 fn test_imports_keep_serde_json() {
594 let source = r#"
595 use serde::{Serialize, Deserialize};
596 use serde_json::Value;
597
598 #[derive(Debug, Clone, sqlx::FromRow)]
599 #[sqlx_gen(kind = "table", table = "users")]
600 pub struct Users {
601 pub data: Value,
602 }
603 "#;
604 let entity = parse_entity_source(source).unwrap();
605 assert_eq!(entity.imports.len(), 1);
606 assert!(entity.imports[0].contains("serde_json"));
607 }
608
609 #[test]
610 fn test_imports_exclude_sqlx() {
611 let source = r#"
612 use sqlx::types::Uuid;
613 use chrono::NaiveDateTime;
614
615 #[derive(Debug, Clone, sqlx::FromRow)]
616 #[sqlx_gen(kind = "table", table = "users")]
617 pub struct Users {
618 pub id: i32,
619 }
620 "#;
621 let entity = parse_entity_source(source).unwrap();
622 assert_eq!(entity.imports.len(), 1);
623 assert!(entity.imports[0].contains("chrono"));
624 }
625
626 #[test]
629 fn test_parse_column_default() {
630 let source = r#"
631 #[derive(Debug, Clone, sqlx::FromRow)]
632 #[sqlx_gen(kind = "table", table = "tasks")]
633 pub struct Tasks {
634 #[sqlx_gen(primary_key)]
635 pub id: i32,
636 #[sqlx_gen(column_default = "now()")]
637 pub created_at: String,
638 }
639 "#;
640 let entity = parse_entity_source(source).unwrap();
641 let created_at = &entity.fields[1];
642 assert_eq!(created_at.column_default, Some("now()".to_string()));
643 }
644
645 #[test]
646 fn test_parse_no_column_default() {
647 let source = r#"
648 #[derive(Debug, Clone, sqlx::FromRow)]
649 #[sqlx_gen(kind = "table", table = "tasks")]
650 pub struct Tasks {
651 #[sqlx_gen(primary_key)]
652 pub id: i32,
653 pub title: String,
654 }
655 "#;
656 let entity = parse_entity_source(source).unwrap();
657 let title = &entity.fields[1];
658 assert_eq!(title.column_default, None);
659 }
660
661 #[test]
662 fn test_parse_column_default_with_cast() {
663 let source = r#"
664 #[derive(Debug, Clone, sqlx::FromRow)]
665 #[sqlx_gen(kind = "table", table = "tasks")]
666 pub struct Tasks {
667 #[sqlx_gen(primary_key)]
668 pub id: i32,
669 #[sqlx_gen(column_default = "'idle'::task_status")]
670 pub status: String,
671 }
672 "#;
673 let entity = parse_entity_source(source).unwrap();
674 let status = &entity.fields[1];
675 assert_eq!(
676 status.column_default,
677 Some("'idle'::task_status".to_string())
678 );
679 }
680
681 #[test]
682 fn test_parse_column_default_with_json_quotes() {
683 let source = r#"
684 #[derive(Debug, Clone, sqlx::FromRow)]
685 #[sqlx_gen(kind = "table", table = "configs")]
686 pub struct Configs {
687 #[sqlx_gen(primary_key)]
688 pub id: i32,
689 #[sqlx_gen(column_default = "'{\"1\": \"\", \"2\": \"\"}'::jsonb")]
690 pub template_variables: String,
691 }
692 "#;
693 let entity = parse_entity_source(source).unwrap();
694 let field = &entity.fields[1];
695 assert_eq!(
696 field.column_default,
697 Some(r#"'{"1": "", "2": ""}'::jsonb"#.to_string())
698 );
699 }
700}