1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8 #[command(subcommand)]
9 pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14 Generate {
16 #[command(subcommand)]
17 subcommand: GenerateCommand,
18 },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23 Entities(EntitiesArgs),
25 Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31 #[arg(short = 'u', long, env = "DATABASE_URL")]
33 pub database_url: String,
34
35 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37 pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42 let url = &self.database_url;
43 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44 Ok(DatabaseKind::Postgres)
45 } else if url.starts_with("mysql://") {
46 Ok(DatabaseKind::Mysql)
47 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48 Ok(DatabaseKind::Sqlite)
49 } else {
50 Err(crate::error::Error::Config(
51 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52 ))
53 }
54 }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59 #[command(flatten)]
60 pub db: DatabaseArgs,
61
62 #[arg(short = 'o', long, default_value = "src/models")]
64 pub output_dir: PathBuf,
65
66 #[arg(short = 'D', long, value_delimiter = ',')]
68 pub derives: Vec<String>,
69
70 #[arg(short = 'T', long, value_delimiter = ',')]
72 pub type_overrides: Vec<String>,
73
74 #[arg(short = 'S', long)]
76 pub single_file: bool,
77
78 #[arg(short = 't', long, value_delimiter = ',')]
80 pub tables: Option<Vec<String>>,
81
82 #[arg(short = 'x', long, value_delimiter = ',')]
84 pub exclude_tables: Option<Vec<String>>,
85
86 #[arg(short = 'v', long)]
88 pub views: bool,
89
90 #[arg(short = 'n', long)]
92 pub dry_run: bool,
93}
94
95impl EntitiesArgs {
96 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
97 self.type_overrides
98 .iter()
99 .filter_map(|s| {
100 let (k, v) = s.split_once('=')?;
101 Some((k.to_string(), v.to_string()))
102 })
103 .collect()
104 }
105}
106
107#[derive(Parser, Debug)]
108pub struct CrudArgs {
109 #[arg(short = 'f', long)]
111 pub entity_file: PathBuf,
112
113 #[arg(short = 'd', long)]
115 pub db_kind: String,
116
117 #[arg(short = 'e', long)]
120 pub entities_module: Option<String>,
121
122 #[arg(short = 'o', long, default_value = "src/crud")]
124 pub output_dir: PathBuf,
125
126 #[arg(short = 'm', long, value_delimiter = ',')]
128 pub methods: Vec<String>,
129
130
131 #[arg(short = 'q', long)]
133 pub query_macro: bool,
134
135 #[arg(short = 'n', long)]
137 pub dry_run: bool,
138}
139
140impl CrudArgs {
141 pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
142 match self.db_kind.to_lowercase().as_str() {
143 "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
144 "mysql" => Ok(DatabaseKind::Mysql),
145 "sqlite" => Ok(DatabaseKind::Sqlite),
146 other => Err(crate::error::Error::Config(format!(
147 "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
148 other
149 ))),
150 }
151 }
152
153 pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
156 match &self.entities_module {
157 Some(m) => Ok(m.clone()),
158 None => module_path_from_file(&self.entity_file),
159 }
160 }
161}
162
163fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
167 let path_str = path.to_string_lossy().replace('\\', "/");
168
169 let after_src = match path_str.rfind("/src/") {
170 Some(pos) => &path_str[pos + 5..],
171 None if path_str.starts_with("src/") => &path_str[4..],
172 _ => {
173 return Err(crate::error::Error::Config(format!(
174 "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
175 path.display()
176 )));
177 }
178 };
179
180 let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
181 let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
182
183 let module_path = format!("crate::{}", module.replace('/', "::"));
184 Ok(module_path)
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub enum DatabaseKind {
189 Postgres,
190 Mysql,
191 Sqlite,
192}
193
194#[derive(Debug, Clone, Default)]
197pub struct Methods {
198 pub get_all: bool,
199 pub paginate: bool,
200 pub get: bool,
201 pub insert: bool,
202 pub update: bool,
203 pub delete: bool,
204}
205
206const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "delete"];
207
208impl Methods {
209 pub fn from_list(names: &[String]) -> Result<Self, String> {
211 let mut m = Self::default();
212 for name in names {
213 match name.as_str() {
214 "*" => return Ok(Self::all()),
215 "get_all" => m.get_all = true,
216 "paginate" => m.paginate = true,
217 "get" => m.get = true,
218 "insert" => m.insert = true,
219 "update" => m.update = true,
220 "delete" => m.delete = true,
221 other => {
222 return Err(format!(
223 "Unknown method '{}'. Valid values: *, {}",
224 other,
225 ALL_METHODS.join(", ")
226 ))
227 }
228 }
229 }
230 Ok(m)
231 }
232
233 pub fn all() -> Self {
234 Self {
235 get_all: true,
236 paginate: true,
237 get: true,
238 insert: true,
239 update: true,
240 delete: true,
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 fn make_db_args(url: &str) -> DatabaseArgs {
250 DatabaseArgs {
251 database_url: url.to_string(),
252 schemas: vec!["public".into()],
253 }
254 }
255
256 fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
257 EntitiesArgs {
258 db: make_db_args("postgres://localhost/db"),
259 output_dir: PathBuf::from("out"),
260 derives: vec![],
261 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
262 single_file: false,
263 tables: None,
264 exclude_tables: None,
265 views: false,
266 dry_run: false,
267 }
268 }
269
270 #[test]
273 fn test_postgres_url() {
274 let args = make_db_args("postgres://localhost/db");
275 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
276 }
277
278 #[test]
279 fn test_postgresql_url() {
280 let args = make_db_args("postgresql://localhost/db");
281 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
282 }
283
284 #[test]
285 fn test_postgres_full_url() {
286 let args = make_db_args("postgres://user:pass@host:5432/db");
287 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
288 }
289
290 #[test]
291 fn test_mysql_url() {
292 let args = make_db_args("mysql://localhost/db");
293 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
294 }
295
296 #[test]
297 fn test_mysql_full_url() {
298 let args = make_db_args("mysql://user:pass@host:3306/db");
299 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
300 }
301
302 #[test]
303 fn test_sqlite_url() {
304 let args = make_db_args("sqlite://path.db");
305 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
306 }
307
308 #[test]
309 fn test_sqlite_colon() {
310 let args = make_db_args("sqlite:path.db");
311 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
312 }
313
314 #[test]
315 fn test_sqlite_memory() {
316 let args = make_db_args("sqlite::memory:");
317 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
318 }
319
320 #[test]
321 fn test_http_url_fails() {
322 let args = make_db_args("http://example.com");
323 assert!(args.database_kind().is_err());
324 }
325
326 #[test]
327 fn test_empty_url_fails() {
328 let args = make_db_args("");
329 assert!(args.database_kind().is_err());
330 }
331
332 #[test]
333 fn test_mongo_url_fails() {
334 let args = make_db_args("mongo://localhost");
335 assert!(args.database_kind().is_err());
336 }
337
338 #[test]
339 fn test_uppercase_postgres_fails() {
340 let args = make_db_args("POSTGRES://localhost");
341 assert!(args.database_kind().is_err());
342 }
343
344 #[test]
347 fn test_overrides_empty() {
348 let args = make_entities_args_with_overrides(vec![]);
349 assert!(args.parse_type_overrides().is_empty());
350 }
351
352 #[test]
353 fn test_overrides_single() {
354 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
355 let map = args.parse_type_overrides();
356 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
357 }
358
359 #[test]
360 fn test_overrides_multiple() {
361 let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
362 let map = args.parse_type_overrides();
363 assert_eq!(map.len(), 2);
364 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
365 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
366 }
367
368 #[test]
369 fn test_overrides_malformed_skipped() {
370 let args = make_entities_args_with_overrides(vec!["noequals"]);
371 assert!(args.parse_type_overrides().is_empty());
372 }
373
374 #[test]
375 fn test_overrides_mixed_valid_invalid() {
376 let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
377 let map = args.parse_type_overrides();
378 assert_eq!(map.len(), 1);
379 assert_eq!(map.get("good").unwrap(), "val");
380 }
381
382 #[test]
383 fn test_overrides_equals_in_value() {
384 let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
385 let map = args.parse_type_overrides();
386 assert_eq!(map.get("key").unwrap(), "val=ue");
387 }
388
389 #[test]
390 fn test_overrides_empty_key() {
391 let args = make_entities_args_with_overrides(vec!["=value"]);
392 let map = args.parse_type_overrides();
393 assert_eq!(map.get("").unwrap(), "value");
394 }
395
396 #[test]
397 fn test_overrides_empty_value() {
398 let args = make_entities_args_with_overrides(vec!["key="]);
399 let map = args.parse_type_overrides();
400 assert_eq!(map.get("key").unwrap(), "");
401 }
402
403 #[test]
406 fn test_exclude_tables_default_none() {
407 let args = make_entities_args_with_overrides(vec![]);
408 assert!(args.exclude_tables.is_none());
409 }
410
411 #[test]
412 fn test_exclude_tables_set() {
413 let mut args = make_entities_args_with_overrides(vec![]);
414 args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
415 assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
416 assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
417 }
418
419 #[test]
422 fn test_methods_default_all_false() {
423 let m = Methods::default();
424 assert!(!m.get_all);
425 assert!(!m.paginate);
426 assert!(!m.get);
427 assert!(!m.insert);
428 assert!(!m.update);
429 assert!(!m.delete);
430 }
431
432 #[test]
433 fn test_methods_star() {
434 let m = Methods::from_list(&["*".to_string()]).unwrap();
435 assert!(m.get_all);
436 assert!(m.paginate);
437 assert!(m.get);
438 assert!(m.insert);
439 assert!(m.update);
440 assert!(m.delete);
441 }
442
443 #[test]
444 fn test_methods_single() {
445 let m = Methods::from_list(&["get".to_string()]).unwrap();
446 assert!(m.get);
447 assert!(!m.get_all);
448 assert!(!m.insert);
449 }
450
451 #[test]
452 fn test_methods_multiple() {
453 let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
454 assert!(m.get_all);
455 assert!(m.delete);
456 assert!(!m.insert);
457 assert!(!m.paginate);
458 }
459
460 #[test]
461 fn test_methods_unknown_fails() {
462 let result = Methods::from_list(&["unknown".to_string()]);
463 assert!(result.is_err());
464 assert!(result.unwrap_err().contains("Unknown method"));
465 }
466
467 #[test]
468 fn test_methods_all() {
469 let m = Methods::all();
470 assert!(m.get_all);
471 assert!(m.paginate);
472 assert!(m.get);
473 assert!(m.insert);
474 assert!(m.update);
475 assert!(m.delete);
476 }
477
478 #[test]
481 fn test_module_path_simple() {
482 let p = PathBuf::from("src/models/users.rs");
483 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
484 }
485
486 #[test]
487 fn test_module_path_mod_rs() {
488 let p = PathBuf::from("src/models/mod.rs");
489 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
490 }
491
492 #[test]
493 fn test_module_path_nested() {
494 let p = PathBuf::from("src/db/entities/agent.rs");
495 assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
496 }
497
498 #[test]
499 fn test_module_path_absolute_with_src() {
500 let p = PathBuf::from("/home/user/project/src/models/users.rs");
501 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
502 }
503
504 #[test]
505 fn test_module_path_relative_with_src() {
506 let p = PathBuf::from("../other_project/src/models/users.rs");
507 assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
508 }
509
510 #[test]
511 fn test_module_path_no_src_fails() {
512 let p = PathBuf::from("models/users.rs");
513 assert!(module_path_from_file(&p).is_err());
514 }
515
516 #[test]
517 fn test_module_path_deeply_nested_mod() {
518 let p = PathBuf::from("src/a/b/c/mod.rs");
519 assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
520 }
521
522 #[test]
523 fn test_module_path_src_root_file() {
524 let p = PathBuf::from("src/lib.rs");
525 assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
526 }
527}