1use std::path::Path;
2
3use quote::ToTokens;
4
5#[derive(Debug, Clone)]
7pub struct ParsedField {
8 pub rust_name: String,
10 pub column_name: String,
12 pub rust_type: String,
14 pub is_nullable: bool,
16 pub inner_type: String,
18 pub is_primary_key: bool,
20 pub sql_type: Option<String>,
22 pub is_sql_array: bool,
24}
25
26#[derive(Debug, Clone)]
28pub struct ParsedEntity {
29 pub struct_name: String,
31 pub table_name: String,
33 pub schema_name: Option<String>,
35 pub is_view: bool,
37 pub fields: Vec<ParsedField>,
39 pub imports: Vec<String>,
41}
42
43pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
45 let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
46 parse_entity_source(&source).map_err(|e| {
47 crate::error::Error::Config(format!("{}: {}", path.display(), e))
48 })
49}
50
51pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
53 let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
54
55 let imports = extract_use_imports(&syntax);
57
58 for item in &syntax.items {
59 if let syn::Item::Struct(item_struct) = item {
60 if has_from_row_derive(item_struct) {
61 let mut entity = extract_entity(item_struct)?;
62 entity.imports = imports;
63 return Ok(entity);
64 }
65 }
66 }
67
68 Err("No struct with sqlx::FromRow derive found".to_string())
69}
70
71fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
73 for attr in &item.attrs {
74 if attr.path().is_ident("derive") {
75 let tokens = attr.meta.to_token_stream().to_string();
76 if tokens.contains("FromRow") {
77 return true;
78 }
79 }
80 }
81 false
82}
83
84fn extract_use_imports(file: &syn::File) -> Vec<String> {
87 file.items
88 .iter()
89 .filter_map(|item| {
90 if let syn::Item::Use(use_item) = item {
91 let text = use_item.to_token_stream().to_string();
92 if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
94 return None;
95 }
96 let normalized = normalize_use_statement(&text);
98 Some(normalized)
99 } else {
100 None
101 }
102 })
103 .collect()
104}
105
106fn normalize_use_statement(s: &str) -> String {
108 s.replace(" :: ", "::")
109 .replace(":: ", "::")
110 .replace(" ::", "::")
111 .replace("{ ", "{")
112 .replace(" }", "}")
113 .replace(" ,", ",")
114 .replace(" ;", ";")
115}
116
117fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
119 let struct_name = item.ident.to_string();
120
121 let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
122 let is_view = kind.as_deref() == Some("view");
123
124 let table_name = table_name.unwrap_or_else(|| struct_name.clone());
126
127 let fields = match &item.fields {
128 syn::Fields::Named(named) => {
129 named
130 .named
131 .iter()
132 .map(extract_field)
133 .collect::<Result<Vec<_>, _>>()?
134 }
135 _ => return Err("Expected named fields".to_string()),
136 };
137
138 Ok(ParsedEntity {
139 struct_name,
140 table_name,
141 schema_name,
142 is_view,
143 fields,
144 imports: Vec::new(), })
146}
147
148fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
151 let mut kind = None;
152 let mut schema_name = None;
153 let mut table_name = None;
154
155 for attr in attrs {
156 if attr.path().is_ident("sqlx_gen") {
157 let tokens = attr.meta.to_token_stream().to_string();
158 if let Some(k) = extract_attr_value(&tokens, "kind") {
159 kind = Some(k);
160 }
161 if let Some(s) = extract_attr_value(&tokens, "schema") {
162 schema_name = Some(s);
163 }
164 if let Some(t) = extract_attr_value(&tokens, "table") {
165 table_name = Some(t);
166 }
167 }
168 }
169
170 (kind, schema_name, table_name)
171}
172
173fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
176 let pattern = format!("{} = \"", key);
177 let start = tokens.find(&pattern)? + pattern.len();
178 let rest = &tokens[start..];
179 let end = rest.find('"')?;
180 Some(rest[..end].to_string())
181}
182
183fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
185 let rust_name = field
186 .ident
187 .as_ref()
188 .ok_or("Unnamed field")?
189 .to_string();
190
191 let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
192 let (is_primary_key, sql_type, is_sql_array) = parse_sqlx_gen_field_attrs(&field.attrs);
193
194 let rust_type = field.ty.to_token_stream().to_string();
195 let (is_nullable, inner_type) = extract_option_type(&field.ty);
196 let inner_type = if is_nullable {
197 inner_type
198 } else {
199 rust_type.clone()
200 };
201
202 Ok(ParsedField {
203 rust_name,
204 column_name,
205 rust_type,
206 is_nullable,
207 inner_type,
208 is_primary_key,
209 sql_type,
210 is_sql_array,
211 })
212}
213
214fn parse_sqlx_gen_field_attrs(attrs: &[syn::Attribute]) -> (bool, Option<String>, bool) {
217 let mut is_pk = false;
218 let mut sql_type = None;
219 let mut is_array = false;
220
221 for attr in attrs {
222 if attr.path().is_ident("sqlx_gen") {
223 let tokens = attr.meta.to_token_stream().to_string();
224 if tokens.contains("primary_key") {
225 is_pk = true;
226 }
227 if let Some(t) = extract_attr_value(&tokens, "sql_type") {
228 sql_type = Some(t);
229 }
230 if tokens.contains("is_array") {
231 is_array = true;
232 }
233 }
234 }
235
236 (is_pk, sql_type, is_array)
237}
238
239fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
241 for attr in attrs {
242 if attr.path().is_ident("sqlx") {
243 let tokens = attr.meta.to_token_stream().to_string();
244 return extract_attr_value(&tokens, "rename");
245 }
246 }
247 None
248}
249
250fn extract_option_type(ty: &syn::Type) -> (bool, String) {
252 if let syn::Type::Path(type_path) = ty {
253 if let Some(segment) = type_path.path.segments.last() {
254 if segment.ident == "Option" {
255 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
256 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
257 return (true, inner.to_token_stream().to_string());
258 }
259 }
260 }
261 }
262 }
263 (false, String::new())
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
273 fn test_parse_simple_table() {
274 let source = r#"
275 #[derive(Debug, Clone, sqlx::FromRow)]
276 #[sqlx_gen(kind = "table", table = "users")]
277 pub struct Users {
278 pub id: i32,
279 pub name: String,
280 }
281 "#;
282 let entity = parse_entity_source(source).unwrap();
283 assert_eq!(entity.struct_name, "Users");
284 assert_eq!(entity.table_name, "users");
285 assert!(!entity.is_view);
286 assert_eq!(entity.fields.len(), 2);
287 }
288
289 #[test]
290 fn test_parse_view() {
291 let source = r#"
292 #[derive(Debug, Clone, sqlx::FromRow)]
293 #[sqlx_gen(kind = "view", table = "active_users")]
294 pub struct ActiveUsers {
295 pub id: i32,
296 }
297 "#;
298 let entity = parse_entity_source(source).unwrap();
299 assert!(entity.is_view);
300 assert_eq!(entity.table_name, "active_users");
301 }
302
303 #[test]
304 fn test_parse_table_not_view() {
305 let source = r#"
306 #[derive(Debug, Clone, sqlx::FromRow)]
307 #[sqlx_gen(kind = "table", table = "users")]
308 pub struct Users {
309 pub id: i32,
310 }
311 "#;
312 let entity = parse_entity_source(source).unwrap();
313 assert!(!entity.is_view);
314 }
315
316 #[test]
319 fn test_parse_primary_key() {
320 let source = r#"
321 #[derive(Debug, Clone, sqlx::FromRow)]
322 #[sqlx_gen(kind = "table", table = "users")]
323 pub struct Users {
324 #[sqlx_gen(primary_key)]
325 pub id: i32,
326 pub name: String,
327 }
328 "#;
329 let entity = parse_entity_source(source).unwrap();
330 assert!(entity.fields[0].is_primary_key);
331 assert!(!entity.fields[1].is_primary_key);
332 }
333
334 #[test]
335 fn test_composite_primary_key() {
336 let source = r#"
337 #[derive(Debug, Clone, sqlx::FromRow)]
338 #[sqlx_gen(kind = "table", table = "user_roles")]
339 pub struct UserRoles {
340 #[sqlx_gen(primary_key)]
341 pub user_id: i32,
342 #[sqlx_gen(primary_key)]
343 pub role_id: i32,
344 }
345 "#;
346 let entity = parse_entity_source(source).unwrap();
347 assert!(entity.fields[0].is_primary_key);
348 assert!(entity.fields[1].is_primary_key);
349 }
350
351 #[test]
352 fn test_no_primary_key() {
353 let source = r#"
354 #[derive(Debug, Clone, sqlx::FromRow)]
355 #[sqlx_gen(kind = "table", table = "logs")]
356 pub struct Logs {
357 pub message: String,
358 }
359 "#;
360 let entity = parse_entity_source(source).unwrap();
361 assert!(!entity.fields[0].is_primary_key);
362 }
363
364 #[test]
367 fn test_sqlx_rename() {
368 let source = r#"
369 #[derive(Debug, Clone, sqlx::FromRow)]
370 #[sqlx_gen(kind = "table", table = "connector")]
371 pub struct Connector {
372 #[sqlx(rename = "type")]
373 pub connector_type: String,
374 }
375 "#;
376 let entity = parse_entity_source(source).unwrap();
377 assert_eq!(entity.fields[0].rust_name, "connector_type");
378 assert_eq!(entity.fields[0].column_name, "type");
379 }
380
381 #[test]
382 fn test_no_rename_uses_field_name() {
383 let source = r#"
384 #[derive(Debug, Clone, sqlx::FromRow)]
385 #[sqlx_gen(kind = "table", table = "users")]
386 pub struct Users {
387 pub name: String,
388 }
389 "#;
390 let entity = parse_entity_source(source).unwrap();
391 assert_eq!(entity.fields[0].rust_name, "name");
392 assert_eq!(entity.fields[0].column_name, "name");
393 }
394
395 #[test]
398 fn test_option_field_nullable() {
399 let source = r#"
400 #[derive(Debug, Clone, sqlx::FromRow)]
401 #[sqlx_gen(kind = "table", table = "users")]
402 pub struct Users {
403 pub email: Option<String>,
404 }
405 "#;
406 let entity = parse_entity_source(source).unwrap();
407 assert!(entity.fields[0].is_nullable);
408 assert_eq!(entity.fields[0].inner_type, "String");
409 }
410
411 #[test]
412 fn test_non_option_not_nullable() {
413 let source = r#"
414 #[derive(Debug, Clone, sqlx::FromRow)]
415 #[sqlx_gen(kind = "table", table = "users")]
416 pub struct Users {
417 pub id: i32,
418 }
419 "#;
420 let entity = parse_entity_source(source).unwrap();
421 assert!(!entity.fields[0].is_nullable);
422 assert_eq!(entity.fields[0].inner_type, "i32");
423 }
424
425 #[test]
426 fn test_option_complex_type() {
427 let source = r#"
428 #[derive(Debug, Clone, sqlx::FromRow)]
429 #[sqlx_gen(kind = "table", table = "users")]
430 pub struct Users {
431 pub created_at: Option<chrono::NaiveDateTime>,
432 }
433 "#;
434 let entity = parse_entity_source(source).unwrap();
435 assert!(entity.fields[0].is_nullable);
436 assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
437 }
438
439 #[test]
442 fn test_rust_type_preserved() {
443 let source = r#"
444 #[derive(Debug, Clone, sqlx::FromRow)]
445 #[sqlx_gen(kind = "table", table = "users")]
446 pub struct Users {
447 pub id: uuid::Uuid,
448 }
449 "#;
450 let entity = parse_entity_source(source).unwrap();
451 assert!(entity.fields[0].rust_type.contains("Uuid"));
452 }
453
454 #[test]
457 fn test_no_from_row_struct() {
458 let source = r#"
459 pub struct NotAnEntity {
460 pub id: i32,
461 }
462 "#;
463 let result = parse_entity_source(source);
464 assert!(result.is_err());
465 }
466
467 #[test]
468 fn test_empty_source() {
469 let result = parse_entity_source("");
470 assert!(result.is_err());
471 }
472
473 #[test]
476 fn test_fallback_table_name_to_struct_name() {
477 let source = r#"
478 #[derive(Debug, Clone, sqlx::FromRow)]
479 pub struct Users {
480 pub id: i32,
481 }
482 "#;
483 let entity = parse_entity_source(source).unwrap();
484 assert_eq!(entity.table_name, "Users");
485 }
486
487 #[test]
490 fn test_pk_with_rename() {
491 let source = r#"
492 #[derive(Debug, Clone, sqlx::FromRow)]
493 #[sqlx_gen(kind = "table", table = "items")]
494 pub struct Items {
495 #[sqlx_gen(primary_key)]
496 #[sqlx(rename = "itemID")]
497 pub item_id: i32,
498 }
499 "#;
500 let entity = parse_entity_source(source).unwrap();
501 let f = &entity.fields[0];
502 assert!(f.is_primary_key);
503 assert_eq!(f.column_name, "itemID");
504 assert_eq!(f.rust_name, "item_id");
505 }
506
507 #[test]
508 fn test_full_entity() {
509 let source = r#"
510 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
511 #[sqlx_gen(kind = "table", table = "users")]
512 pub struct Users {
513 #[sqlx_gen(primary_key)]
514 pub id: i32,
515 pub name: String,
516 pub email: Option<String>,
517 #[sqlx(rename = "createdAt")]
518 pub created_at: chrono::NaiveDateTime,
519 }
520 "#;
521 let entity = parse_entity_source(source).unwrap();
522 assert_eq!(entity.struct_name, "Users");
523 assert_eq!(entity.table_name, "users");
524 assert!(!entity.is_view);
525 assert_eq!(entity.fields.len(), 4);
526
527 assert!(entity.fields[0].is_primary_key);
528 assert_eq!(entity.fields[0].rust_name, "id");
529
530 assert!(!entity.fields[1].is_primary_key);
531 assert_eq!(entity.fields[1].rust_type, "String");
532
533 assert!(entity.fields[2].is_nullable);
534 assert_eq!(entity.fields[2].inner_type, "String");
535
536 assert_eq!(entity.fields[3].column_name, "createdAt");
537 assert_eq!(entity.fields[3].rust_name, "created_at");
538 }
539
540 #[test]
543 fn test_imports_extracted() {
544 let source = r#"
545 use chrono::{DateTime, Utc};
546 use uuid::Uuid;
547 use serde::{Serialize, Deserialize};
548
549 #[derive(Debug, Clone, sqlx::FromRow)]
550 #[sqlx_gen(kind = "table", table = "users")]
551 pub struct Users {
552 pub id: Uuid,
553 pub created_at: DateTime<Utc>,
554 }
555 "#;
556 let entity = parse_entity_source(source).unwrap();
557 assert_eq!(entity.imports.len(), 2);
558 assert!(entity.imports.iter().any(|i| i.contains("chrono")));
559 assert!(entity.imports.iter().any(|i| i.contains("uuid")));
560 assert!(!entity.imports.iter().any(|i| i.contains("serde")));
562 }
563
564 #[test]
565 fn test_imports_empty_when_none() {
566 let source = r#"
567 #[derive(Debug, Clone, sqlx::FromRow)]
568 #[sqlx_gen(kind = "table", table = "users")]
569 pub struct Users {
570 pub id: i32,
571 }
572 "#;
573 let entity = parse_entity_source(source).unwrap();
574 assert!(entity.imports.is_empty());
575 }
576
577 #[test]
578 fn test_imports_keep_serde_json() {
579 let source = r#"
580 use serde::{Serialize, Deserialize};
581 use serde_json::Value;
582
583 #[derive(Debug, Clone, sqlx::FromRow)]
584 #[sqlx_gen(kind = "table", table = "users")]
585 pub struct Users {
586 pub data: Value,
587 }
588 "#;
589 let entity = parse_entity_source(source).unwrap();
590 assert_eq!(entity.imports.len(), 1);
591 assert!(entity.imports[0].contains("serde_json"));
592 }
593
594 #[test]
595 fn test_imports_exclude_sqlx() {
596 let source = r#"
597 use sqlx::types::Uuid;
598 use chrono::NaiveDateTime;
599
600 #[derive(Debug, Clone, sqlx::FromRow)]
601 #[sqlx_gen(kind = "table", table = "users")]
602 pub struct Users {
603 pub id: i32,
604 }
605 "#;
606 let entity = parse_entity_source(source).unwrap();
607 assert_eq!(entity.imports.len(), 1);
608 assert!(entity.imports[0].contains("chrono"));
609 }
610}