Skip to main content

sqlx_gen/
cli.rs

1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8    #[command(subcommand)]
9    pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14    /// Generate code from database schema
15    Generate {
16        #[command(subcommand)]
17        subcommand: GenerateCommand,
18    },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23    /// Generate entity structs, enums, composites, and domains
24    Entities(EntitiesArgs),
25    /// Generate CRUD repository for a table or view
26    Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31    /// Database connection URL
32    #[arg(short = 'u', long, env = "DATABASE_URL")]
33    pub database_url: String,
34
35    /// Schemas to introspect (comma-separated, PG default: public)
36    #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37    pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42        let url = &self.database_url;
43        if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44            Ok(DatabaseKind::Postgres)
45        } else if url.starts_with("mysql://") {
46            Ok(DatabaseKind::Mysql)
47        } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48            Ok(DatabaseKind::Sqlite)
49        } else {
50            Err(crate::error::Error::Config(
51                "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52            ))
53        }
54    }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59    #[command(flatten)]
60    pub db: DatabaseArgs,
61
62    /// Output directory for generated files
63    #[arg(short = 'o', long, default_value = "src/models")]
64    pub output_dir: PathBuf,
65
66    /// Additional derives (e.g. Serialize,Deserialize,PartialEq)
67    #[arg(short = 'D', long, value_delimiter = ',')]
68    pub derives: Vec<String>,
69
70    /// Type overrides (e.g. jsonb=MyJsonType,uuid=MyUuid)
71    #[arg(short = 'T', long, value_delimiter = ',')]
72    pub type_overrides: Vec<String>,
73
74    /// Generate everything into a single file instead of one file per table
75    #[arg(short = 'S', long)]
76    pub single_file: bool,
77
78    /// Only generate for these tables (comma-separated)
79    #[arg(short = 't', long, value_delimiter = ',')]
80    pub tables: Option<Vec<String>>,
81
82    /// Exclude these tables/views from generation (comma-separated)
83    #[arg(short = 'x', long, value_delimiter = ',')]
84    pub exclude_tables: Option<Vec<String>>,
85
86    /// Also generate structs for SQL views
87    #[arg(short = 'v', long)]
88    pub views: bool,
89
90    /// Time crate to use for date/time types: chrono (default) or time
91    #[arg(long, default_value = "chrono")]
92    pub time_crate: TimeCrate,
93
94    /// Print to stdout without writing files
95    #[arg(short = 'n', long)]
96    pub dry_run: bool,
97}
98
99impl EntitiesArgs {
100    pub fn parse_type_overrides(&self) -> HashMap<String, String> {
101        self.type_overrides
102            .iter()
103            .filter_map(|s| {
104                let (k, v) = s.split_once('=')?;
105                Some((k.to_string(), v.to_string()))
106            })
107            .collect()
108    }
109}
110
111#[derive(Parser, Debug)]
112pub struct CrudArgs {
113    /// Path to the generated entity .rs file
114    #[arg(short = 'f', long)]
115    pub entity_file: PathBuf,
116
117    /// Database kind (postgres, mysql, sqlite)
118    #[arg(short = 'd', long)]
119    pub db_kind: String,
120
121    /// Module path of generated entities (e.g. "crate::models::users").
122    /// If omitted, derived from --entity-file by finding `src/` and converting the path.
123    #[arg(short = 'e', long)]
124    pub entities_module: Option<String>,
125
126    /// Output directory for generated repository files
127    #[arg(short = 'o', long, default_value = "src/crud")]
128    pub output_dir: PathBuf,
129
130    /// Methods to generate (comma-separated): *, get_all, paginate, get, insert, update, delete
131    #[arg(short = 'm', long, value_delimiter = ',')]
132    pub methods: Vec<String>,
133
134
135    /// Use sqlx::query_as!() compile-time checked macros instead of query_as::<_, T>() functions
136    #[arg(short = 'q', long)]
137    pub query_macro: bool,
138
139    /// Visibility of the pool field in generated repository structs: private, pub, pub(crate)
140    #[arg(short = 'p', long, default_value = "private")]
141    pub pool_visibility: PoolVisibility,
142
143    /// Print to stdout without writing files
144    #[arg(short = 'n', long)]
145    pub dry_run: bool,
146}
147
148impl CrudArgs {
149    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
150        match self.db_kind.to_lowercase().as_str() {
151            "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
152            "mysql" => Ok(DatabaseKind::Mysql),
153            "sqlite" => Ok(DatabaseKind::Sqlite),
154            other => Err(crate::error::Error::Config(format!(
155                "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
156                other
157            ))),
158        }
159    }
160
161    /// Resolve the entities module path: use the explicit value if provided,
162    /// otherwise derive it from the entity file path.
163    pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
164        match &self.entities_module {
165            Some(m) => Ok(m.clone()),
166            None => module_path_from_file(&self.entity_file),
167        }
168    }
169}
170
171/// Derive a Rust module path from a file path by finding `src/` and converting.
172/// e.g. `some/project/src/models/users.rs` → `crate::models::users`
173/// e.g. `src/db/entities/mod.rs` → `crate::db::entities`
174fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
175    let path_str = path.to_string_lossy().replace('\\', "/");
176
177    let after_src = match path_str.rfind("/src/") {
178        Some(pos) => &path_str[pos + 5..],
179        None if path_str.starts_with("src/") => &path_str[4..],
180        _ => {
181            return Err(crate::error::Error::Config(format!(
182                "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
183                path.display()
184            )));
185        }
186    };
187
188    let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
189    let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
190
191    let module_path = format!("crate::{}", module.replace('/', "::"));
192    Ok(module_path)
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum DatabaseKind {
197    Postgres,
198    Mysql,
199    Sqlite,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
203pub enum TimeCrate {
204    #[default]
205    Chrono,
206    Time,
207}
208
209impl std::str::FromStr for TimeCrate {
210    type Err = String;
211
212    fn from_str(s: &str) -> Result<Self, Self::Err> {
213        match s {
214            "chrono" => Ok(Self::Chrono),
215            "time" => Ok(Self::Time),
216            other => Err(format!(
217                "Unknown time crate '{}'. Expected: chrono, time",
218                other
219            )),
220        }
221    }
222}
223
224impl std::fmt::Display for TimeCrate {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        match self {
227            Self::Chrono => write!(f, "chrono"),
228            Self::Time => write!(f, "time"),
229        }
230    }
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
234pub enum PoolVisibility {
235    #[default]
236    Private,
237    Pub,
238    PubCrate,
239}
240
241impl std::str::FromStr for PoolVisibility {
242    type Err = String;
243
244    fn from_str(s: &str) -> Result<Self, Self::Err> {
245        match s {
246            "private" => Ok(Self::Private),
247            "pub" => Ok(Self::Pub),
248            "pub(crate)" => Ok(Self::PubCrate),
249            other => Err(format!(
250                "Unknown pool visibility '{}'. Expected: private, pub, pub(crate)",
251                other
252            )),
253        }
254    }
255}
256
257/// Which CRUD methods to generate. All fields default to `false`.
258/// Use `Methods::from_list` to parse from CLI input.
259#[derive(Debug, Clone, Default)]
260pub struct Methods {
261    pub get_all: bool,
262    pub paginate: bool,
263    pub get: bool,
264    pub insert: bool,
265    pub update: bool,
266    pub delete: bool,
267}
268
269const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "delete"];
270
271impl Methods {
272    /// Parse a list of method names. `"*"` enables all methods.
273    pub fn from_list(names: &[String]) -> Result<Self, String> {
274        let mut m = Self::default();
275        for name in names {
276            match name.as_str() {
277                "*" => return Ok(Self::all()),
278                "get_all" => m.get_all = true,
279                "paginate" => m.paginate = true,
280                "get" => m.get = true,
281                "insert" => m.insert = true,
282                "update" => m.update = true,
283                "delete" => m.delete = true,
284                other => {
285                    return Err(format!(
286                        "Unknown method '{}'. Valid values: *, {}",
287                        other,
288                        ALL_METHODS.join(", ")
289                    ))
290                }
291            }
292        }
293        Ok(m)
294    }
295
296    pub fn all() -> Self {
297        Self {
298            get_all: true,
299            paginate: true,
300            get: true,
301            insert: true,
302            update: true,
303            delete: true,
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    fn make_db_args(url: &str) -> DatabaseArgs {
313        DatabaseArgs {
314            database_url: url.to_string(),
315            schemas: vec!["public".into()],
316        }
317    }
318
319    fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
320        EntitiesArgs {
321            db: make_db_args("postgres://localhost/db"),
322            output_dir: PathBuf::from("out"),
323            derives: vec![],
324            type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
325            single_file: false,
326            tables: None,
327            exclude_tables: None,
328            views: false,
329            time_crate: TimeCrate::Chrono,
330            dry_run: false,
331        }
332    }
333
334    // ========== database_kind ==========
335
336    #[test]
337    fn test_postgres_url() {
338        let args = make_db_args("postgres://localhost/db");
339        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
340    }
341
342    #[test]
343    fn test_postgresql_url() {
344        let args = make_db_args("postgresql://localhost/db");
345        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
346    }
347
348    #[test]
349    fn test_postgres_full_url() {
350        let args = make_db_args("postgres://user:pass@host:5432/db");
351        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
352    }
353
354    #[test]
355    fn test_mysql_url() {
356        let args = make_db_args("mysql://localhost/db");
357        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
358    }
359
360    #[test]
361    fn test_mysql_full_url() {
362        let args = make_db_args("mysql://user:pass@host:3306/db");
363        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
364    }
365
366    #[test]
367    fn test_sqlite_url() {
368        let args = make_db_args("sqlite://path.db");
369        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
370    }
371
372    #[test]
373    fn test_sqlite_colon() {
374        let args = make_db_args("sqlite:path.db");
375        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
376    }
377
378    #[test]
379    fn test_sqlite_memory() {
380        let args = make_db_args("sqlite::memory:");
381        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
382    }
383
384    #[test]
385    fn test_http_url_fails() {
386        let args = make_db_args("http://example.com");
387        assert!(args.database_kind().is_err());
388    }
389
390    #[test]
391    fn test_empty_url_fails() {
392        let args = make_db_args("");
393        assert!(args.database_kind().is_err());
394    }
395
396    #[test]
397    fn test_mongo_url_fails() {
398        let args = make_db_args("mongo://localhost");
399        assert!(args.database_kind().is_err());
400    }
401
402    #[test]
403    fn test_uppercase_postgres_fails() {
404        let args = make_db_args("POSTGRES://localhost");
405        assert!(args.database_kind().is_err());
406    }
407
408    // ========== parse_type_overrides ==========
409
410    #[test]
411    fn test_overrides_empty() {
412        let args = make_entities_args_with_overrides(vec![]);
413        assert!(args.parse_type_overrides().is_empty());
414    }
415
416    #[test]
417    fn test_overrides_single() {
418        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
419        let map = args.parse_type_overrides();
420        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
421    }
422
423    #[test]
424    fn test_overrides_multiple() {
425        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
426        let map = args.parse_type_overrides();
427        assert_eq!(map.len(), 2);
428        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
429        assert_eq!(map.get("uuid").unwrap(), "MyUuid");
430    }
431
432    #[test]
433    fn test_overrides_malformed_skipped() {
434        let args = make_entities_args_with_overrides(vec!["noequals"]);
435        assert!(args.parse_type_overrides().is_empty());
436    }
437
438    #[test]
439    fn test_overrides_mixed_valid_invalid() {
440        let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
441        let map = args.parse_type_overrides();
442        assert_eq!(map.len(), 1);
443        assert_eq!(map.get("good").unwrap(), "val");
444    }
445
446    #[test]
447    fn test_overrides_equals_in_value() {
448        let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
449        let map = args.parse_type_overrides();
450        assert_eq!(map.get("key").unwrap(), "val=ue");
451    }
452
453    #[test]
454    fn test_overrides_empty_key() {
455        let args = make_entities_args_with_overrides(vec!["=value"]);
456        let map = args.parse_type_overrides();
457        assert_eq!(map.get("").unwrap(), "value");
458    }
459
460    #[test]
461    fn test_overrides_empty_value() {
462        let args = make_entities_args_with_overrides(vec!["key="]);
463        let map = args.parse_type_overrides();
464        assert_eq!(map.get("key").unwrap(), "");
465    }
466
467    // ========== exclude_tables ==========
468
469    #[test]
470    fn test_exclude_tables_default_none() {
471        let args = make_entities_args_with_overrides(vec![]);
472        assert!(args.exclude_tables.is_none());
473    }
474
475    #[test]
476    fn test_exclude_tables_set() {
477        let mut args = make_entities_args_with_overrides(vec![]);
478        args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
479        assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
480        assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
481    }
482
483    // ========== methods ==========
484
485    #[test]
486    fn test_methods_default_all_false() {
487        let m = Methods::default();
488        assert!(!m.get_all);
489        assert!(!m.paginate);
490        assert!(!m.get);
491        assert!(!m.insert);
492        assert!(!m.update);
493        assert!(!m.delete);
494    }
495
496    #[test]
497    fn test_methods_star() {
498        let m = Methods::from_list(&["*".to_string()]).unwrap();
499        assert!(m.get_all);
500        assert!(m.paginate);
501        assert!(m.get);
502        assert!(m.insert);
503        assert!(m.update);
504        assert!(m.delete);
505    }
506
507    #[test]
508    fn test_methods_single() {
509        let m = Methods::from_list(&["get".to_string()]).unwrap();
510        assert!(m.get);
511        assert!(!m.get_all);
512        assert!(!m.insert);
513    }
514
515    #[test]
516    fn test_methods_multiple() {
517        let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
518        assert!(m.get_all);
519        assert!(m.delete);
520        assert!(!m.insert);
521        assert!(!m.paginate);
522    }
523
524    #[test]
525    fn test_methods_unknown_fails() {
526        let result = Methods::from_list(&["unknown".to_string()]);
527        assert!(result.is_err());
528        assert!(result.unwrap_err().contains("Unknown method"));
529    }
530
531    #[test]
532    fn test_methods_all() {
533        let m = Methods::all();
534        assert!(m.get_all);
535        assert!(m.paginate);
536        assert!(m.get);
537        assert!(m.insert);
538        assert!(m.update);
539        assert!(m.delete);
540    }
541
542    // ========== module_path_from_file ==========
543
544    #[test]
545    fn test_module_path_simple() {
546        let p = PathBuf::from("src/models/users.rs");
547        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
548    }
549
550    #[test]
551    fn test_module_path_mod_rs() {
552        let p = PathBuf::from("src/models/mod.rs");
553        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
554    }
555
556    #[test]
557    fn test_module_path_nested() {
558        let p = PathBuf::from("src/db/entities/agent.rs");
559        assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
560    }
561
562    #[test]
563    fn test_module_path_absolute_with_src() {
564        let p = PathBuf::from("/home/user/project/src/models/users.rs");
565        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
566    }
567
568    #[test]
569    fn test_module_path_relative_with_src() {
570        let p = PathBuf::from("../other_project/src/models/users.rs");
571        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
572    }
573
574    #[test]
575    fn test_module_path_no_src_fails() {
576        let p = PathBuf::from("models/users.rs");
577        assert!(module_path_from_file(&p).is_err());
578    }
579
580    #[test]
581    fn test_module_path_deeply_nested_mod() {
582        let p = PathBuf::from("src/a/b/c/mod.rs");
583        assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
584    }
585
586    #[test]
587    fn test_module_path_src_root_file() {
588        let p = PathBuf::from("src/lib.rs");
589        assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
590    }
591}