1use clap::Parser;
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Parser, Debug)]
6#[command(name = "sqlx-gen", about = "Generate Rust structs from database schema")]
7pub struct Args {
8 #[arg(short = 'u', long, env = "DATABASE_URL")]
10 pub database_url: String,
11
12 #[arg(short = 'o', long, default_value = "src/models")]
14 pub output_dir: PathBuf,
15
16 #[arg(short = 's', long, value_delimiter = ',', default_value = "public")]
18 pub schemas: Vec<String>,
19
20 #[arg(long, value_delimiter = ',')]
22 pub derives: Vec<String>,
23
24 #[arg(long, value_delimiter = ',')]
26 pub type_overrides: Vec<String>,
27
28 #[arg(long)]
30 pub single_file: bool,
31
32 #[arg(long, value_delimiter = ',')]
34 pub tables: Option<Vec<String>>,
35
36 #[arg(long)]
38 pub views: bool,
39
40 #[arg(long)]
42 pub dry_run: bool,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum DatabaseKind {
47 Postgres,
48 Mysql,
49 Sqlite,
50}
51
52impl Args {
53 pub fn database_kind(&self) -> anyhow::Result<DatabaseKind> {
54 let url = &self.database_url;
55 if url.starts_with("postgres://") || url.starts_with("postgresql://") {
56 Ok(DatabaseKind::Postgres)
57 } else if url.starts_with("mysql://") {
58 Ok(DatabaseKind::Mysql)
59 } else if url.starts_with("sqlite://") || url.starts_with("sqlite:") {
60 Ok(DatabaseKind::Sqlite)
61 } else {
62 anyhow::bail!(
63 "Cannot detect database type from URL. Expected postgres://, mysql://, or sqlite:// prefix."
64 )
65 }
66 }
67
68 pub fn parse_type_overrides(&self) -> HashMap<String, String> {
69 self.type_overrides
70 .iter()
71 .filter_map(|s| {
72 let (k, v) = s.split_once('=')?;
73 Some((k.to_string(), v.to_string()))
74 })
75 .collect()
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 fn make_args(url: &str) -> Args {
84 Args {
85 database_url: url.to_string(),
86 output_dir: PathBuf::from("out"),
87 schemas: vec!["public".into()],
88 derives: vec![],
89 type_overrides: vec![],
90 single_file: false,
91 tables: None,
92 views: false,
93 dry_run: false,
94 }
95 }
96
97 fn make_args_with_overrides(overrides: Vec<&str>) -> Args {
98 Args {
99 database_url: "postgres://localhost/db".to_string(),
100 output_dir: PathBuf::from("out"),
101 schemas: vec!["public".into()],
102 derives: vec![],
103 type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
104 single_file: false,
105 tables: None,
106 views: false,
107 dry_run: false,
108 }
109 }
110
111 #[test]
114 fn test_postgres_url() {
115 let args = make_args("postgres://localhost/db");
116 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
117 }
118
119 #[test]
120 fn test_postgresql_url() {
121 let args = make_args("postgresql://localhost/db");
122 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
123 }
124
125 #[test]
126 fn test_postgres_full_url() {
127 let args = make_args("postgres://user:pass@host:5432/db");
128 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Postgres);
129 }
130
131 #[test]
132 fn test_mysql_url() {
133 let args = make_args("mysql://localhost/db");
134 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
135 }
136
137 #[test]
138 fn test_mysql_full_url() {
139 let args = make_args("mysql://user:pass@host:3306/db");
140 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Mysql);
141 }
142
143 #[test]
144 fn test_sqlite_url() {
145 let args = make_args("sqlite://path.db");
146 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
147 }
148
149 #[test]
150 fn test_sqlite_colon() {
151 let args = make_args("sqlite:path.db");
152 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
153 }
154
155 #[test]
156 fn test_sqlite_memory() {
157 let args = make_args("sqlite::memory:");
158 assert_eq!(args.database_kind().unwrap(), DatabaseKind::Sqlite);
159 }
160
161 #[test]
162 fn test_http_url_fails() {
163 let args = make_args("http://example.com");
164 assert!(args.database_kind().is_err());
165 }
166
167 #[test]
168 fn test_empty_url_fails() {
169 let args = make_args("");
170 assert!(args.database_kind().is_err());
171 }
172
173 #[test]
174 fn test_mongo_url_fails() {
175 let args = make_args("mongo://localhost");
176 assert!(args.database_kind().is_err());
177 }
178
179 #[test]
180 fn test_uppercase_postgres_fails() {
181 let args = make_args("POSTGRES://localhost");
182 assert!(args.database_kind().is_err());
183 }
184
185 #[test]
188 fn test_overrides_empty() {
189 let args = make_args_with_overrides(vec![]);
190 assert!(args.parse_type_overrides().is_empty());
191 }
192
193 #[test]
194 fn test_overrides_single() {
195 let args = make_args_with_overrides(vec!["jsonb=MyJson"]);
196 let map = args.parse_type_overrides();
197 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
198 }
199
200 #[test]
201 fn test_overrides_multiple() {
202 let args = make_args_with_overrides(vec!["jsonb=MyJson", "uuid=MyUuid"]);
203 let map = args.parse_type_overrides();
204 assert_eq!(map.len(), 2);
205 assert_eq!(map.get("jsonb").unwrap(), "MyJson");
206 assert_eq!(map.get("uuid").unwrap(), "MyUuid");
207 }
208
209 #[test]
210 fn test_overrides_malformed_skipped() {
211 let args = make_args_with_overrides(vec!["noequals"]);
212 assert!(args.parse_type_overrides().is_empty());
213 }
214
215 #[test]
216 fn test_overrides_mixed_valid_invalid() {
217 let args = make_args_with_overrides(vec!["good=val", "bad"]);
218 let map = args.parse_type_overrides();
219 assert_eq!(map.len(), 1);
220 assert_eq!(map.get("good").unwrap(), "val");
221 }
222
223 #[test]
224 fn test_overrides_equals_in_value() {
225 let args = make_args_with_overrides(vec!["key=val=ue"]);
226 let map = args.parse_type_overrides();
227 assert_eq!(map.get("key").unwrap(), "val=ue");
228 }
229
230 #[test]
231 fn test_overrides_empty_key() {
232 let args = make_args_with_overrides(vec!["=value"]);
233 let map = args.parse_type_overrides();
234 assert_eq!(map.get("").unwrap(), "value");
235 }
236
237 #[test]
238 fn test_overrides_empty_value() {
239 let args = make_args_with_overrides(vec!["key="]);
240 let map = args.parse_type_overrides();
241 assert_eq!(map.get("key").unwrap(), "");
242 }
243}