1use std::path::Path;
2
3use quote::ToTokens;
4
5#[derive(Debug, Clone)]
7pub struct ParsedField {
8 pub rust_name: String,
10 pub column_name: String,
12 pub rust_type: String,
14 pub is_nullable: bool,
16 pub inner_type: String,
18 pub is_primary_key: bool,
20}
21
22#[derive(Debug, Clone)]
24pub struct ParsedEntity {
25 pub struct_name: String,
27 pub table_name: String,
29 pub is_view: bool,
31 pub fields: Vec<ParsedField>,
33 pub imports: Vec<String>,
35}
36
37pub fn parse_entity_file(path: &Path) -> crate::error::Result<ParsedEntity> {
39 let source = std::fs::read_to_string(path).map_err(crate::error::Error::Io)?;
40 parse_entity_source(&source).map_err(|e| {
41 crate::error::Error::Config(format!("{}: {}", path.display(), e))
42 })
43}
44
45pub fn parse_entity_source(source: &str) -> Result<ParsedEntity, String> {
47 let syntax = syn::parse_file(source).map_err(|e| format!("Failed to parse: {}", e))?;
48
49 let imports = extract_use_imports(&syntax);
51
52 for item in &syntax.items {
53 if let syn::Item::Struct(item_struct) = item {
54 if has_from_row_derive(item_struct) {
55 let mut entity = extract_entity(item_struct)?;
56 entity.imports = imports;
57 return Ok(entity);
58 }
59 }
60 }
61
62 Err("No struct with sqlx::FromRow derive found".to_string())
63}
64
65fn has_from_row_derive(item: &syn::ItemStruct) -> bool {
67 for attr in &item.attrs {
68 if attr.path().is_ident("derive") {
69 let tokens = attr.meta.to_token_stream().to_string();
70 if tokens.contains("FromRow") {
71 return true;
72 }
73 }
74 }
75 false
76}
77
78fn extract_use_imports(file: &syn::File) -> Vec<String> {
81 file.items
82 .iter()
83 .filter_map(|item| {
84 if let syn::Item::Use(use_item) = item {
85 let text = use_item.to_token_stream().to_string();
86 if text.contains("serde") || text.contains("sqlx") {
88 return None;
89 }
90 let normalized = normalize_use_statement(&text);
92 Some(normalized)
93 } else {
94 None
95 }
96 })
97 .collect()
98}
99
100fn normalize_use_statement(s: &str) -> String {
102 s.replace(" :: ", "::")
103 .replace(":: ", "::")
104 .replace(" ::", "::")
105 .replace("{ ", "{")
106 .replace(" }", "}")
107 .replace(" ,", ",")
108 .replace(" ;", ";")
109}
110
111fn extract_entity(item: &syn::ItemStruct) -> Result<ParsedEntity, String> {
113 let struct_name = item.ident.to_string();
114
115 let (kind, table_name) = parse_sqlx_gen_struct_attrs(&item.attrs);
116 let is_view = kind.as_deref() == Some("view");
117
118 let table_name = table_name.unwrap_or_else(|| struct_name.clone());
120
121 let fields = match &item.fields {
122 syn::Fields::Named(named) => {
123 named
124 .named
125 .iter()
126 .map(extract_field)
127 .collect::<Result<Vec<_>, _>>()?
128 }
129 _ => return Err("Expected named fields".to_string()),
130 };
131
132 Ok(ParsedEntity {
133 struct_name,
134 table_name,
135 is_view,
136 fields,
137 imports: Vec::new(), })
139}
140
141fn parse_sqlx_gen_struct_attrs(attrs: &[syn::Attribute]) -> (Option<String>, Option<String>) {
144 let mut kind = None;
145 let mut table_name = None;
146
147 for attr in attrs {
148 if attr.path().is_ident("sqlx_gen") {
149 let tokens = attr.meta.to_token_stream().to_string();
150 if let Some(k) = extract_attr_value(&tokens, "kind") {
151 kind = Some(k);
152 }
153 if let Some(t) = extract_attr_value(&tokens, "table") {
154 table_name = Some(t);
155 }
156 }
157 }
158
159 (kind, table_name)
160}
161
162fn extract_attr_value(tokens: &str, key: &str) -> Option<String> {
165 let pattern = format!("{} = \"", key);
166 let start = tokens.find(&pattern)? + pattern.len();
167 let rest = &tokens[start..];
168 let end = rest.find('"')?;
169 Some(rest[..end].to_string())
170}
171
172fn extract_field(field: &syn::Field) -> Result<ParsedField, String> {
174 let rust_name = field
175 .ident
176 .as_ref()
177 .ok_or("Unnamed field")?
178 .to_string();
179
180 let column_name = get_sqlx_rename(&field.attrs).unwrap_or_else(|| rust_name.clone());
181 let is_primary_key = has_sqlx_gen_primary_key(&field.attrs);
182
183 let rust_type = field.ty.to_token_stream().to_string();
184 let (is_nullable, inner_type) = extract_option_type(&field.ty);
185 let inner_type = if is_nullable {
186 inner_type
187 } else {
188 rust_type.clone()
189 };
190
191 Ok(ParsedField {
192 rust_name,
193 column_name,
194 rust_type,
195 is_nullable,
196 inner_type,
197 is_primary_key,
198 })
199}
200
201fn has_sqlx_gen_primary_key(attrs: &[syn::Attribute]) -> bool {
203 for attr in attrs {
204 if attr.path().is_ident("sqlx_gen") {
205 let tokens = attr.meta.to_token_stream().to_string();
206 if tokens.contains("primary_key") {
207 return true;
208 }
209 }
210 }
211 false
212}
213
214fn get_sqlx_rename(attrs: &[syn::Attribute]) -> Option<String> {
216 for attr in attrs {
217 if attr.path().is_ident("sqlx") {
218 let tokens = attr.meta.to_token_stream().to_string();
219 return extract_attr_value(&tokens, "rename");
220 }
221 }
222 None
223}
224
225fn extract_option_type(ty: &syn::Type) -> (bool, String) {
227 if let syn::Type::Path(type_path) = ty {
228 if let Some(segment) = type_path.path.segments.last() {
229 if segment.ident == "Option" {
230 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
231 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
232 return (true, inner.to_token_stream().to_string());
233 }
234 }
235 }
236 }
237 }
238 (false, String::new())
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
248 fn test_parse_simple_table() {
249 let source = r#"
250 #[derive(Debug, Clone, sqlx::FromRow)]
251 #[sqlx_gen(kind = "table", table = "users")]
252 pub struct Users {
253 pub id: i32,
254 pub name: String,
255 }
256 "#;
257 let entity = parse_entity_source(source).unwrap();
258 assert_eq!(entity.struct_name, "Users");
259 assert_eq!(entity.table_name, "users");
260 assert!(!entity.is_view);
261 assert_eq!(entity.fields.len(), 2);
262 }
263
264 #[test]
265 fn test_parse_view() {
266 let source = r#"
267 #[derive(Debug, Clone, sqlx::FromRow)]
268 #[sqlx_gen(kind = "view", table = "active_users")]
269 pub struct ActiveUsers {
270 pub id: i32,
271 }
272 "#;
273 let entity = parse_entity_source(source).unwrap();
274 assert!(entity.is_view);
275 assert_eq!(entity.table_name, "active_users");
276 }
277
278 #[test]
279 fn test_parse_table_not_view() {
280 let source = r#"
281 #[derive(Debug, Clone, sqlx::FromRow)]
282 #[sqlx_gen(kind = "table", table = "users")]
283 pub struct Users {
284 pub id: i32,
285 }
286 "#;
287 let entity = parse_entity_source(source).unwrap();
288 assert!(!entity.is_view);
289 }
290
291 #[test]
294 fn test_parse_primary_key() {
295 let source = r#"
296 #[derive(Debug, Clone, sqlx::FromRow)]
297 #[sqlx_gen(kind = "table", table = "users")]
298 pub struct Users {
299 #[sqlx_gen(primary_key)]
300 pub id: i32,
301 pub name: String,
302 }
303 "#;
304 let entity = parse_entity_source(source).unwrap();
305 assert!(entity.fields[0].is_primary_key);
306 assert!(!entity.fields[1].is_primary_key);
307 }
308
309 #[test]
310 fn test_composite_primary_key() {
311 let source = r#"
312 #[derive(Debug, Clone, sqlx::FromRow)]
313 #[sqlx_gen(kind = "table", table = "user_roles")]
314 pub struct UserRoles {
315 #[sqlx_gen(primary_key)]
316 pub user_id: i32,
317 #[sqlx_gen(primary_key)]
318 pub role_id: i32,
319 }
320 "#;
321 let entity = parse_entity_source(source).unwrap();
322 assert!(entity.fields[0].is_primary_key);
323 assert!(entity.fields[1].is_primary_key);
324 }
325
326 #[test]
327 fn test_no_primary_key() {
328 let source = r#"
329 #[derive(Debug, Clone, sqlx::FromRow)]
330 #[sqlx_gen(kind = "table", table = "logs")]
331 pub struct Logs {
332 pub message: String,
333 }
334 "#;
335 let entity = parse_entity_source(source).unwrap();
336 assert!(!entity.fields[0].is_primary_key);
337 }
338
339 #[test]
342 fn test_sqlx_rename() {
343 let source = r#"
344 #[derive(Debug, Clone, sqlx::FromRow)]
345 #[sqlx_gen(kind = "table", table = "connector")]
346 pub struct Connector {
347 #[sqlx(rename = "type")]
348 pub connector_type: String,
349 }
350 "#;
351 let entity = parse_entity_source(source).unwrap();
352 assert_eq!(entity.fields[0].rust_name, "connector_type");
353 assert_eq!(entity.fields[0].column_name, "type");
354 }
355
356 #[test]
357 fn test_no_rename_uses_field_name() {
358 let source = r#"
359 #[derive(Debug, Clone, sqlx::FromRow)]
360 #[sqlx_gen(kind = "table", table = "users")]
361 pub struct Users {
362 pub name: String,
363 }
364 "#;
365 let entity = parse_entity_source(source).unwrap();
366 assert_eq!(entity.fields[0].rust_name, "name");
367 assert_eq!(entity.fields[0].column_name, "name");
368 }
369
370 #[test]
373 fn test_option_field_nullable() {
374 let source = r#"
375 #[derive(Debug, Clone, sqlx::FromRow)]
376 #[sqlx_gen(kind = "table", table = "users")]
377 pub struct Users {
378 pub email: Option<String>,
379 }
380 "#;
381 let entity = parse_entity_source(source).unwrap();
382 assert!(entity.fields[0].is_nullable);
383 assert_eq!(entity.fields[0].inner_type, "String");
384 }
385
386 #[test]
387 fn test_non_option_not_nullable() {
388 let source = r#"
389 #[derive(Debug, Clone, sqlx::FromRow)]
390 #[sqlx_gen(kind = "table", table = "users")]
391 pub struct Users {
392 pub id: i32,
393 }
394 "#;
395 let entity = parse_entity_source(source).unwrap();
396 assert!(!entity.fields[0].is_nullable);
397 assert_eq!(entity.fields[0].inner_type, "i32");
398 }
399
400 #[test]
401 fn test_option_complex_type() {
402 let source = r#"
403 #[derive(Debug, Clone, sqlx::FromRow)]
404 #[sqlx_gen(kind = "table", table = "users")]
405 pub struct Users {
406 pub created_at: Option<chrono::NaiveDateTime>,
407 }
408 "#;
409 let entity = parse_entity_source(source).unwrap();
410 assert!(entity.fields[0].is_nullable);
411 assert!(entity.fields[0].inner_type.contains("NaiveDateTime"));
412 }
413
414 #[test]
417 fn test_rust_type_preserved() {
418 let source = r#"
419 #[derive(Debug, Clone, sqlx::FromRow)]
420 #[sqlx_gen(kind = "table", table = "users")]
421 pub struct Users {
422 pub id: uuid::Uuid,
423 }
424 "#;
425 let entity = parse_entity_source(source).unwrap();
426 assert!(entity.fields[0].rust_type.contains("Uuid"));
427 }
428
429 #[test]
432 fn test_no_from_row_struct() {
433 let source = r#"
434 pub struct NotAnEntity {
435 pub id: i32,
436 }
437 "#;
438 let result = parse_entity_source(source);
439 assert!(result.is_err());
440 }
441
442 #[test]
443 fn test_empty_source() {
444 let result = parse_entity_source("");
445 assert!(result.is_err());
446 }
447
448 #[test]
451 fn test_fallback_table_name_to_struct_name() {
452 let source = r#"
453 #[derive(Debug, Clone, sqlx::FromRow)]
454 pub struct Users {
455 pub id: i32,
456 }
457 "#;
458 let entity = parse_entity_source(source).unwrap();
459 assert_eq!(entity.table_name, "Users");
460 }
461
462 #[test]
465 fn test_pk_with_rename() {
466 let source = r#"
467 #[derive(Debug, Clone, sqlx::FromRow)]
468 #[sqlx_gen(kind = "table", table = "items")]
469 pub struct Items {
470 #[sqlx_gen(primary_key)]
471 #[sqlx(rename = "itemID")]
472 pub item_id: i32,
473 }
474 "#;
475 let entity = parse_entity_source(source).unwrap();
476 let f = &entity.fields[0];
477 assert!(f.is_primary_key);
478 assert_eq!(f.column_name, "itemID");
479 assert_eq!(f.rust_name, "item_id");
480 }
481
482 #[test]
483 fn test_full_entity() {
484 let source = r#"
485 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::FromRow)]
486 #[sqlx_gen(kind = "table", table = "users")]
487 pub struct Users {
488 #[sqlx_gen(primary_key)]
489 pub id: i32,
490 pub name: String,
491 pub email: Option<String>,
492 #[sqlx(rename = "createdAt")]
493 pub created_at: chrono::NaiveDateTime,
494 }
495 "#;
496 let entity = parse_entity_source(source).unwrap();
497 assert_eq!(entity.struct_name, "Users");
498 assert_eq!(entity.table_name, "users");
499 assert!(!entity.is_view);
500 assert_eq!(entity.fields.len(), 4);
501
502 assert!(entity.fields[0].is_primary_key);
503 assert_eq!(entity.fields[0].rust_name, "id");
504
505 assert!(!entity.fields[1].is_primary_key);
506 assert_eq!(entity.fields[1].rust_type, "String");
507
508 assert!(entity.fields[2].is_nullable);
509 assert_eq!(entity.fields[2].inner_type, "String");
510
511 assert_eq!(entity.fields[3].column_name, "createdAt");
512 assert_eq!(entity.fields[3].rust_name, "created_at");
513 }
514
515 #[test]
518 fn test_imports_extracted() {
519 let source = r#"
520 use chrono::{DateTime, Utc};
521 use uuid::Uuid;
522 use serde::{Serialize, Deserialize};
523
524 #[derive(Debug, Clone, sqlx::FromRow)]
525 #[sqlx_gen(kind = "table", table = "users")]
526 pub struct Users {
527 pub id: Uuid,
528 pub created_at: DateTime<Utc>,
529 }
530 "#;
531 let entity = parse_entity_source(source).unwrap();
532 assert_eq!(entity.imports.len(), 2);
533 assert!(entity.imports.iter().any(|i| i.contains("chrono")));
534 assert!(entity.imports.iter().any(|i| i.contains("uuid")));
535 assert!(!entity.imports.iter().any(|i| i.contains("serde")));
537 }
538
539 #[test]
540 fn test_imports_empty_when_none() {
541 let source = r#"
542 #[derive(Debug, Clone, sqlx::FromRow)]
543 #[sqlx_gen(kind = "table", table = "users")]
544 pub struct Users {
545 pub id: i32,
546 }
547 "#;
548 let entity = parse_entity_source(source).unwrap();
549 assert!(entity.imports.is_empty());
550 }
551
552 #[test]
553 fn test_imports_exclude_sqlx() {
554 let source = r#"
555 use sqlx::types::Uuid;
556 use chrono::NaiveDateTime;
557
558 #[derive(Debug, Clone, sqlx::FromRow)]
559 #[sqlx_gen(kind = "table", table = "users")]
560 pub struct Users {
561 pub id: i32,
562 }
563 "#;
564 let entity = parse_entity_source(source).unwrap();
565 assert_eq!(entity.imports.len(), 1);
566 assert!(entity.imports[0].contains("chrono"));
567 }
568}