1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8 #[command(subcommand)]
9 pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14 Generate {
16 #[command(subcommand)]
17 subcommand: GenerateCommand,
18 },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23 Entities(EntitiesArgs),
25 Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31 #[arg(short = 'u', long, env = "DATABASE_URL")]
33 pub database_url: String,
34
35 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37 pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42 let url = &self.database_url;
43 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44 Ok(DatabaseKind::Postgres)
45 } else if url.starts_with("mysql://") {
46 Ok(DatabaseKind::Mysql)
47 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48 Ok(DatabaseKind::Sqlite)
49 } else {
50 Err(crate::error::Error::Config(
51 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52 ))
53 }
54 }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59 #[command(flatten)]
60 pub db: DatabaseArgs,
61
62 #[arg(short = 'o', long, default_value = "src/models")]
64 pub output_dir: PathBuf,
65
66 #[arg(short = 'D', long, value_delimiter = ',')]
68 pub derives: Vec<String>,
69
70 #[arg(short = 'T', long, value_delimiter = ',')]
72 pub type_overrides: Vec<String>,
73
74 #[arg(short = 'S', long)]
76 pub single_file: bool,
77
78 #[arg(short = 't', long, value_delimiter = ',')]
80 pub tables: Option<Vec<String>>,
81
82 #[arg(short = 'x', long, value_delimiter = ',')]
84 pub exclude_tables: Option<Vec<String>>,
85
86 #[arg(short = 'v', long)]
88 pub views: bool,
89
90 #[arg(long, default_value = "chrono")]
92 pub time_crate: TimeCrate,
93
94 #[arg(short = 'n', long)]
96 pub dry_run: bool,
97}
98
99impl EntitiesArgs {
100 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
101 self.type_overrides
102 .iter()
103 .filter_map(|s| {
104 let (k, v) = s.split_once('=')?;
105 Some((k.to_string(), v.to_string()))
106 })
107 .collect()
108 }
109}
110
111#[derive(Parser, Debug)]
112pub struct CrudArgs {
113 #[arg(short = 'f', long)]
115 pub entity_file: PathBuf,
116
117 #[arg(short = 'd', long)]
119 pub db_kind: String,
120
121 #[arg(short = 'e', long)]
124 pub entities_module: Option<String>,
125
126 #[arg(short = 'o', long, default_value = "src/crud")]
128 pub output_dir: PathBuf,
129
130 #[arg(short = 'm', long, value_delimiter = ',')]
132 pub methods: Vec<String>,
133
134
135 #[arg(short = 'q', long)]
137 pub query_macro: bool,
138
139 #[arg(short = 'p', long, default_value = "private")]
141 pub pool_visibility: PoolVisibility,
142
143 #[arg(short = 'n', long)]
145 pub dry_run: bool,
146}
147
148impl CrudArgs {
149 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
150 match self.db_kind.to_lowercase().as_str() {
151 "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
152 "mysql" => Ok(DatabaseKind::Mysql),
153 "sqlite" => Ok(DatabaseKind::Sqlite),
154 other => Err(crate::error::Error::Config(format!(
155 "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
156 other
157 ))),
158 }
159 }
160
161 pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
164 match &self.entities_module {
165 Some(m) => Ok(m.clone()),
166 None => module_path_from_file(&self.entity_file),
167 }
168 }
169}
170
171fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
175 let path_str = path.to_string_lossy().replace('\\', "/");
176
177 let after_src = match path_str.rfind("/src/") {
178 Some(pos) => &path_str[pos + 5..],
179 None if path_str.starts_with("src/") => &path_str[4..],
180 _ => {
181 return Err(crate::error::Error::Config(format!(
182 "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
183 path.display()
184 )));
185 }
186 };
187
188 let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
189 let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
190
191 let module_path = format!("crate::{}", module.replace('/', "::"));
192 Ok(module_path)
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum DatabaseKind {
197 Postgres,
198 Mysql,
199 Sqlite,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
203pub enum TimeCrate {
204 #[default]
205 Chrono,
206 Time,
207}
208
209impl std::str::FromStr for TimeCrate {
210 type Err = String;
211
212 fn from_str(s: &str) -> Result<Self, Self::Err> {
213 match s {
214 "chrono" => Ok(Self::Chrono),
215 "time" => Ok(Self::Time),
216 other => Err(format!(
217 "Unknown time crate '{}'. Expected: chrono, time",
218 other
219 )),
220 }
221 }
222}
223
224impl std::fmt::Display for TimeCrate {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Self::Chrono => write!(f, "chrono"),
228 Self::Time => write!(f, "time"),
229 }
230 }
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
234pub enum PoolVisibility {
235 #[default]
236 Private,
237 Pub,
238 PubCrate,
239}
240
241impl std::str::FromStr for PoolVisibility {
242 type Err = String;
243
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 match s {
246 "private" => Ok(Self::Private),
247 "pub" => Ok(Self::Pub),
248 "pub(crate)" => Ok(Self::PubCrate),
249 other => Err(format!(
250 "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
251 other
252 )),
253 }
254 }
255}
256
257#[derive(Debug, Clone, Default)]
260pub struct Methods {
261 pub get_all: bool,
262 pub paginate: bool,
263 pub get: bool,
264 pub insert: bool,
265 pub insert_many: bool,
266 pub update: bool,
267 pub overwrite: bool,
268 pub delete: bool,
269}
270
271const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "insert_many", "update", "overwrite", "delete"];
272
273impl Methods {
274 pub fn from_list(names: &[String]) -> Result<Self, String> {
276 let mut m = Self::default();
277 for name in names {
278 match name.as_str() {
279 "*" => return Ok(Self::all()),
280 "get_all" => m.get_all = true,
281 "paginate" => m.paginate = true,
282 "get" => m.get = true,
283 "insert" => m.insert = true,
284 "insert_many" => m.insert_many = true,
285 "update" => m.update = true,
286 "overwrite" => m.overwrite = true,
287 "delete" => m.delete = true,
288 other => {
289 return Err(format!(
290 "Unknown method '{}'. Valid values: *, {}",
291 other,
292 ALL_METHODS.join(", ")
293 ))
294 }
295 }
296 }
297 Ok(m)
298 }
299
300 pub fn all() -> Self {
301 Self {
302 get_all: true,
303 paginate: true,
304 get: true,
305 insert: true,
306 insert_many: true,
307 update: true,
308 overwrite: true,
309 delete: true,
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 fn make_db_args(url: &str) -> DatabaseArgs {
319 DatabaseArgs {
320 database_url: url.to_string(),
321 schemas: vec!["public".into()],
322 }
323 }
324
325 fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
326 EntitiesArgs {
327 db: make_db_args("postgres://localhost/db"),
328 output_dir: PathBuf::from("out"),
329 derives: vec![],
330 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
331 single_file: false,
332 tables: None,
333 exclude_tables: None,
334 views: false,
335 time_crate: TimeCrate::Chrono,
336 dry_run: false,
337 }
338 }
339
340 #[test]
343 fn test_postgres_url() {
344 let args = make_db_args("postgres://localhost/db");
345 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
346 }
347
348 #[test]
349 fn test_postgresql_url() {
350 let args = make_db_args("postgresql://localhost/db");
351 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
352 }
353
354 #[test]
355 fn test_postgres_full_url() {
356 let args = make_db_args("postgres://user:pass@host:5432/db");
357 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
358 }
359
360 #[test]
361 fn test_mysql_url() {
362 let args = make_db_args("mysql://localhost/db");
363 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
364 }
365
366 #[test]
367 fn test_mysql_full_url() {
368 let args = make_db_args("mysql://user:pass@host:3306/db");
369 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
370 }
371
372 #[test]
373 fn test_sqlite_url() {
374 let args = make_db_args("sqlite://path.db");
375 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
376 }
377
378 #[test]
379 fn test_sqlite_colon() {
380 let args = make_db_args("sqlite:path.db");
381 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
382 }
383
384 #[test]
385 fn test_sqlite_memory() {
386 let args = make_db_args("sqlite::memory:");
387 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
388 }
389
390 #[test]
391 fn test_http_url_fails() {
392 let args = make_db_args("http://example.com");
393 assert!(args.database_kind().is_err());
394 }
395
396 #[test]
397 fn test_empty_url_fails() {
398 let args = make_db_args("");
399 assert!(args.database_kind().is_err());
400 }
401
402 #[test]
403 fn test_mongo_url_fails() {
404 let args = make_db_args("mongo://localhost");
405 assert!(args.database_kind().is_err());
406 }
407
408 #[test]
409 fn test_uppercase_postgres_fails() {
410 let args = make_db_args("POSTGRES://localhost");
411 assert!(args.database_kind().is_err());
412 }
413
414 #[test]
417 fn test_overrides_empty() {
418 let args = make_entities_args_with_overrides(vec![]);
419 assert!(args.parse_type_overrides().is_empty());
420 }
421
422 #[test]
423 fn test_overrides_single() {
424 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
425 let map = args.parse_type_overrides();
426 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
427 }
428
429 #[test]
430 fn test_overrides_multiple() {
431 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
432 let map = args.parse_type_overrides();
433 assert_eq!(map.len(), 2);
434 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
435 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
436 }
437
438 #[test]
439 fn test_overrides_malformed_skipped() {
440 let args = make_entities_args_with_overrides(vec!["noequals"]);
441 assert!(args.parse_type_overrides().is_empty());
442 }
443
444 #[test]
445 fn test_overrides_mixed_valid_invalid() {
446 let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
447 let map = args.parse_type_overrides();
448 assert_eq!(map.len(), 1);
449 assert_eq!(map.get("good").unwrap(), "val");
450 }
451
452 #[test]
453 fn test_overrides_equals_in_value() {
454 let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
455 let map = args.parse_type_overrides();
456 assert_eq!(map.get("key").unwrap(), "val=ue");
457 }
458
459 #[test]
460 fn test_overrides_empty_key() {
461 let args = make_entities_args_with_overrides(vec!["=value"]);
462 let map = args.parse_type_overrides();
463 assert_eq!(map.get("").unwrap(), "value");
464 }
465
466 #[test]
467 fn test_overrides_empty_value() {
468 let args = make_entities_args_with_overrides(vec!["key="]);
469 let map = args.parse_type_overrides();
470 assert_eq!(map.get("key").unwrap(), "");
471 }
472
473 #[test]
476 fn test_exclude_tables_default_none() {
477 let args = make_entities_args_with_overrides(vec![]);
478 assert!(args.exclude_tables.is_none());
479 }
480
481 #[test]
482 fn test_exclude_tables_set() {
483 let mut args = make_entities_args_with_overrides(vec![]);
484 args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
485 assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
486 assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
487 }
488
489 #[test]
492 fn test_methods_default_all_false() {
493 let m = Methods::default();
494 assert!(!m.get_all);
495 assert!(!m.paginate);
496 assert!(!m.get);
497 assert!(!m.insert);
498 assert!(!m.insert_many);
499 assert!(!m.update);
500 assert!(!m.overwrite);
501 assert!(!m.delete);
502 }
503
504 #[test]
505 fn test_methods_star() {
506 let m = Methods::from_list(&["*".to_string()]).unwrap();
507 assert!(m.get_all);
508 assert!(m.paginate);
509 assert!(m.get);
510 assert!(m.insert);
511 assert!(m.insert_many);
512 assert!(m.update);
513 assert!(m.overwrite);
514 assert!(m.delete);
515 }
516
517 #[test]
518 fn test_methods_single() {
519 let m = Methods::from_list(&["get".to_string()]).unwrap();
520 assert!(m.get);
521 assert!(!m.get_all);
522 assert!(!m.insert);
523 }
524
525 #[test]
526 fn test_methods_multiple() {
527 let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
528 assert!(m.get_all);
529 assert!(m.delete);
530 assert!(!m.insert);
531 assert!(!m.paginate);
532 }
533
534 #[test]
535 fn test_methods_unknown_fails() {
536 let result = Methods::from_list(&["unknown".to_string()]);
537 assert!(result.is_err());
538 assert!(result.unwrap_err().contains("Unknown method"));
539 }
540
541 #[test]
542 fn test_methods_all() {
543 let m = Methods::all();
544 assert!(m.get_all);
545 assert!(m.paginate);
546 assert!(m.get);
547 assert!(m.insert);
548 assert!(m.insert_many);
549 assert!(m.update);
550 assert!(m.overwrite);
551 assert!(m.delete);
552 }
553
554 #[test]
555 fn test_parse_overwrite_method() {
556 let m = Methods::from_list(&["overwrite".to_string()]).unwrap();
557 assert!(m.overwrite);
558 assert!(!m.update);
559 }
560
561 #[test]
562 fn test_parse_insert_many_method() {
563 let m = Methods::from_list(&["insert_many".to_string()]).unwrap();
564 assert!(m.insert_many);
565 assert!(!m.insert);
566 assert!(!m.get);
567 }
568
569 #[test]
572 fn test_module_path_simple() {
573 let p = PathBuf::from("src/models/users.rs");
574 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
575 }
576
577 #[test]
578 fn test_module_path_mod_rs() {
579 let p = PathBuf::from("src/models/mod.rs");
580 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
581 }
582
583 #[test]
584 fn test_module_path_nested() {
585 let p = PathBuf::from("src/db/entities/agent.rs");
586 assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
587 }
588
589 #[test]
590 fn test_module_path_absolute_with_src() {
591 let p = PathBuf::from("/home/user/project/src/models/users.rs");
592 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
593 }
594
595 #[test]
596 fn test_module_path_relative_with_src() {
597 let p = PathBuf::from("../other_project/src/models/users.rs");
598 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
599 }
600
601 #[test]
602 fn test_module_path_no_src_fails() {
603 let p = PathBuf::from("models/users.rs");
604 assert!(module_path_from_file(&p).is_err());
605 }
606
607 #[test]
608 fn test_module_path_deeply_nested_mod() {
609 let p = PathBuf::from("src/a/b/c/mod.rs");
610 assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
611 }
612
613 #[test]
614 fn test_module_path_src_root_file() {
615 let p = PathBuf::from("src/lib.rs");
616 assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
617 }
618}