1use core::time;
2use sea_orm_codegen::{
3 DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
4 WithPrelude, WithSerde,
5};
6use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{prelude::*, EnvFilter};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 compact_format: _,
19 expanded_format,
20 frontend_format,
21 include_hidden_tables,
22 tables,
23 ignore_tables,
24 max_connections,
25 acquire_timeout,
26 output_dir,
27 database_schema,
28 database_url,
29 with_prelude,
30 with_serde,
31 serde_skip_deserializing_primary_key,
32 serde_skip_hidden_column,
33 with_copy_enums,
34 date_time_crate,
35 lib,
36 model_extra_derives,
37 model_extra_attributes,
38 enum_extra_derives,
39 enum_extra_attributes,
40 seaography,
41 impl_active_model_behavior,
42 } => {
43 if verbose {
44 let _ = tracing_subscriber::fmt()
45 .with_max_level(tracing::Level::DEBUG)
46 .with_test_writer()
47 .try_init();
48 } else {
49 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
50 let fmt_layer = tracing_subscriber::fmt::layer()
51 .with_target(false)
52 .with_level(false)
53 .without_time();
54
55 let _ = tracing_subscriber::registry()
56 .with(filter_layer)
57 .with(fmt_layer)
58 .try_init();
59 }
60
61 let url = Url::parse(&database_url)?;
64
65 let is_sqlite = url.scheme() == "sqlite";
70
71 let filter_tables =
73 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
74
75 let filter_hidden_tables = |table: &str| -> bool {
76 if include_hidden_tables {
77 true
78 } else {
79 !table.starts_with('_')
80 }
81 };
82
83 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
84
85 let _database_name = if !is_sqlite {
86 let database_name = url
92 .path_segments()
93 .unwrap_or_else(|| {
94 panic!(
95 "There is no database name as part of the url path: {}",
96 url.as_str()
97 )
98 })
99 .next()
100 .unwrap();
101
102 if database_name.is_empty() {
104 panic!(
105 "There is no database name as part of the url path: {}",
106 url.as_str()
107 );
108 }
109
110 database_name
111 } else {
112 Default::default()
113 };
114
115 let (schema_name, table_stmts) = match url.scheme() {
116 "mysql" => {
117 #[cfg(not(feature = "sqlx-mysql"))]
118 {
119 panic!("mysql feature is off")
120 }
121 #[cfg(feature = "sqlx-mysql")]
122 {
123 use sea_schema::mysql::discovery::SchemaDiscovery;
124 use sqlx::MySql;
125
126 println!("Connecting to MySQL ...");
127 let connection = sqlx_connect::<MySql>(
128 max_connections,
129 acquire_timeout,
130 url.as_str(),
131 None,
132 )
133 .await?;
134 println!("Discovering schema ...");
135 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
136 let schema = schema_discovery.discover().await?;
137 let table_stmts = schema
138 .tables
139 .into_iter()
140 .filter(|schema| filter_tables(&schema.info.name))
141 .filter(|schema| filter_hidden_tables(&schema.info.name))
142 .filter(|schema| filter_skip_tables(&schema.info.name))
143 .map(|schema| schema.write())
144 .collect();
145 (None, table_stmts)
146 }
147 }
148 "sqlite" => {
149 #[cfg(not(feature = "sqlx-sqlite"))]
150 {
151 panic!("sqlite feature is off")
152 }
153 #[cfg(feature = "sqlx-sqlite")]
154 {
155 use sea_schema::sqlite::discovery::SchemaDiscovery;
156 use sqlx::Sqlite;
157
158 println!("Connecting to SQLite ...");
159 let connection = sqlx_connect::<Sqlite>(
160 max_connections,
161 acquire_timeout,
162 url.as_str(),
163 None,
164 )
165 .await?;
166 println!("Discovering schema ...");
167 let schema_discovery = SchemaDiscovery::new(connection);
168 let schema = schema_discovery
169 .discover()
170 .await?
171 .merge_indexes_into_table();
172 let table_stmts = schema
173 .tables
174 .into_iter()
175 .filter(|schema| filter_tables(&schema.name))
176 .filter(|schema| filter_hidden_tables(&schema.name))
177 .filter(|schema| filter_skip_tables(&schema.name))
178 .map(|schema| schema.write())
179 .collect();
180 (None, table_stmts)
181 }
182 }
183 "postgres" | "postgresql" => {
184 #[cfg(not(feature = "sqlx-postgres"))]
185 {
186 panic!("postgres feature is off")
187 }
188 #[cfg(feature = "sqlx-postgres")]
189 {
190 use sea_schema::postgres::discovery::SchemaDiscovery;
191 use sqlx::Postgres;
192
193 println!("Connecting to Postgres ...");
194 let schema = database_schema.as_deref().unwrap_or("public");
195 let connection = sqlx_connect::<Postgres>(
196 max_connections,
197 acquire_timeout,
198 url.as_str(),
199 Some(schema),
200 )
201 .await?;
202 println!("Discovering schema ...");
203 let schema_discovery = SchemaDiscovery::new(connection, schema);
204 let schema = schema_discovery.discover().await?;
205 let table_stmts = schema
206 .tables
207 .into_iter()
208 .filter(|schema| filter_tables(&schema.info.name))
209 .filter(|schema| filter_hidden_tables(&schema.info.name))
210 .filter(|schema| filter_skip_tables(&schema.info.name))
211 .map(|schema| schema.write())
212 .collect();
213 (database_schema, table_stmts)
214 }
215 }
216 _ => unimplemented!("{} is not supported", url.scheme()),
217 };
218 println!("... discovered.");
219
220 let writer_context = EntityWriterContext::new(
221 expanded_format,
222 frontend_format,
223 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
224 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
225 with_copy_enums,
226 date_time_crate.into(),
227 schema_name,
228 lib,
229 serde_skip_deserializing_primary_key,
230 serde_skip_hidden_column,
231 model_extra_derives,
232 model_extra_attributes,
233 enum_extra_derives,
234 enum_extra_attributes,
235 seaography,
236 impl_active_model_behavior,
237 );
238 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
239
240 let dir = Path::new(&output_dir);
241 fs::create_dir_all(dir)?;
242
243 for OutputFile { name, content } in output.files.iter() {
244 let file_path = dir.join(name);
245 println!("Writing {}", file_path.display());
246 let mut file = fs::File::create(file_path)?;
247 file.write_all(content.as_bytes())?;
248 }
249
250 for OutputFile { name, .. } in output.files.iter() {
252 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
254 return Err(format!("Fail to format file `{name}`").into());
256 }
257 }
258
259 println!("... Done.");
260 }
261 }
262
263 Ok(())
264}
265
266async fn sqlx_connect<DB>(
267 max_connections: u32,
268 acquire_timeout: u64,
269 url: &str,
270 schema: Option<&str>,
271) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
272where
273 DB: sqlx::Database,
274 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
275{
276 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
277 .max_connections(max_connections)
278 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
279 if let Some(schema) = schema {
282 let sql = format!("SET search_path = '{schema}'");
283 pool_options = pool_options.after_connect(move |conn, _| {
284 let sql = sql.clone();
285 Box::pin(async move {
286 sqlx::Executor::execute(conn, sql.as_str())
287 .await
288 .map(|_| ())
289 })
290 });
291 }
292 pool_options.connect(url).await.map_err(Into::into)
293}
294
295impl From<DateTimeCrate> for CodegenDateTimeCrate {
296 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
297 match date_time_crate {
298 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
299 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use clap::Parser;
307
308 use super::*;
309 use crate::{Cli, Commands};
310
311 #[test]
312 #[should_panic(
313 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
314 )]
315 fn test_generate_entity_no_protocol() {
316 let cli = Cli::parse_from([
317 "sea-orm-cli",
318 "generate",
319 "entity",
320 "--database-url",
321 "://root:root@localhost:3306/database",
322 ]);
323
324 match cli.command {
325 Commands::Generate { command } => {
326 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
327 }
328 _ => unreachable!(),
329 }
330 }
331
332 #[test]
333 #[should_panic(
334 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
335 )]
336 fn test_generate_entity_no_database_section() {
337 let cli = Cli::parse_from([
338 "sea-orm-cli",
339 "generate",
340 "entity",
341 "--database-url",
342 "postgresql://root:root@localhost:3306",
343 ]);
344
345 match cli.command {
346 Commands::Generate { command } => {
347 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
348 }
349 _ => unreachable!(),
350 }
351 }
352
353 #[test]
354 #[should_panic(
355 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
356 )]
357 fn test_generate_entity_no_database_path() {
358 let cli = Cli::parse_from([
359 "sea-orm-cli",
360 "generate",
361 "entity",
362 "--database-url",
363 "mysql://root:root@localhost:3306/",
364 ]);
365
366 match cli.command {
367 Commands::Generate { command } => {
368 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
369 }
370 _ => unreachable!(),
371 }
372 }
373
374 #[test]
375 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
376 fn test_generate_entity_no_password() {
377 let cli = Cli::parse_from([
378 "sea-orm-cli",
379 "generate",
380 "entity",
381 "--database-url",
382 "mysql://root:@localhost:3306/database",
383 ]);
384
385 match cli.command {
386 Commands::Generate { command } => {
387 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
388 }
389 _ => unreachable!(),
390 }
391 }
392
393 #[test]
394 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
395 fn test_generate_entity_no_host() {
396 let cli = Cli::parse_from([
397 "sea-orm-cli",
398 "generate",
399 "entity",
400 "--database-url",
401 "postgres://root:root@/database",
402 ]);
403
404 match cli.command {
405 Commands::Generate { command } => {
406 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
407 }
408 _ => unreachable!(),
409 }
410 }
411}