Skip to main content

sqlx_gen/
cli.rs

1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8    #[command(subcommand)]
9    pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14    /// Generate code from database schema
15    Generate {
16        #[command(subcommand)]
17        subcommand: GenerateCommand,
18    },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23    /// Generate entity structs, enums, composites, and domains
24    Entities(EntitiesArgs),
25    /// Generate CRUD repository for a table or view
26    Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31    /// Database connection URL
32    #[arg(short = 'u', long, env = "DATABASE_URL")]
33    pub database_url: String,
34
35    /// Schemas to introspect (comma-separated, PG default: public)
36    #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37    pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42        let url = &self.database_url;
43        if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44            Ok(DatabaseKind::Postgres)
45        } else if url.starts_with("mysql://") {
46            Ok(DatabaseKind::Mysql)
47        } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48            Ok(DatabaseKind::Sqlite)
49        } else {
50            Err(crate::error::Error::Config(
51                "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52            ))
53        }
54    }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59    #[command(flatten)]
60    pub db: DatabaseArgs,
61
62    /// Output directory for generated files
63    #[arg(short = 'o', long, default_value = "src/models")]
64    pub output_dir: PathBuf,
65
66    /// Additional derives (e.g. Serialize,Deserialize,PartialEq)
67    #[arg(short = 'D', long, value_delimiter = ',')]
68    pub derives: Vec<String>,
69
70    /// Type overrides (e.g. jsonb=MyJsonType,uuid=MyUuid)
71    #[arg(short = 'T', long, value_delimiter = ',')]
72    pub type_overrides: Vec<String>,
73
74    /// Generate everything into a single file instead of one file per table
75    #[arg(short = 'S', long)]
76    pub single_file: bool,
77
78    /// Only generate for these tables (comma-separated)
79    #[arg(short = 't', long, value_delimiter = ',')]
80    pub tables: Option<Vec<String>>,
81
82    /// Exclude these tables/views from generation (comma-separated)
83    #[arg(short = 'x', long, value_delimiter = ',')]
84    pub exclude_tables: Option<Vec<String>>,
85
86    /// Also generate structs for SQL views
87    #[arg(short = 'v', long)]
88    pub views: bool,
89
90    /// Time crate to use for date/time types: chrono (default) or time
91    #[arg(long, default_value = "chrono")]
92    pub time_crate: TimeCrate,
93
94    /// Print to stdout without writing files
95    #[arg(short = 'n', long)]
96    pub dry_run: bool,
97}
98
99impl EntitiesArgs {
100    pub fn parse_type_overrides(&self) -> HashMap<String, String> {
101        self.type_overrides
102            .iter()
103            .filter_map(|s| {
104                let (k, v) = s.split_once('=')?;
105                Some((k.to_string(), v.to_string()))
106            })
107            .collect()
108    }
109}
110
111#[derive(Parser, Debug)]
112pub struct CrudArgs {
113    /// Path to the generated entity .rs file
114    #[arg(short = 'f', long)]
115    pub entity_file: PathBuf,
116
117    /// Database kind (postgres, mysql, sqlite)
118    #[arg(short = 'd', long)]
119    pub db_kind: String,
120
121    /// Module path of generated entities (e.g. "crate::models::users").
122    /// If omitted, derived from --entity-file by finding `src/` and converting the path.
123    #[arg(short = 'e', long)]
124    pub entities_module: Option<String>,
125
126    /// Output directory for generated repository files
127    #[arg(short = 'o', long, default_value = "src/crud")]
128    pub output_dir: PathBuf,
129
130    /// Methods to generate (comma-separated): *, get_all, paginate, get, insert, update, delete
131    #[arg(short = 'm', long, value_delimiter = ',')]
132    pub methods: Vec<String>,
133
134
135    /// Use sqlx::query_as!() compile-time checked macros instead of query_as::<_, T>() functions
136    #[arg(short = 'q', long)]
137    pub query_macro: bool,
138
139    /// Visibility of the pool field in generated repository structs: private, pub, pub(crate)
140    #[arg(short = 'p', long, default_value = "private")]
141    pub pool_visibility: PoolVisibility,
142
143    /// Print to stdout without writing files
144    #[arg(short = 'n', long)]
145    pub dry_run: bool,
146}
147
148impl CrudArgs {
149    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
150        match self.db_kind.to_lowercase().as_str() {
151            "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
152            "mysql" => Ok(DatabaseKind::Mysql),
153            "sqlite" => Ok(DatabaseKind::Sqlite),
154            other => Err(crate::error::Error::Config(format!(
155                "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
156                other
157            ))),
158        }
159    }
160
161    /// Resolve the entities module path: use the explicit value if provided,
162    /// otherwise derive it from the entity file path.
163    pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
164        match &self.entities_module {
165            Some(m) => Ok(m.clone()),
166            None => module_path_from_file(&self.entity_file),
167        }
168    }
169}
170
171/// Derive a Rust module path from a file path by finding `src/` and converting.
172/// e.g. `some/project/src/models/users.rs` → `crate::models::users`
173/// e.g. `src/db/entities/mod.rs` → `crate::db::entities`
174fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
175    let path_str = path.to_string_lossy().replace('\\', "/");
176
177    let after_src = match path_str.rfind("/src/") {
178        Some(pos) => &path_str[pos + 5..],
179        None if path_str.starts_with("src/") => &path_str[4..],
180        _ => {
181            return Err(crate::error::Error::Config(format!(
182                "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
183                path.display()
184            )));
185        }
186    };
187
188    let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
189    let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
190
191    let module_path = format!("crate::{}", module.replace('/', "::"));
192    Ok(module_path)
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum DatabaseKind {
197    Postgres,
198    Mysql,
199    Sqlite,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
203pub enum TimeCrate {
204    #[default]
205    Chrono,
206    Time,
207}
208
209impl std::str::FromStr for TimeCrate {
210    type Err = String;
211
212    fn from_str(s: &str) -> Result<Self, Self::Err> {
213        match s {
214            "chrono" => Ok(Self::Chrono),
215            "time" => Ok(Self::Time),
216            other => Err(format!(
217                "Unknown time crate '{}'. Expected: chrono, time",
218                other
219            )),
220        }
221    }
222}
223
224impl std::fmt::Display for TimeCrate {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Self::Chrono => write!(f, "chrono"),
228            Self::Time => write!(f, "time"),
229        }
230    }
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
234pub enum PoolVisibility {
235    #[default]
236    Private,
237    Pub,
238    PubCrate,
239}
240
241impl std::str::FromStr for PoolVisibility {
242    type Err = String;
243
244    fn from_str(s: &str) -> Result<Self, Self::Err> {
245        match s {
246            "private" => Ok(Self::Private),
247            "pub" => Ok(Self::Pub),
248            "pub(crate)" => Ok(Self::PubCrate),
249            other => Err(format!(
250                "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
251                other
252            )),
253        }
254    }
255}
256
257/// Which CRUD methods to generate. All fields default to `false`.
258/// Use `Methods::from_list` to parse from CLI input.
259#[derive(Debug, Clone, Default)]
260pub struct Methods {
261    pub get_all: bool,
262    pub paginate: bool,
263    pub get: bool,
264    pub insert: bool,
265    pub update: bool,
266    pub overwrite: bool,
267    pub delete: bool,
268}
269
270const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "overwrite", "delete"];
271
272impl Methods {
273    /// Parse a list of method names. `"*"` enables all methods.
274    pub fn from_list(names: &[String]) -> Result<Self, String> {
275        let mut m = Self::default();
276        for name in names {
277            match name.as_str() {
278                "*" => return Ok(Self::all()),
279                "get_all" => m.get_all = true,
280                "paginate" => m.paginate = true,
281                "get" => m.get = true,
282                "insert" => m.insert = true,
283                "update" => m.update = true,
284                "overwrite" => m.overwrite = true,
285                "delete" => m.delete = true,
286                other => {
287                    return Err(format!(
288                        "Unknown method '{}'. Valid values: *, {}",
289                        other,
290                        ALL_METHODS.join(", ")
291                    ))
292                }
293            }
294        }
295        Ok(m)
296    }
297
298    pub fn all() -> Self {
299        Self {
300            get_all: true,
301            paginate: true,
302            get: true,
303            insert: true,
304            update: true,
305            overwrite: true,
306            delete: true,
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    fn make_db_args(url: &str) -> DatabaseArgs {
316        DatabaseArgs {
317            database_url: url.to_string(),
318            schemas: vec!["public".into()],
319        }
320    }
321
322    fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
323        EntitiesArgs {
324            db: make_db_args("postgres://localhost/db"),
325            output_dir: PathBuf::from("out"),
326            derives: vec![],
327            type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
328            single_file: false,
329            tables: None,
330            exclude_tables: None,
331            views: false,
332            time_crate: TimeCrate::Chrono,
333            dry_run: false,
334        }
335    }
336
337    // ========== database_kind ==========
338
339    #[test]
340    fn test_postgres_url() {
341        let args = make_db_args("postgres://localhost/db");
342        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
343    }
344
345    #[test]
346    fn test_postgresql_url() {
347        let args = make_db_args("postgresql://localhost/db");
348        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
349    }
350
351    #[test]
352    fn test_postgres_full_url() {
353        let args = make_db_args("postgres://user:pass@host:5432/db");
354        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
355    }
356
357    #[test]
358    fn test_mysql_url() {
359        let args = make_db_args("mysql://localhost/db");
360        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
361    }
362
363    #[test]
364    fn test_mysql_full_url() {
365        let args = make_db_args("mysql://user:pass@host:3306/db");
366        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
367    }
368
369    #[test]
370    fn test_sqlite_url() {
371        let args = make_db_args("sqlite://path.db");
372        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
373    }
374
375    #[test]
376    fn test_sqlite_colon() {
377        let args = make_db_args("sqlite:path.db");
378        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
379    }
380
381    #[test]
382    fn test_sqlite_memory() {
383        let args = make_db_args("sqlite::memory:");
384        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
385    }
386
387    #[test]
388    fn test_http_url_fails() {
389        let args = make_db_args("http://example.com");
390        assert!(args.database_kind().is_err());
391    }
392
393    #[test]
394    fn test_empty_url_fails() {
395        let args = make_db_args("");
396        assert!(args.database_kind().is_err());
397    }
398
399    #[test]
400    fn test_mongo_url_fails() {
401        let args = make_db_args("mongo://localhost");
402        assert!(args.database_kind().is_err());
403    }
404
405    #[test]
406    fn test_uppercase_postgres_fails() {
407        let args = make_db_args("POSTGRES://localhost");
408        assert!(args.database_kind().is_err());
409    }
410
411    // ========== parse_type_overrides ==========
412
413    #[test]
414    fn test_overrides_empty() {
415        let args = make_entities_args_with_overrides(vec![]);
416        assert!(args.parse_type_overrides().is_empty());
417    }
418
419    #[test]
420    fn test_overrides_single() {
421        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
422        let map = args.parse_type_overrides();
423        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
424    }
425
426    #[test]
427    fn test_overrides_multiple() {
428        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
429        let map = args.parse_type_overrides();
430        assert_eq!(map.len(), 2);
431        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
432        assert_eq!(map.get("uuid").unwrap(), "MyUuid");
433    }
434
435    #[test]
436    fn test_overrides_malformed_skipped() {
437        let args = make_entities_args_with_overrides(vec!["noequals"]);
438        assert!(args.parse_type_overrides().is_empty());
439    }
440
441    #[test]
442    fn test_overrides_mixed_valid_invalid() {
443        let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
444        let map = args.parse_type_overrides();
445        assert_eq!(map.len(), 1);
446        assert_eq!(map.get("good").unwrap(), "val");
447    }
448
449    #[test]
450    fn test_overrides_equals_in_value() {
451        let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
452        let map = args.parse_type_overrides();
453        assert_eq!(map.get("key").unwrap(), "val=ue");
454    }
455
456    #[test]
457    fn test_overrides_empty_key() {
458        let args = make_entities_args_with_overrides(vec!["=value"]);
459        let map = args.parse_type_overrides();
460        assert_eq!(map.get("").unwrap(), "value");
461    }
462
463    #[test]
464    fn test_overrides_empty_value() {
465        let args = make_entities_args_with_overrides(vec!["key="]);
466        let map = args.parse_type_overrides();
467        assert_eq!(map.get("key").unwrap(), "");
468    }
469
470    // ========== exclude_tables ==========
471
472    #[test]
473    fn test_exclude_tables_default_none() {
474        let args = make_entities_args_with_overrides(vec![]);
475        assert!(args.exclude_tables.is_none());
476    }
477
478    #[test]
479    fn test_exclude_tables_set() {
480        let mut args = make_entities_args_with_overrides(vec![]);
481        args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
482        assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
483        assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
484    }
485
486    // ========== methods ==========
487
488    #[test]
489    fn test_methods_default_all_false() {
490        let m = Methods::default();
491        assert!(!m.get_all);
492        assert!(!m.paginate);
493        assert!(!m.get);
494        assert!(!m.insert);
495        assert!(!m.update);
496        assert!(!m.overwrite);
497        assert!(!m.delete);
498    }
499
500    #[test]
501    fn test_methods_star() {
502        let m = Methods::from_list(&["*".to_string()]).unwrap();
503        assert!(m.get_all);
504        assert!(m.paginate);
505        assert!(m.get);
506        assert!(m.insert);
507        assert!(m.update);
508        assert!(m.overwrite);
509        assert!(m.delete);
510    }
511
512    #[test]
513    fn test_methods_single() {
514        let m = Methods::from_list(&["get".to_string()]).unwrap();
515        assert!(m.get);
516        assert!(!m.get_all);
517        assert!(!m.insert);
518    }
519
520    #[test]
521    fn test_methods_multiple() {
522        let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
523        assert!(m.get_all);
524        assert!(m.delete);
525        assert!(!m.insert);
526        assert!(!m.paginate);
527    }
528
529    #[test]
530    fn test_methods_unknown_fails() {
531        let result = Methods::from_list(&["unknown".to_string()]);
532        assert!(result.is_err());
533        assert!(result.unwrap_err().contains("Unknown method"));
534    }
535
536    #[test]
537    fn test_methods_all() {
538        let m = Methods::all();
539        assert!(m.get_all);
540        assert!(m.paginate);
541        assert!(m.get);
542        assert!(m.insert);
543        assert!(m.update);
544        assert!(m.overwrite);
545        assert!(m.delete);
546    }
547
548    #[test]
549    fn test_parse_overwrite_method() {
550        let m = Methods::from_list(&["overwrite".to_string()]).unwrap();
551        assert!(m.overwrite);
552        assert!(!m.update);
553    }
554
555    // ========== module_path_from_file ==========
556
557    #[test]
558    fn test_module_path_simple() {
559        let p = PathBuf::from("src/models/users.rs");
560        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
561    }
562
563    #[test]
564    fn test_module_path_mod_rs() {
565        let p = PathBuf::from("src/models/mod.rs");
566        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
567    }
568
569    #[test]
570    fn test_module_path_nested() {
571        let p = PathBuf::from("src/db/entities/agent.rs");
572        assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
573    }
574
575    #[test]
576    fn test_module_path_absolute_with_src() {
577        let p = PathBuf::from("/home/user/project/src/models/users.rs");
578        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
579    }
580
581    #[test]
582    fn test_module_path_relative_with_src() {
583        let p = PathBuf::from("../other_project/src/models/users.rs");
584        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
585    }
586
587    #[test]
588    fn test_module_path_no_src_fails() {
589        let p = PathBuf::from("models/users.rs");
590        assert!(module_path_from_file(&p).is_err());
591    }
592
593    #[test]
594    fn test_module_path_deeply_nested_mod() {
595        let p = PathBuf::from("src/a/b/c/mod.rs");
596        assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
597    }
598
599    #[test]
600    fn test_module_path_src_root_file() {
601        let p = PathBuf::from("src/lib.rs");
602        assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
603    }
604}