1use core::time;
2use sea_orm_codegen::{
3 DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
4 WithPrelude, WithSerde,
5};
6use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{prelude::*, EnvFilter};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 compact_format: _,
19 expanded_format,
20 include_hidden_tables,
21 tables,
22 ignore_tables,
23 max_connections,
24 acquire_timeout,
25 output_dir,
26 database_schema,
27 database_url,
28 with_prelude,
29 with_serde,
30 serde_skip_deserializing_primary_key,
31 serde_skip_hidden_column,
32 with_copy_enums,
33 date_time_crate,
34 lib,
35 model_extra_derives,
36 model_extra_attributes,
37 enum_extra_derives,
38 enum_extra_attributes,
39 seaography,
40 impl_active_model_behavior,
41 } => {
42 if verbose {
43 let _ = tracing_subscriber::fmt()
44 .with_max_level(tracing::Level::DEBUG)
45 .with_test_writer()
46 .try_init();
47 } else {
48 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
49 let fmt_layer = tracing_subscriber::fmt::layer()
50 .with_target(false)
51 .with_level(false)
52 .without_time();
53
54 let _ = tracing_subscriber::registry()
55 .with(filter_layer)
56 .with(fmt_layer)
57 .try_init();
58 }
59
60 let url = Url::parse(&database_url)?;
63
64 let is_sqlite = url.scheme() == "sqlite";
69
70 let filter_tables =
72 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
73
74 let filter_hidden_tables = |table: &str| -> bool {
75 if include_hidden_tables {
76 true
77 } else {
78 !table.starts_with('_')
79 }
80 };
81
82 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
83
84 let _database_name = if !is_sqlite {
85 let database_name = url
91 .path_segments()
92 .unwrap_or_else(|| {
93 panic!(
94 "There is no database name as part of the url path: {}",
95 url.as_str()
96 )
97 })
98 .next()
99 .unwrap();
100
101 if database_name.is_empty() {
103 panic!(
104 "There is no database name as part of the url path: {}",
105 url.as_str()
106 );
107 }
108
109 database_name
110 } else {
111 Default::default()
112 };
113
114 let (schema_name, table_stmts) = match url.scheme() {
115 "mysql" => {
116 #[cfg(not(feature = "sqlx-mysql"))]
117 {
118 panic!("mysql feature is off")
119 }
120 #[cfg(feature = "sqlx-mysql")]
121 {
122 use sea_schema::mysql::discovery::SchemaDiscovery;
123 use sqlx::MySql;
124
125 println!("Connecting to MySQL ...");
126 let connection = sqlx_connect::<MySql>(
127 max_connections,
128 acquire_timeout,
129 url.as_str(),
130 None,
131 )
132 .await?;
133 println!("Discovering schema ...");
134 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
135 let schema = schema_discovery.discover().await?;
136 let table_stmts = schema
137 .tables
138 .into_iter()
139 .filter(|schema| filter_tables(&schema.info.name))
140 .filter(|schema| filter_hidden_tables(&schema.info.name))
141 .filter(|schema| filter_skip_tables(&schema.info.name))
142 .map(|schema| schema.write())
143 .collect();
144 (None, table_stmts)
145 }
146 }
147 "sqlite" => {
148 #[cfg(not(feature = "sqlx-sqlite"))]
149 {
150 panic!("sqlite feature is off")
151 }
152 #[cfg(feature = "sqlx-sqlite")]
153 {
154 use sea_schema::sqlite::discovery::SchemaDiscovery;
155 use sqlx::Sqlite;
156
157 println!("Connecting to SQLite ...");
158 let connection = sqlx_connect::<Sqlite>(
159 max_connections,
160 acquire_timeout,
161 url.as_str(),
162 None,
163 )
164 .await?;
165 println!("Discovering schema ...");
166 let schema_discovery = SchemaDiscovery::new(connection);
167 let schema = schema_discovery
168 .discover()
169 .await?
170 .merge_indexes_into_table();
171 let table_stmts = schema
172 .tables
173 .into_iter()
174 .filter(|schema| filter_tables(&schema.name))
175 .filter(|schema| filter_hidden_tables(&schema.name))
176 .filter(|schema| filter_skip_tables(&schema.name))
177 .map(|schema| schema.write())
178 .collect();
179 (None, table_stmts)
180 }
181 }
182 "postgres" | "postgresql" => {
183 #[cfg(not(feature = "sqlx-postgres"))]
184 {
185 panic!("postgres feature is off")
186 }
187 #[cfg(feature = "sqlx-postgres")]
188 {
189 use sea_schema::postgres::discovery::SchemaDiscovery;
190 use sqlx::Postgres;
191
192 println!("Connecting to Postgres ...");
193 let schema = database_schema.as_deref().unwrap_or("public");
194 let connection = sqlx_connect::<Postgres>(
195 max_connections,
196 acquire_timeout,
197 url.as_str(),
198 Some(schema),
199 )
200 .await?;
201 println!("Discovering schema ...");
202 let schema_discovery = SchemaDiscovery::new(connection, schema);
203 let schema = schema_discovery.discover().await?;
204 let table_stmts = schema
205 .tables
206 .into_iter()
207 .filter(|schema| filter_tables(&schema.info.name))
208 .filter(|schema| filter_hidden_tables(&schema.info.name))
209 .filter(|schema| filter_skip_tables(&schema.info.name))
210 .map(|schema| schema.write())
211 .collect();
212 (database_schema, table_stmts)
213 }
214 }
215 _ => unimplemented!("{} is not supported", url.scheme()),
216 };
217 println!("... discovered.");
218
219 let writer_context = EntityWriterContext::new(
220 expanded_format,
221 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
222 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
223 with_copy_enums,
224 date_time_crate.into(),
225 schema_name,
226 lib,
227 serde_skip_deserializing_primary_key,
228 serde_skip_hidden_column,
229 model_extra_derives,
230 model_extra_attributes,
231 enum_extra_derives,
232 enum_extra_attributes,
233 seaography,
234 impl_active_model_behavior,
235 );
236 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
237
238 let dir = Path::new(&output_dir);
239 fs::create_dir_all(dir)?;
240
241 for OutputFile { name, content } in output.files.iter() {
242 let file_path = dir.join(name);
243 println!("Writing {}", file_path.display());
244 let mut file = fs::File::create(file_path)?;
245 file.write_all(content.as_bytes())?;
246 }
247
248 for OutputFile { name, .. } in output.files.iter() {
250 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
252 return Err(format!("Fail to format file `{name}`").into());
254 }
255 }
256
257 println!("... Done.");
258 }
259 }
260
261 Ok(())
262}
263
264async fn sqlx_connect<DB>(
265 max_connections: u32,
266 acquire_timeout: u64,
267 url: &str,
268 schema: Option<&str>,
269) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
270where
271 DB: sqlx::Database,
272 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
273{
274 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
275 .max_connections(max_connections)
276 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
277 if let Some(schema) = schema {
280 let sql = format!("SET search_path = '{schema}'");
281 pool_options = pool_options.after_connect(move |conn, _| {
282 let sql = sql.clone();
283 Box::pin(async move {
284 sqlx::Executor::execute(conn, sql.as_str())
285 .await
286 .map(|_| ())
287 })
288 });
289 }
290 pool_options.connect(url).await.map_err(Into::into)
291}
292
293impl From<DateTimeCrate> for CodegenDateTimeCrate {
294 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
295 match date_time_crate {
296 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
297 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use clap::Parser;
305
306 use super::*;
307 use crate::{Cli, Commands};
308
309 #[test]
310 #[should_panic(
311 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
312 )]
313 fn test_generate_entity_no_protocol() {
314 let cli = Cli::parse_from([
315 "sea-orm-cli",
316 "generate",
317 "entity",
318 "--database-url",
319 "://root:root@localhost:3306/database",
320 ]);
321
322 match cli.command {
323 Commands::Generate { command } => {
324 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
325 }
326 _ => unreachable!(),
327 }
328 }
329
330 #[test]
331 #[should_panic(
332 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
333 )]
334 fn test_generate_entity_no_database_section() {
335 let cli = Cli::parse_from([
336 "sea-orm-cli",
337 "generate",
338 "entity",
339 "--database-url",
340 "postgresql://root:root@localhost:3306",
341 ]);
342
343 match cli.command {
344 Commands::Generate { command } => {
345 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
346 }
347 _ => unreachable!(),
348 }
349 }
350
351 #[test]
352 #[should_panic(
353 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
354 )]
355 fn test_generate_entity_no_database_path() {
356 let cli = Cli::parse_from([
357 "sea-orm-cli",
358 "generate",
359 "entity",
360 "--database-url",
361 "mysql://root:root@localhost:3306/",
362 ]);
363
364 match cli.command {
365 Commands::Generate { command } => {
366 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
367 }
368 _ => unreachable!(),
369 }
370 }
371
372 #[test]
373 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
374 fn test_generate_entity_no_password() {
375 let cli = Cli::parse_from([
376 "sea-orm-cli",
377 "generate",
378 "entity",
379 "--database-url",
380 "mysql://root:@localhost:3306/database",
381 ]);
382
383 match cli.command {
384 Commands::Generate { command } => {
385 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
386 }
387 _ => unreachable!(),
388 }
389 }
390
391 #[test]
392 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
393 fn test_generate_entity_no_host() {
394 let cli = Cli::parse_from([
395 "sea-orm-cli",
396 "generate",
397 "entity",
398 "--database-url",
399 "postgres://root:root@/database",
400 ]);
401
402 match cli.command {
403 Commands::Generate { command } => {
404 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
405 }
406 _ => unreachable!(),
407 }
408 }
409}