1use std::path::Path;
2
3use quote::ToTokens;
4
5#[derive(Debug, Clone)]
7pub struct ParsedField {
8 pub rust_name: String,
10 pub column_name: String,
12 pub rust_type: String,
14 pub is_nullable: bool,
16 pub inner_type: String,
18 pub is_primary_key: bool,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParsedEntity {
25 pub struct_name: String,
27 pub table_name: String,
29 pub schema_name: Option<String>,
31 pub is_view: bool,
33 pub fields: Vec<ParsedField>,
35 pub imports: Vec<String>,
37}
38
39pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
41 let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
42 parse_entity_source(&source).map_err(|e| {
43 crate::error::Error::Config(format!("{}: {}", path.display(), e))
44 })
45}
46
47pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
49 let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
50
51 let imports = extract_use_imports(&syntax);
53
54 for item in &syntax.items {
55 if let syn::Item::Struct(item_struct) = item {
56 if has_from_row_derive(item_struct) {
57 let mut entity = extract_entity(item_struct)?;
58 entity.imports = imports;
59 return Ok(entity);
60 }
61 }
62 }
63
64 Err("No struct with sqlx::FromRow derive found".to_string())
65}
66
67fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
69 for attr in &item.attrs {
70 if attr.path().is_ident("derive") {
71 let tokens = attr.meta.to_token_stream().to_string();
72 if tokens.contains("FromRow") {
73 return true;
74 }
75 }
76 }
77 false
78}
79
80fn extract_use_imports(file: &syn::File) -> Vec<String> {
83 file.items
84 .iter()
85 .filter_map(|item| {
86 if let syn::Item::Use(use_item) = item {
87 let text = use_item.to_token_stream().to_string();
88 if (text.contains("serde") && !text.contains("serde_")) || text.contains("sqlx") {
90 return None;
91 }
92 let normalized = normalize_use_statement(&text);
94 Some(normalized)
95 } else {
96 None
97 }
98 })
99 .collect()
100}
101
102fn normalize_use_statement(s: &str) -> String {
104 s.replace(" :: ", "::")
105 .replace(":: ", "::")
106 .replace(" ::", "::")
107 .replace("{ ", "{")
108 .replace(" }", "}")
109 .replace(" ,", ",")
110 .replace(" ;", ";")
111}
112
113fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
115 let struct_name = item.ident.to_string();
116
117 let (kind, schema_name, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
118 let is_view = kind.as_deref() == Some("view");
119
120 let table_name = table_name.unwrap_or_else(|| struct_name.clone());
122
123 let fields = match &item.fields {
124 syn::Fields::Named(named) => {
125 named
126 .named
127 .iter()
128 .map(extract_field)
129 .collect::<Result<Vec<_>, _>>()?
130 }
131 _ => return Err("Expected named fields".to_string()),
132 };
133
134 Ok(ParsedEntity {
135 struct_name,
136 table_name,
137 schema_name,
138 is_view,
139 fields,
140 imports: Vec::new(), })
142}
143
144fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>, Option<String>) {
147 let mut kind = None;
148 let mut schema_name = None;
149 let mut table_name = None;
150
151 for attr in attrs {
152 if attr.path().is_ident("sqlx_gen") {
153 let tokens = attr.meta.to_token_stream().to_string();
154 if let Some(k) = extract_attr_value(&tokens, "kind") {
155 kind = Some(k);
156 }
157 if let Some(s) = extract_attr_value(&tokens, "schema") {
158 schema_name = Some(s);
159 }
160 if let Some(t) = extract_attr_value(&tokens, "table") {
161 table_name = Some(t);
162 }
163 }
164 }
165
166 (kind, schema_name, table_name)
167}
168
169fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
172 let pattern = format!("{} = \"", key);
173 let start = tokens.find(&pattern)? + pattern.len();
174 let rest = &tokens[start..];
175 let end = rest.find('"')?;
176 Some(rest[..end].to_string())
177}
178
179fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
181 let rust_name = field
182 .ident
183 .as_ref()
184 .ok_or("Unnamed field")?
185 .to_string();
186
187 let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
188 let is_primary_key = has_sqlx_gen_primary_key(&field.attrs);
189
190 let rust_type = field.ty.to_token_stream().to_string();
191 let (is_nullable, inner_type) = extract_option_type(&field.ty);
192 let inner_type = if is_nullable {
193 inner_type
194 } else {
195 rust_type.clone()
196 };
197
198 Ok(ParsedField {
199 rust_name,
200 column_name,
201 rust_type,
202 is_nullable,
203 inner_type,
204 is_primary_key,
205 })
206}
207
208fn has_sqlx_gen_primary_key(attrs: &[syn::Attribute]) -> bool {
210 for attr in attrs {
211 if attr.path().is_ident("sqlx_gen") {
212 let tokens = attr.meta.to_token_stream().to_string();
213 if tokens.contains("primary_key") {
214 return true;
215 }
216 }
217 }
218 false
219}
220
221fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
223 for attr in attrs {
224 if attr.path().is_ident("sqlx") {
225 let tokens = attr.meta.to_token_stream().to_string();
226 return extract_attr_value(&tokens, "rename");
227 }
228 }
229 None
230}
231
232fn extract_option_type(ty: &syn::Type) -> (bool, String) {
234 if let syn::Type::Path(type_path) = ty {
235 if let Some(segment) = type_path.path.segments.last() {
236 if segment.ident == "Option" {
237 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
238 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
239 return (true, inner.to_token_stream().to_string());
240 }
241 }
242 }
243 }
244 }
245 (false, String::new())
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
255 fn test_parse_simple_table() {
256 let source = r#"
257 #[derive(Debug, Clone, sqlx::FromRow)]
258 #[sqlx_gen(kind = "table", table = "users")]
259 pub struct Users {
260 pub id: i32,
261 pub name: String,
262 }
263 "#;
264 let entity = parse_entity_source(source).unwrap();
265 assert_eq!(entity.struct_name, "Users");
266 assert_eq!(entity.table_name, "users");
267 assert!(!entity.is_view);
268 assert_eq!(entity.fields.len(), 2);
269 }
270
271 #[test]
272 fn test_parse_view() {
273 let source = r#"
274 #[derive(Debug, Clone, sqlx::FromRow)]
275 #[sqlx_gen(kind = "view", table = "active_users")]
276 pub struct ActiveUsers {
277 pub id: i32,
278 }
279 "#;
280 let entity = parse_entity_source(source).unwrap();
281 assert!(entity.is_view);
282 assert_eq!(entity.table_name, "active_users");
283 }
284
285 #[test]
286 fn test_parse_table_not_view() {
287 let source = r#"
288 #[derive(Debug, Clone, sqlx::FromRow)]
289 #[sqlx_gen(kind = "table", table = "users")]
290 pub struct Users {
291 pub id: i32,
292 }
293 "#;
294 let entity = parse_entity_source(source).unwrap();
295 assert!(!entity.is_view);
296 }
297
298 #[test]
301 fn test_parse_primary_key() {
302 let source = r#"
303 #[derive(Debug, Clone, sqlx::FromRow)]
304 #[sqlx_gen(kind = "table", table = "users")]
305 pub struct Users {
306 #[sqlx_gen(primary_key)]
307 pub id: i32,
308 pub name: String,
309 }
310 "#;
311 let entity = parse_entity_source(source).unwrap();
312 assert!(entity.fields[0].is_primary_key);
313 assert!(!entity.fields[1].is_primary_key);
314 }
315
316 #[test]
317 fn test_composite_primary_key() {
318 let source = r#"
319 #[derive(Debug, Clone, sqlx::FromRow)]
320 #[sqlx_gen(kind = "table", table = "user_roles")]
321 pub struct UserRoles {
322 #[sqlx_gen(primary_key)]
323 pub user_id: i32,
324 #[sqlx_gen(primary_key)]
325 pub role_id: i32,
326 }
327 "#;
328 let entity = parse_entity_source(source).unwrap();
329 assert!(entity.fields[0].is_primary_key);
330 assert!(entity.fields[1].is_primary_key);
331 }
332
333 #[test]
334 fn test_no_primary_key() {
335 let source = r#"
336 #[derive(Debug, Clone, sqlx::FromRow)]
337 #[sqlx_gen(kind = "table", table = "logs")]
338 pub struct Logs {
339 pub message: String,
340 }
341 "#;
342 let entity = parse_entity_source(source).unwrap();
343 assert!(!entity.fields[0].is_primary_key);
344 }
345
346 #[test]
349 fn test_sqlx_rename() {
350 let source = r#"
351 #[derive(Debug, Clone, sqlx::FromRow)]
352 #[sqlx_gen(kind = "table", table = "connector")]
353 pub struct Connector {
354 #[sqlx(rename = "type")]
355 pub connector_type: String,
356 }
357 "#;
358 let entity = parse_entity_source(source).unwrap();
359 assert_eq!(entity.fields[0].rust_name, "connector_type");
360 assert_eq!(entity.fields[0].column_name, "type");
361 }
362
363 #[test]
364 fn test_no_rename_uses_field_name() {
365 let source = r#"
366 #[derive(Debug, Clone, sqlx::FromRow)]
367 #[sqlx_gen(kind = "table", table = "users")]
368 pub struct Users {
369 pub name: String,
370 }
371 "#;
372 let entity = parse_entity_source(source).unwrap();
373 assert_eq!(entity.fields[0].rust_name, "name");
374 assert_eq!(entity.fields[0].column_name, "name");
375 }
376
377 #[test]
380 fn test_option_field_nullable() {
381 let source = r#"
382 #[derive(Debug, Clone, sqlx::FromRow)]
383 #[sqlx_gen(kind = "table", table = "users")]
384 pub struct Users {
385 pub email: Option<String>,
386 }
387 "#;
388 let entity = parse_entity_source(source).unwrap();
389 assert!(entity.fields[0].is_nullable);
390 assert_eq!(entity.fields[0].inner_type, "String");
391 }
392
393 #[test]
394 fn test_non_option_not_nullable() {
395 let source = r#"
396 #[derive(Debug, Clone, sqlx::FromRow)]
397 #[sqlx_gen(kind = "table", table = "users")]
398 pub struct Users {
399 pub id: i32,
400 }
401 "#;
402 let entity = parse_entity_source(source).unwrap();
403 assert!(!entity.fields[0].is_nullable);
404 assert_eq!(entity.fields[0].inner_type, "i32");
405 }
406
407 #[test]
408 fn test_option_complex_type() {
409 let source = r#"
410 #[derive(Debug, Clone, sqlx::FromRow)]
411 #[sqlx_gen(kind = "table", table = "users")]
412 pub struct Users {
413 pub created_at: Option<chrono::NaiveDateTime>,
414 }
415 "#;
416 let entity = parse_entity_source(source).unwrap();
417 assert!(entity.fields[0].is_nullable);
418 assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
419 }
420
421 #[test]
424 fn test_rust_type_preserved() {
425 let source = r#"
426 #[derive(Debug, Clone, sqlx::FromRow)]
427 #[sqlx_gen(kind = "table", table = "users")]
428 pub struct Users {
429 pub id: uuid::Uuid,
430 }
431 "#;
432 let entity = parse_entity_source(source).unwrap();
433 assert!(entity.fields[0].rust_type.contains("Uuid"));
434 }
435
436 #[test]
439 fn test_no_from_row_struct() {
440 let source = r#"
441 pub struct NotAnEntity {
442 pub id: i32,
443 }
444 "#;
445 let result = parse_entity_source(source);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn test_empty_source() {
451 let result = parse_entity_source("");
452 assert!(result.is_err());
453 }
454
455 #[test]
458 fn test_fallback_table_name_to_struct_name() {
459 let source = r#"
460 #[derive(Debug, Clone, sqlx::FromRow)]
461 pub struct Users {
462 pub id: i32,
463 }
464 "#;
465 let entity = parse_entity_source(source).unwrap();
466 assert_eq!(entity.table_name, "Users");
467 }
468
469 #[test]
472 fn test_pk_with_rename() {
473 let source = r#"
474 #[derive(Debug, Clone, sqlx::FromRow)]
475 #[sqlx_gen(kind = "table", table = "items")]
476 pub struct Items {
477 #[sqlx_gen(primary_key)]
478 #[sqlx(rename = "itemID")]
479 pub item_id: i32,
480 }
481 "#;
482 let entity = parse_entity_source(source).unwrap();
483 let f = &entity.fields[0];
484 assert!(f.is_primary_key);
485 assert_eq!(f.column_name, "itemID");
486 assert_eq!(f.rust_name, "item_id");
487 }
488
489 #[test]
490 fn test_full_entity() {
491 let source = r#"
492 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
493 #[sqlx_gen(kind = "table", table = "users")]
494 pub struct Users {
495 #[sqlx_gen(primary_key)]
496 pub id: i32,
497 pub name: String,
498 pub email: Option<String>,
499 #[sqlx(rename = "createdAt")]
500 pub created_at: chrono::NaiveDateTime,
501 }
502 "#;
503 let entity = parse_entity_source(source).unwrap();
504 assert_eq!(entity.struct_name, "Users");
505 assert_eq!(entity.table_name, "users");
506 assert!(!entity.is_view);
507 assert_eq!(entity.fields.len(), 4);
508
509 assert!(entity.fields[0].is_primary_key);
510 assert_eq!(entity.fields[0].rust_name, "id");
511
512 assert!(!entity.fields[1].is_primary_key);
513 assert_eq!(entity.fields[1].rust_type, "String");
514
515 assert!(entity.fields[2].is_nullable);
516 assert_eq!(entity.fields[2].inner_type, "String");
517
518 assert_eq!(entity.fields[3].column_name, "createdAt");
519 assert_eq!(entity.fields[3].rust_name, "created_at");
520 }
521
522 #[test]
525 fn test_imports_extracted() {
526 let source = r#"
527 use chrono::{DateTime, Utc};
528 use uuid::Uuid;
529 use serde::{Serialize, Deserialize};
530
531 #[derive(Debug, Clone, sqlx::FromRow)]
532 #[sqlx_gen(kind = "table", table = "users")]
533 pub struct Users {
534 pub id: Uuid,
535 pub created_at: DateTime<Utc>,
536 }
537 "#;
538 let entity = parse_entity_source(source).unwrap();
539 assert_eq!(entity.imports.len(), 2);
540 assert!(entity.imports.iter().any(|i| i.contains("chrono")));
541 assert!(entity.imports.iter().any(|i| i.contains("uuid")));
542 assert!(!entity.imports.iter().any(|i| i.contains("serde")));
544 }
545
546 #[test]
547 fn test_imports_empty_when_none() {
548 let source = r#"
549 #[derive(Debug, Clone, sqlx::FromRow)]
550 #[sqlx_gen(kind = "table", table = "users")]
551 pub struct Users {
552 pub id: i32,
553 }
554 "#;
555 let entity = parse_entity_source(source).unwrap();
556 assert!(entity.imports.is_empty());
557 }
558
559 #[test]
560 fn test_imports_keep_serde_json() {
561 let source = r#"
562 use serde::{Serialize, Deserialize};
563 use serde_json::Value;
564
565 #[derive(Debug, Clone, sqlx::FromRow)]
566 #[sqlx_gen(kind = "table", table = "users")]
567 pub struct Users {
568 pub data: Value,
569 }
570 "#;
571 let entity = parse_entity_source(source).unwrap();
572 assert_eq!(entity.imports.len(), 1);
573 assert!(entity.imports[0].contains("serde_json"));
574 }
575
576 #[test]
577 fn test_imports_exclude_sqlx() {
578 let source = r#"
579 use sqlx::types::Uuid;
580 use chrono::NaiveDateTime;
581
582 #[derive(Debug, Clone, sqlx::FromRow)]
583 #[sqlx_gen(kind = "table", table = "users")]
584 pub struct Users {
585 pub id: i32,
586 }
587 "#;
588 let entity = parse_entity_source(source).unwrap();
589 assert_eq!(entity.imports.len(), 1);
590 assert!(entity.imports[0].contains("chrono"));
591 }
592}