Skip to main content

sqlx_gen/
cli.rs

1use clap::{Parser, Subcommand};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Cli {
8    #[command(subcommand)]
9    pub command: Command,
10}
11
12#[derive(Subcommand, Debug)]
13pub enum Command {
14    /// Generate code from database schema
15    Generate {
16        #[command(subcommand)]
17        subcommand: GenerateCommand,
18    },
19}
20
21#[derive(Subcommand, Debug)]
22pub enum GenerateCommand {
23    /// Generate entity structs, enums, composites, and domains
24    Entities(EntitiesArgs),
25    /// Generate CRUD repository for a table or view
26    Crud(CrudArgs),
27}
28
29#[derive(Parser, Debug)]
30pub struct DatabaseArgs {
31    /// Database connection URL
32    #[arg(short = 'u', long, env = "DATABASE_URL")]
33    pub database_url: String,
34
35    /// Schemas to introspect (comma-separated, PG default: public)
36    #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
37    pub schemas: Vec<String>,
38}
39
40impl DatabaseArgs {
41    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
42        let url = &self.database_url;
43        if url.starts_with("postgres://") || url.starts_with("postgresql://") {
44            Ok(DatabaseKind::Postgres)
45        } else if url.starts_with("mysql://") {
46            Ok(DatabaseKind::Mysql)
47        } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
48            Ok(DatabaseKind::Sqlite)
49        } else {
50            Err(crate::error::Error::Config(
51                "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix.".to_string(),
52            ))
53        }
54    }
55}
56
57#[derive(Parser, Debug)]
58pub struct EntitiesArgs {
59    #[command(flatten)]
60    pub db: DatabaseArgs,
61
62    /// Output directory for generated files
63    #[arg(short = 'o', long, default_value = "src/models")]
64    pub output_dir: PathBuf,
65
66    /// Additional derives (e.g. Serialize,Deserialize,PartialEq)
67    #[arg(short = 'D', long, value_delimiter = ',')]
68    pub derives: Vec<String>,
69
70    /// Type overrides (e.g. jsonb=MyJsonType,uuid=MyUuid)
71    #[arg(short = 'T', long, value_delimiter = ',')]
72    pub type_overrides: Vec<String>,
73
74    /// Generate everything into a single file instead of one file per table
75    #[arg(short = 'S', long)]
76    pub single_file: bool,
77
78    /// Only generate for these tables (comma-separated)
79    #[arg(short = 't', long, value_delimiter = ',')]
80    pub tables: Option<Vec<String>>,
81
82    /// Exclude these tables/views from generation (comma-separated)
83    #[arg(short = 'x', long, value_delimiter = ',')]
84    pub exclude_tables: Option<Vec<String>>,
85
86    /// Also generate structs for SQL views
87    #[arg(short = 'v', long)]
88    pub views: bool,
89
90    /// Print to stdout without writing files
91    #[arg(short = 'n', long)]
92    pub dry_run: bool,
93}
94
95impl EntitiesArgs {
96    pub fn parse_type_overrides(&self) -> HashMap<String, String> {
97        self.type_overrides
98            .iter()
99            .filter_map(|s| {
100                let (k, v) = s.split_once('=')?;
101                Some((k.to_string(), v.to_string()))
102            })
103            .collect()
104    }
105}
106
107#[derive(Parser, Debug)]
108pub struct CrudArgs {
109    /// Path to the generated entity .rs file
110    #[arg(short = 'f', long)]
111    pub entity_file: PathBuf,
112
113    /// Database kind (postgres, mysql, sqlite)
114    #[arg(short = 'd', long)]
115    pub db_kind: String,
116
117    /// Module path of generated entities (e.g. "crate::models::users").
118    /// If omitted, derived from --entity-file by finding `src/` and converting the path.
119    #[arg(short = 'e', long)]
120    pub entities_module: Option<String>,
121
122    /// Output directory for generated repository files
123    #[arg(short = 'o', long, default_value = "src/crud")]
124    pub output_dir: PathBuf,
125
126    /// Methods to generate (comma-separated): *, get_all, paginate, get, insert, update, delete
127    #[arg(short = 'm', long, value_delimiter = ',')]
128    pub methods: Vec<String>,
129
130
131    /// Use sqlx::query_as!() compile-time checked macros instead of query_as::<_, T>() functions
132    #[arg(short = 'q', long)]
133    pub query_macro: bool,
134
135    /// Print to stdout without writing files
136    #[arg(short = 'n', long)]
137    pub dry_run: bool,
138}
139
140impl CrudArgs {
141    pub fn database_kind(&self) -> crate::error::Result<DatabaseKind> {
142        match self.db_kind.to_lowercase().as_str() {
143            "postgres" | "postgresql" | "pg" => Ok(DatabaseKind::Postgres),
144            "mysql" => Ok(DatabaseKind::Mysql),
145            "sqlite" => Ok(DatabaseKind::Sqlite),
146            other => Err(crate::error::Error::Config(format!(
147                "Unknown database kind '{}'. Expected: postgres, mysql, sqlite",
148                other
149            ))),
150        }
151    }
152
153    /// Resolve the entities module path: use the explicit value if provided,
154    /// otherwise derive it from the entity file path.
155    pub fn resolve_entities_module(&self) -> crate::error::Result<String> {
156        match &self.entities_module {
157            Some(m) => Ok(m.clone()),
158            None => module_path_from_file(&self.entity_file),
159        }
160    }
161}
162
163/// Derive a Rust module path from a file path by finding `src/` and converting.
164/// e.g. `some/project/src/models/users.rs` → `crate::models::users`
165/// e.g. `src/db/entities/mod.rs` → `crate::db::entities`
166fn module_path_from_file(path: &std::path::Path) -> crate::error::Result<String> {
167    let path_str = path.to_string_lossy().replace('\\', "/");
168
169    let after_src = match path_str.rfind("/src/") {
170        Some(pos) => &path_str[pos + 5..],
171        None if path_str.starts_with("src/") => &path_str[4..],
172        _ => {
173            return Err(crate::error::Error::Config(format!(
174                "Cannot derive module path from '{}': no 'src/' found. Use --entities-module explicitly.",
175                path.display()
176            )));
177        }
178    };
179
180    let without_ext = after_src.strip_suffix(".rs").unwrap_or(after_src);
181    let module = without_ext.strip_suffix("/mod").unwrap_or(without_ext);
182
183    let module_path = format!("crate::{}", module.replace('/', "::"));
184    Ok(module_path)
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq)]
188pub enum DatabaseKind {
189    Postgres,
190    Mysql,
191    Sqlite,
192}
193
194/// Which CRUD methods to generate. All fields default to `false`.
195/// Use `Methods::from_list` to parse from CLI input.
196#[derive(Debug, Clone, Default)]
197pub struct Methods {
198    pub get_all: bool,
199    pub paginate: bool,
200    pub get: bool,
201    pub insert: bool,
202    pub update: bool,
203    pub delete: bool,
204}
205
206const ALL_METHODS: &[&str] = &["get_all", "paginate", "get", "insert", "update", "delete"];
207
208impl Methods {
209    /// Parse a list of method names. `"*"` enables all methods.
210    pub fn from_list(names: &[String]) -> Result<Self, String> {
211        let mut m = Self::default();
212        for name in names {
213            match name.as_str() {
214                "*" => return Ok(Self::all()),
215                "get_all" => m.get_all = true,
216                "paginate" => m.paginate = true,
217                "get" => m.get = true,
218                "insert" => m.insert = true,
219                "update" => m.update = true,
220                "delete" => m.delete = true,
221                other => {
222                    return Err(format!(
223                        "Unknown method '{}'. Valid values: *, {}",
224                        other,
225                        ALL_METHODS.join(", ")
226                    ))
227                }
228            }
229        }
230        Ok(m)
231    }
232
233    pub fn all() -> Self {
234        Self {
235            get_all: true,
236            paginate: true,
237            get: true,
238            insert: true,
239            update: true,
240            delete: true,
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    fn make_db_args(url: &str) -> DatabaseArgs {
250        DatabaseArgs {
251            database_url: url.to_string(),
252            schemas: vec!["public".into()],
253        }
254    }
255
256    fn make_entities_args_with_overrides(overrides: Vec<&str>) -> EntitiesArgs {
257        EntitiesArgs {
258            db: make_db_args("postgres://localhost/db"),
259            output_dir: PathBuf::from("out"),
260            derives: vec![],
261            type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
262            single_file: false,
263            tables: None,
264            exclude_tables: None,
265            views: false,
266            dry_run: false,
267        }
268    }
269
270    // ========== database_kind ==========
271
272    #[test]
273    fn test_postgres_url() {
274        let args = make_db_args("postgres://localhost/db");
275        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
276    }
277
278    #[test]
279    fn test_postgresql_url() {
280        let args = make_db_args("postgresql://localhost/db");
281        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
282    }
283
284    #[test]
285    fn test_postgres_full_url() {
286        let args = make_db_args("postgres://user:pass@host:5432/db");
287        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
288    }
289
290    #[test]
291    fn test_mysql_url() {
292        let args = make_db_args("mysql://localhost/db");
293        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
294    }
295
296    #[test]
297    fn test_mysql_full_url() {
298        let args = make_db_args("mysql://user:pass@host:3306/db");
299        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
300    }
301
302    #[test]
303    fn test_sqlite_url() {
304        let args = make_db_args("sqlite://path.db");
305        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
306    }
307
308    #[test]
309    fn test_sqlite_colon() {
310        let args = make_db_args("sqlite:path.db");
311        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
312    }
313
314    #[test]
315    fn test_sqlite_memory() {
316        let args = make_db_args("sqlite::memory:");
317        assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
318    }
319
320    #[test]
321    fn test_http_url_fails() {
322        let args = make_db_args("http://example.com");
323        assert!(args.database_kind().is_err());
324    }
325
326    #[test]
327    fn test_empty_url_fails() {
328        let args = make_db_args("");
329        assert!(args.database_kind().is_err());
330    }
331
332    #[test]
333    fn test_mongo_url_fails() {
334        let args = make_db_args("mongo://localhost");
335        assert!(args.database_kind().is_err());
336    }
337
338    #[test]
339    fn test_uppercase_postgres_fails() {
340        let args = make_db_args("POSTGRES://localhost");
341        assert!(args.database_kind().is_err());
342    }
343
344    // ========== parse_type_overrides ==========
345
346    #[test]
347    fn test_overrides_empty() {
348        let args = make_entities_args_with_overrides(vec![]);
349        assert!(args.parse_type_overrides().is_empty());
350    }
351
352    #[test]
353    fn test_overrides_single() {
354        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson"]);
355        let map = args.parse_type_overrides();
356        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
357    }
358
359    #[test]
360    fn test_overrides_multiple() {
361        let args = make_entities_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
362        let map = args.parse_type_overrides();
363        assert_eq!(map.len(), 2);
364        assert_eq!(map.get("jsonb").unwrap(), "MyJson");
365        assert_eq!(map.get("uuid").unwrap(), "MyUuid");
366    }
367
368    #[test]
369    fn test_overrides_malformed_skipped() {
370        let args = make_entities_args_with_overrides(vec!["noequals"]);
371        assert!(args.parse_type_overrides().is_empty());
372    }
373
374    #[test]
375    fn test_overrides_mixed_valid_invalid() {
376        let args = make_entities_args_with_overrides(vec!["good=val", "bad"]);
377        let map = args.parse_type_overrides();
378        assert_eq!(map.len(), 1);
379        assert_eq!(map.get("good").unwrap(), "val");
380    }
381
382    #[test]
383    fn test_overrides_equals_in_value() {
384        let args = make_entities_args_with_overrides(vec!["key=val=ue"]);
385        let map = args.parse_type_overrides();
386        assert_eq!(map.get("key").unwrap(), "val=ue");
387    }
388
389    #[test]
390    fn test_overrides_empty_key() {
391        let args = make_entities_args_with_overrides(vec!["=value"]);
392        let map = args.parse_type_overrides();
393        assert_eq!(map.get("").unwrap(), "value");
394    }
395
396    #[test]
397    fn test_overrides_empty_value() {
398        let args = make_entities_args_with_overrides(vec!["key="]);
399        let map = args.parse_type_overrides();
400        assert_eq!(map.get("key").unwrap(), "");
401    }
402
403    // ========== exclude_tables ==========
404
405    #[test]
406    fn test_exclude_tables_default_none() {
407        let args = make_entities_args_with_overrides(vec![]);
408        assert!(args.exclude_tables.is_none());
409    }
410
411    #[test]
412    fn test_exclude_tables_set() {
413        let mut args = make_entities_args_with_overrides(vec![]);
414        args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
415        assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
416        assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
417    }
418
419    // ========== methods ==========
420
421    #[test]
422    fn test_methods_default_all_false() {
423        let m = Methods::default();
424        assert!(!m.get_all);
425        assert!(!m.paginate);
426        assert!(!m.get);
427        assert!(!m.insert);
428        assert!(!m.update);
429        assert!(!m.delete);
430    }
431
432    #[test]
433    fn test_methods_star() {
434        let m = Methods::from_list(&["*".to_string()]).unwrap();
435        assert!(m.get_all);
436        assert!(m.paginate);
437        assert!(m.get);
438        assert!(m.insert);
439        assert!(m.update);
440        assert!(m.delete);
441    }
442
443    #[test]
444    fn test_methods_single() {
445        let m = Methods::from_list(&["get".to_string()]).unwrap();
446        assert!(m.get);
447        assert!(!m.get_all);
448        assert!(!m.insert);
449    }
450
451    #[test]
452    fn test_methods_multiple() {
453        let m = Methods::from_list(&["get_all".to_string(), "delete".to_string()]).unwrap();
454        assert!(m.get_all);
455        assert!(m.delete);
456        assert!(!m.insert);
457        assert!(!m.paginate);
458    }
459
460    #[test]
461    fn test_methods_unknown_fails() {
462        let result = Methods::from_list(&["unknown".to_string()]);
463        assert!(result.is_err());
464        assert!(result.unwrap_err().contains("Unknown method"));
465    }
466
467    #[test]
468    fn test_methods_all() {
469        let m = Methods::all();
470        assert!(m.get_all);
471        assert!(m.paginate);
472        assert!(m.get);
473        assert!(m.insert);
474        assert!(m.update);
475        assert!(m.delete);
476    }
477
478    // ========== module_path_from_file ==========
479
480    #[test]
481    fn test_module_path_simple() {
482        let p = PathBuf::from("src/models/users.rs");
483        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
484    }
485
486    #[test]
487    fn test_module_path_mod_rs() {
488        let p = PathBuf::from("src/models/mod.rs");
489        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models");
490    }
491
492    #[test]
493    fn test_module_path_nested() {
494        let p = PathBuf::from("src/db/entities/agent.rs");
495        assert_eq!(module_path_from_file(&p).unwrap(), "crate::db::entities::agent");
496    }
497
498    #[test]
499    fn test_module_path_absolute_with_src() {
500        let p = PathBuf::from("/home/user/project/src/models/users.rs");
501        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
502    }
503
504    #[test]
505    fn test_module_path_relative_with_src() {
506        let p = PathBuf::from("../other_project/src/models/users.rs");
507        assert_eq!(module_path_from_file(&p).unwrap(), "crate::models::users");
508    }
509
510    #[test]
511    fn test_module_path_no_src_fails() {
512        let p = PathBuf::from("models/users.rs");
513        assert!(module_path_from_file(&p).is_err());
514    }
515
516    #[test]
517    fn test_module_path_deeply_nested_mod() {
518        let p = PathBuf::from("src/a/b/c/mod.rs");
519        assert_eq!(module_path_from_file(&p).unwrap(), "crate::a::b::c");
520    }
521
522    #[test]
523    fn test_module_path_src_root_file() {
524        let p = PathBuf::from("src/lib.rs");
525        assert_eq!(module_path_from_file(&p).unwrap(), "crate::lib");
526    }
527}