1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8 #[command(subcommand)]
9 pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14 Generate {
16 #[command(subcommand)]
17 subcommand: GenerateCommand,
18 },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23 Entities(EntitiesArgs),
25 Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31 #[arg(short = 'u', long, env = "DATABASE_URL")]
33 pub database_url: String,
34
35 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37 pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42 let url = &self.database_url;
43 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44 Ok(DatabaseKind::Postgres)
45 } else if url.starts_with("mysql://") {
46 Ok(DatabaseKind::Mysql)
47 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48 Ok(DatabaseKind::Sqlite)
49 } else {
50 Err(crate::error::Error::Config(
51 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52 ))
53 }
54 }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59 #[command(flatten)]
60 pub db: DatabaseArgs,
61
62 #[arg(short = 'o', long, default_value = "src/models")]
64 pub output_dir: PathBuf,
65
66 #[arg(short = 'D', long, value_delimiter = ',')]
68 pub derives: Vec<String>,
69
70 #[arg(short = 'T', long, value_delimiter = ',')]
72 pub type_overrides: Vec<String>,
73
74 #[arg(short = 'S', long)]
76 pub single_file: bool,
77
78 #[arg(short = 't', long, value_delimiter = ',')]
80 pub tables: Option<Vec<String>>,
81
82 #[arg(short = 'x', long, value_delimiter = ',')]
84 pub exclude_tables: Option<Vec<String>>,
85
86 #[arg(short = 'v', long)]
88 pub views: bool,
89
90 #[arg(long, default_value = "chrono")]
92 pub time_crate: TimeCrate,
93
94 #[arg(short = 'n', long)]
96 pub dry_run: bool,
97}
98
99impl EntitiesArgs {
100 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
101 self.type_overrides
102 .iter()
103 .filter_map(|s| {
104 let (k, v) = s.split_once('=')?;
105 Some((k.to_string(), v.to_string()))
106 })
107 .collect()
108 }
109}
110
111#[derive(Parser, Debug)]
112pub struct CrudArgs {
113 #[arg(short = 'f', long)]
115 pub entity_file: PathBuf,
116
117 #[arg(short = 'd', long)]
119 pub db_kind: String,
120
121 #[arg(short = 'e', long)]
124 pub entities_module: Option<String>,
125
126 #[arg(short = 'o', long, default_value = "src/crud")]
128 pub output_dir: PathBuf,
129
130 #[arg(short = 'm', long, value_delimiter = ',')]
132 pub methods: Vec<String>,
133
134
135 #[arg(short = 'q', long)]
137 pub query_macro: bool,
138
139 #[arg(short = 'p', long, default_value = "private")]
141 pub pool_visibility: PoolVisibility,
142
143 #[arg(short = 'n', long)]
145 pub dry_run: bool,
146}
147
148impl CrudArgs {
149 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
150 match self.db_kind.to_lowercase().as_str() {
151 "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
152 "mysql" => Ok(DatabaseKind::Mysql),
153 "sqlite" => Ok(DatabaseKind::Sqlite),
154 other => Err(crate::error::Error::Config(format!(
155 "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
156 other
157 ))),
158 }
159 }
160
161 pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
164 match &self.entities_module {
165 Some(m) => Ok(m.clone()),
166 None => module_path_from_file(&self.entity_file),
167 }
168 }
169}
170
171fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
175 let path_str = path.to_string_lossy().replace('\\', "/");
176
177 let after_src = match path_str.rfind("/src/") {
178 Some(pos) => &path_str[pos + 5..],
179 None if path_str.starts_with("src/") => &path_str[4..],
180 _ => {
181 return Err(crate::error::Error::Config(format!(
182 "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
183 path.display()
184 )));
185 }
186 };
187
188 let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
189 let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
190
191 let module_path = format!("crate::{}", module.replace('/', "::"));
192 Ok(module_path)
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum DatabaseKind {
197 Postgres,
198 Mysql,
199 Sqlite,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
203pub enum TimeCrate {
204 #[default]
205 Chrono,
206 Time,
207}
208
209impl std::str::FromStr for TimeCrate {
210 type Err = String;
211
212 fn from_str(s: &str) -> Result<Self, Self::Err> {
213 match s {
214 "chrono" => Ok(Self::Chrono),
215 "time" => Ok(Self::Time),
216 other => Err(format!(
217 "Unknown time crate '{}'. Expected: chrono, time",
218 other
219 )),
220 }
221 }
222}
223
224impl std::fmt::Display for TimeCrate {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Self::Chrono => write!(f, "chrono"),
228 Self::Time => write!(f, "time"),
229 }
230 }
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
234pub enum PoolVisibility {
235 #[default]
236 Private,
237 Pub,
238 PubCrate,
239}
240
241impl std::str::FromStr for PoolVisibility {
242 type Err = String;
243
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 match s {
246 "private" => Ok(Self::Private),
247 "pub" => Ok(Self::Pub),
248 "pub(crate)" => Ok(Self::PubCrate),
249 other => Err(format!(
250 "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
251 other
252 )),
253 }
254 }
255}
256
257#[derive(Debug, Clone, Default)]
260pub struct Methods {
261 pub get_all: bool,
262 pub paginate: bool,
263 pub get: bool,
264 pub insert: bool,
265 pub update: bool,
266 pub delete: bool,
267}
268
269const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "delete"];
270
271impl Methods {
272 pub fn from_list(names: &[String]) -> Result<Self, String> {
274 let mut m = Self::default();
275 for name in names {
276 match name.as_str() {
277 "*" => return Ok(Self::all()),
278 "get_all" => m.get_all = true,
279 "paginate" => m.paginate = true,
280 "get" => m.get = true,
281 "insert" => m.insert = true,
282 "update" => m.update = true,
283 "delete" => m.delete = true,
284 other => {
285 return Err(format!(
286 "Unknown method '{}'. Valid values: *, {}",
287 other,
288 ALL_METHODS.join(", ")
289 ))
290 }
291 }
292 }
293 Ok(m)
294 }
295
296 pub fn all() -> Self {
297 Self {
298 get_all: true,
299 paginate: true,
300 get: true,
301 insert: true,
302 update: true,
303 delete: true,
304 }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 fn make_db_args(url: &str) -> DatabaseArgs {
313 DatabaseArgs {
314 database_url: url.to_string(),
315 schemas: vec!["public".into()],
316 }
317 }
318
319 fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
320 EntitiesArgs {
321 db: make_db_args("postgres://localhost/db"),
322 output_dir: PathBuf::from("out"),
323 derives: vec![],
324 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
325 single_file: false,
326 tables: None,
327 exclude_tables: None,
328 views: false,
329 time_crate: TimeCrate::Chrono,
330 dry_run: false,
331 }
332 }
333
334 #[test]
337 fn test_postgres_url() {
338 let args = make_db_args("postgres://localhost/db");
339 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
340 }
341
342 #[test]
343 fn test_postgresql_url() {
344 let args = make_db_args("postgresql://localhost/db");
345 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
346 }
347
348 #[test]
349 fn test_postgres_full_url() {
350 let args = make_db_args("postgres://user:pass@host:5432/db");
351 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
352 }
353
354 #[test]
355 fn test_mysql_url() {
356 let args = make_db_args("mysql://localhost/db");
357 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
358 }
359
360 #[test]
361 fn test_mysql_full_url() {
362 let args = make_db_args("mysql://user:pass@host:3306/db");
363 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
364 }
365
366 #[test]
367 fn test_sqlite_url() {
368 let args = make_db_args("sqlite://path.db");
369 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
370 }
371
372 #[test]
373 fn test_sqlite_colon() {
374 let args = make_db_args("sqlite:path.db");
375 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
376 }
377
378 #[test]
379 fn test_sqlite_memory() {
380 let args = make_db_args("sqlite::memory:");
381 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
382 }
383
384 #[test]
385 fn test_http_url_fails() {
386 let args = make_db_args("http://example.com");
387 assert!(args.database_kind().is_err());
388 }
389
390 #[test]
391 fn test_empty_url_fails() {
392 let args = make_db_args("");
393 assert!(args.database_kind().is_err());
394 }
395
396 #[test]
397 fn test_mongo_url_fails() {
398 let args = make_db_args("mongo://localhost");
399 assert!(args.database_kind().is_err());
400 }
401
402 #[test]
403 fn test_uppercase_postgres_fails() {
404 let args = make_db_args("POSTGRES://localhost");
405 assert!(args.database_kind().is_err());
406 }
407
408 #[test]
411 fn test_overrides_empty() {
412 let args = make_entities_args_with_overrides(vec![]);
413 assert!(args.parse_type_overrides().is_empty());
414 }
415
416 #[test]
417 fn test_overrides_single() {
418 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
419 let map = args.parse_type_overrides();
420 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
421 }
422
423 #[test]
424 fn test_overrides_multiple() {
425 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
426 let map = args.parse_type_overrides();
427 assert_eq!(map.len(), 2);
428 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
429 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
430 }
431
432 #[test]
433 fn test_overrides_malformed_skipped() {
434 let args = make_entities_args_with_overrides(vec!["noequals"]);
435 assert!(args.parse_type_overrides().is_empty());
436 }
437
438 #[test]
439 fn test_overrides_mixed_valid_invalid() {
440 let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
441 let map = args.parse_type_overrides();
442 assert_eq!(map.len(), 1);
443 assert_eq!(map.get("good").unwrap(), "val");
444 }
445
446 #[test]
447 fn test_overrides_equals_in_value() {
448 let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
449 let map = args.parse_type_overrides();
450 assert_eq!(map.get("key").unwrap(), "val=ue");
451 }
452
453 #[test]
454 fn test_overrides_empty_key() {
455 let args = make_entities_args_with_overrides(vec!["=value"]);
456 let map = args.parse_type_overrides();
457 assert_eq!(map.get("").unwrap(), "value");
458 }
459
460 #[test]
461 fn test_overrides_empty_value() {
462 let args = make_entities_args_with_overrides(vec!["key="]);
463 let map = args.parse_type_overrides();
464 assert_eq!(map.get("key").unwrap(), "");
465 }
466
467 #[test]
470 fn test_exclude_tables_default_none() {
471 let args = make_entities_args_with_overrides(vec![]);
472 assert!(args.exclude_tables.is_none());
473 }
474
475 #[test]
476 fn test_exclude_tables_set() {
477 let mut args = make_entities_args_with_overrides(vec![]);
478 args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
479 assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
480 assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
481 }
482
483 #[test]
486 fn test_methods_default_all_false() {
487 let m = Methods::default();
488 assert!(!m.get_all);
489 assert!(!m.paginate);
490 assert!(!m.get);
491 assert!(!m.insert);
492 assert!(!m.update);
493 assert!(!m.delete);
494 }
495
496 #[test]
497 fn test_methods_star() {
498 let m = Methods::from_list(&["*".to_string()]).unwrap();
499 assert!(m.get_all);
500 assert!(m.paginate);
501 assert!(m.get);
502 assert!(m.insert);
503 assert!(m.update);
504 assert!(m.delete);
505 }
506
507 #[test]
508 fn test_methods_single() {
509 let m = Methods::from_list(&["get".to_string()]).unwrap();
510 assert!(m.get);
511 assert!(!m.get_all);
512 assert!(!m.insert);
513 }
514
515 #[test]
516 fn test_methods_multiple() {
517 let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
518 assert!(m.get_all);
519 assert!(m.delete);
520 assert!(!m.insert);
521 assert!(!m.paginate);
522 }
523
524 #[test]
525 fn test_methods_unknown_fails() {
526 let result = Methods::from_list(&["unknown".to_string()]);
527 assert!(result.is_err());
528 assert!(result.unwrap_err().contains("Unknown method"));
529 }
530
531 #[test]
532 fn test_methods_all() {
533 let m = Methods::all();
534 assert!(m.get_all);
535 assert!(m.paginate);
536 assert!(m.get);
537 assert!(m.insert);
538 assert!(m.update);
539 assert!(m.delete);
540 }
541
542 #[test]
545 fn test_module_path_simple() {
546 let p = PathBuf::from("src/models/users.rs");
547 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
548 }
549
550 #[test]
551 fn test_module_path_mod_rs() {
552 let p = PathBuf::from("src/models/mod.rs");
553 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
554 }
555
556 #[test]
557 fn test_module_path_nested() {
558 let p = PathBuf::from("src/db/entities/agent.rs");
559 assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
560 }
561
562 #[test]
563 fn test_module_path_absolute_with_src() {
564 let p = PathBuf::from("/home/user/project/src/models/users.rs");
565 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
566 }
567
568 #[test]
569 fn test_module_path_relative_with_src() {
570 let p = PathBuf::from("../other_project/src/models/users.rs");
571 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
572 }
573
574 #[test]
575 fn test_module_path_no_src_fails() {
576 let p = PathBuf::from("models/users.rs");
577 assert!(module_path_from_file(&p).is_err());
578 }
579
580 #[test]
581 fn test_module_path_deeply_nested_mod() {
582 let p = PathBuf::from("src/a/b/c/mod.rs");
583 assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
584 }
585
586 #[test]
587 fn test_module_path_src_root_file() {
588 let p = PathBuf::from("src/lib.rs");
589 assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
590 }
591}