1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8 #[command(subcommand)]
9 pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14 Generate {
16 #[command(subcommand)]
17 subcommand: GenerateCommand,
18 },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23 Entities(EntitiesArgs),
25 Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31 #[arg(short = 'u', long, env = "DATABASE_URL")]
33 pub database_url: String,
34
35 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37 pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42 let url = &self.database_url;
43 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44 Ok(DatabaseKind::Postgres)
45 } else if url.starts_with("mysql://") {
46 Ok(DatabaseKind::Mysql)
47 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48 Ok(DatabaseKind::Sqlite)
49 } else {
50 Err(crate::error::Error::Config(
51 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52 ))
53 }
54 }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59 #[command(flatten)]
60 pub db: DatabaseArgs,
61
62 #[arg(short = 'o', long, default_value = "src/models")]
64 pub output_dir: PathBuf,
65
66 #[arg(short = 'D', long, value_delimiter = ',')]
68 pub derives: Vec<String>,
69
70 #[arg(short = 'T', long, value_delimiter = ',')]
72 pub type_overrides: Vec<String>,
73
74 #[arg(short = 'S', long)]
76 pub single_file: bool,
77
78 #[arg(short = 't', long, value_delimiter = ',')]
80 pub tables: Option<Vec<String>>,
81
82 #[arg(short = 'x', long, value_delimiter = ',')]
84 pub exclude_tables: Option<Vec<String>>,
85
86 #[arg(short = 'v', long)]
88 pub views: bool,
89
90 #[arg(long, default_value = "chrono")]
92 pub time_crate: TimeCrate,
93
94 #[arg(short = 'n', long)]
96 pub dry_run: bool,
97}
98
99impl EntitiesArgs {
100 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
101 self.type_overrides
102 .iter()
103 .filter_map(|s| {
104 let (k, v) = s.split_once('=')?;
105 Some((k.to_string(), v.to_string()))
106 })
107 .collect()
108 }
109}
110
111#[derive(Parser, Debug)]
112pub struct CrudArgs {
113 #[arg(short = 'f', long)]
115 pub entity_file: PathBuf,
116
117 #[arg(short = 'd', long)]
119 pub db_kind: String,
120
121 #[arg(short = 'e', long)]
124 pub entities_module: Option<String>,
125
126 #[arg(short = 'o', long, default_value = "src/crud")]
128 pub output_dir: PathBuf,
129
130 #[arg(short = 'm', long, value_delimiter = ',')]
132 pub methods: Vec<String>,
133
134
135 #[arg(short = 'q', long)]
137 pub query_macro: bool,
138
139 #[arg(short = 'p', long, default_value = "private")]
141 pub pool_visibility: PoolVisibility,
142
143 #[arg(short = 'n', long)]
145 pub dry_run: bool,
146}
147
148impl CrudArgs {
149 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
150 match self.db_kind.to_lowercase().as_str() {
151 "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
152 "mysql" => Ok(DatabaseKind::Mysql),
153 "sqlite" => Ok(DatabaseKind::Sqlite),
154 other => Err(crate::error::Error::Config(format!(
155 "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
156 other
157 ))),
158 }
159 }
160
161 pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
164 match &self.entities_module {
165 Some(m) => Ok(m.clone()),
166 None => module_path_from_file(&self.entity_file),
167 }
168 }
169}
170
171fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
175 let path_str = path.to_string_lossy().replace('\\', "/");
176
177 let after_src = match path_str.rfind("/src/") {
178 Some(pos) => &path_str[pos + 5..],
179 None if path_str.starts_with("src/") => &path_str[4..],
180 _ => {
181 return Err(crate::error::Error::Config(format!(
182 "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
183 path.display()
184 )));
185 }
186 };
187
188 let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
189 let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
190
191 let module_path = format!("crate::{}", module.replace('/', "::"));
192 Ok(module_path)
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum DatabaseKind {
197 Postgres,
198 Mysql,
199 Sqlite,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
203pub enum TimeCrate {
204 #[default]
205 Chrono,
206 Time,
207}
208
209impl std::str::FromStr for TimeCrate {
210 type Err = String;
211
212 fn from_str(s: &str) -> Result<Self, Self::Err> {
213 match s {
214 "chrono" => Ok(Self::Chrono),
215 "time" => Ok(Self::Time),
216 other => Err(format!(
217 "Unknown time crate '{}'. Expected: chrono, time",
218 other
219 )),
220 }
221 }
222}
223
224impl std::fmt::Display for TimeCrate {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Self::Chrono => write!(f, "chrono"),
228 Self::Time => write!(f, "time"),
229 }
230 }
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
234pub enum PoolVisibility {
235 #[default]
236 Private,
237 Pub,
238 PubCrate,
239}
240
241impl std::str::FromStr for PoolVisibility {
242 type Err = String;
243
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 match s {
246 "private" => Ok(Self::Private),
247 "pub" => Ok(Self::Pub),
248 "pub(crate)" => Ok(Self::PubCrate),
249 other => Err(format!(
250 "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
251 other
252 )),
253 }
254 }
255}
256
257#[derive(Debug, Clone, Default)]
260pub struct Methods {
261 pub get_all: bool,
262 pub paginate: bool,
263 pub get: bool,
264 pub insert: bool,
265 pub update: bool,
266 pub overwrite: bool,
267 pub delete: bool,
268}
269
270const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "overwrite", "delete"];
271
272impl Methods {
273 pub fn from_list(names: &[String]) -> Result<Self, String> {
275 let mut m = Self::default();
276 for name in names {
277 match name.as_str() {
278 "*" => return Ok(Self::all()),
279 "get_all" => m.get_all = true,
280 "paginate" => m.paginate = true,
281 "get" => m.get = true,
282 "insert" => m.insert = true,
283 "update" => m.update = true,
284 "overwrite" => m.overwrite = true,
285 "delete" => m.delete = true,
286 other => {
287 return Err(format!(
288 "Unknown method '{}'. Valid values: *, {}",
289 other,
290 ALL_METHODS.join(", ")
291 ))
292 }
293 }
294 }
295 Ok(m)
296 }
297
298 pub fn all() -> Self {
299 Self {
300 get_all: true,
301 paginate: true,
302 get: true,
303 insert: true,
304 update: true,
305 overwrite: true,
306 delete: true,
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 fn make_db_args(url: &str) -> DatabaseArgs {
316 DatabaseArgs {
317 database_url: url.to_string(),
318 schemas: vec!["public".into()],
319 }
320 }
321
322 fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
323 EntitiesArgs {
324 db: make_db_args("postgres://localhost/db"),
325 output_dir: PathBuf::from("out"),
326 derives: vec![],
327 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
328 single_file: false,
329 tables: None,
330 exclude_tables: None,
331 views: false,
332 time_crate: TimeCrate::Chrono,
333 dry_run: false,
334 }
335 }
336
337 #[test]
340 fn test_postgres_url() {
341 let args = make_db_args("postgres://localhost/db");
342 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
343 }
344
345 #[test]
346 fn test_postgresql_url() {
347 let args = make_db_args("postgresql://localhost/db");
348 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
349 }
350
351 #[test]
352 fn test_postgres_full_url() {
353 let args = make_db_args("postgres://user:pass@host:5432/db");
354 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
355 }
356
357 #[test]
358 fn test_mysql_url() {
359 let args = make_db_args("mysql://localhost/db");
360 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
361 }
362
363 #[test]
364 fn test_mysql_full_url() {
365 let args = make_db_args("mysql://user:pass@host:3306/db");
366 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
367 }
368
369 #[test]
370 fn test_sqlite_url() {
371 let args = make_db_args("sqlite://path.db");
372 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
373 }
374
375 #[test]
376 fn test_sqlite_colon() {
377 let args = make_db_args("sqlite:path.db");
378 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
379 }
380
381 #[test]
382 fn test_sqlite_memory() {
383 let args = make_db_args("sqlite::memory:");
384 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
385 }
386
387 #[test]
388 fn test_http_url_fails() {
389 let args = make_db_args("http://example.com");
390 assert!(args.database_kind().is_err());
391 }
392
393 #[test]
394 fn test_empty_url_fails() {
395 let args = make_db_args("");
396 assert!(args.database_kind().is_err());
397 }
398
399 #[test]
400 fn test_mongo_url_fails() {
401 let args = make_db_args("mongo://localhost");
402 assert!(args.database_kind().is_err());
403 }
404
405 #[test]
406 fn test_uppercase_postgres_fails() {
407 let args = make_db_args("POSTGRES://localhost");
408 assert!(args.database_kind().is_err());
409 }
410
411 #[test]
414 fn test_overrides_empty() {
415 let args = make_entities_args_with_overrides(vec![]);
416 assert!(args.parse_type_overrides().is_empty());
417 }
418
419 #[test]
420 fn test_overrides_single() {
421 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
422 let map = args.parse_type_overrides();
423 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
424 }
425
426 #[test]
427 fn test_overrides_multiple() {
428 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
429 let map = args.parse_type_overrides();
430 assert_eq!(map.len(), 2);
431 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
432 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
433 }
434
435 #[test]
436 fn test_overrides_malformed_skipped() {
437 let args = make_entities_args_with_overrides(vec!["noequals"]);
438 assert!(args.parse_type_overrides().is_empty());
439 }
440
441 #[test]
442 fn test_overrides_mixed_valid_invalid() {
443 let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
444 let map = args.parse_type_overrides();
445 assert_eq!(map.len(), 1);
446 assert_eq!(map.get("good").unwrap(), "val");
447 }
448
449 #[test]
450 fn test_overrides_equals_in_value() {
451 let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
452 let map = args.parse_type_overrides();
453 assert_eq!(map.get("key").unwrap(), "val=ue");
454 }
455
456 #[test]
457 fn test_overrides_empty_key() {
458 let args = make_entities_args_with_overrides(vec!["=value"]);
459 let map = args.parse_type_overrides();
460 assert_eq!(map.get("").unwrap(), "value");
461 }
462
463 #[test]
464 fn test_overrides_empty_value() {
465 let args = make_entities_args_with_overrides(vec!["key="]);
466 let map = args.parse_type_overrides();
467 assert_eq!(map.get("key").unwrap(), "");
468 }
469
470 #[test]
473 fn test_exclude_tables_default_none() {
474 let args = make_entities_args_with_overrides(vec![]);
475 assert!(args.exclude_tables.is_none());
476 }
477
478 #[test]
479 fn test_exclude_tables_set() {
480 let mut args = make_entities_args_with_overrides(vec![]);
481 args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
482 assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
483 assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
484 }
485
486 #[test]
489 fn test_methods_default_all_false() {
490 let m = Methods::default();
491 assert!(!m.get_all);
492 assert!(!m.paginate);
493 assert!(!m.get);
494 assert!(!m.insert);
495 assert!(!m.update);
496 assert!(!m.overwrite);
497 assert!(!m.delete);
498 }
499
500 #[test]
501 fn test_methods_star() {
502 let m = Methods::from_list(&["*".to_string()]).unwrap();
503 assert!(m.get_all);
504 assert!(m.paginate);
505 assert!(m.get);
506 assert!(m.insert);
507 assert!(m.update);
508 assert!(m.overwrite);
509 assert!(m.delete);
510 }
511
512 #[test]
513 fn test_methods_single() {
514 let m = Methods::from_list(&["get".to_string()]).unwrap();
515 assert!(m.get);
516 assert!(!m.get_all);
517 assert!(!m.insert);
518 }
519
520 #[test]
521 fn test_methods_multiple() {
522 let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
523 assert!(m.get_all);
524 assert!(m.delete);
525 assert!(!m.insert);
526 assert!(!m.paginate);
527 }
528
529 #[test]
530 fn test_methods_unknown_fails() {
531 let result = Methods::from_list(&["unknown".to_string()]);
532 assert!(result.is_err());
533 assert!(result.unwrap_err().contains("Unknown method"));
534 }
535
536 #[test]
537 fn test_methods_all() {
538 let m = Methods::all();
539 assert!(m.get_all);
540 assert!(m.paginate);
541 assert!(m.get);
542 assert!(m.insert);
543 assert!(m.update);
544 assert!(m.overwrite);
545 assert!(m.delete);
546 }
547
548 #[test]
549 fn test_parse_overwrite_method() {
550 let m = Methods::from_list(&["overwrite".to_string()]).unwrap();
551 assert!(m.overwrite);
552 assert!(!m.update);
553 }
554
555 #[test]
558 fn test_module_path_simple() {
559 let p = PathBuf::from("src/models/users.rs");
560 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
561 }
562
563 #[test]
564 fn test_module_path_mod_rs() {
565 let p = PathBuf::from("src/models/mod.rs");
566 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
567 }
568
569 #[test]
570 fn test_module_path_nested() {
571 let p = PathBuf::from("src/db/entities/agent.rs");
572 assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
573 }
574
575 #[test]
576 fn test_module_path_absolute_with_src() {
577 let p = PathBuf::from("/home/user/project/src/models/users.rs");
578 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
579 }
580
581 #[test]
582 fn test_module_path_relative_with_src() {
583 let p = PathBuf::from("../other_project/src/models/users.rs");
584 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
585 }
586
587 #[test]
588 fn test_module_path_no_src_fails() {
589 let p = PathBuf::from("models/users.rs");
590 assert!(module_path_from_file(&p).is_err());
591 }
592
593 #[test]
594 fn test_module_path_deeply_nested_mod() {
595 let p = PathBuf::from("src/a/b/c/mod.rs");
596 assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
597 }
598
599 #[test]
600 fn test_module_path_src_root_file() {
601 let p = PathBuf::from("src/lib.rs");
602 assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
603 }
604}