1use crate::{BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4 BigIntegerType as CodegenBigIntegerType, DateTimeCrate as CodegenDateTimeCrate, EntityFormat,
5 EntityTransformer, EntityWriterContext, MergeReport, OutputFile, WithPrelude, WithSerde,
6 merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 entity_format,
19 compact_format: _,
20 expanded_format,
21 frontend_format,
22 include_hidden_tables,
23 tables,
24 ignore_tables,
25 max_connections,
26 acquire_timeout,
27 output_dir,
28 database_schema,
29 database_url,
30 with_prelude,
31 with_serde,
32 serde_skip_deserializing_primary_key,
33 serde_skip_hidden_column,
34 with_copy_enums,
35 date_time_crate,
36 big_integer_type,
37 lib,
38 model_extra_derives,
39 model_extra_attributes,
40 enum_extra_derives,
41 enum_extra_attributes,
42 column_extra_derives,
43 seaography,
44 impl_active_model_behavior,
45 preserve_user_modifications,
46 } => {
47 if verbose {
48 let _ = tracing_subscriber::fmt()
49 .with_max_level(tracing::Level::DEBUG)
50 .with_test_writer()
51 .try_init();
52 } else {
53 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
54 let fmt_layer = tracing_subscriber::fmt::layer()
55 .with_target(false)
56 .with_level(false)
57 .without_time();
58
59 let _ = tracing_subscriber::registry()
60 .with(filter_layer)
61 .with(fmt_layer)
62 .try_init();
63 }
64
65 let url = Url::parse(&database_url)?;
68
69 let is_sqlite = url.scheme() == "sqlite";
74
75 let filter_tables =
77 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
78
79 let filter_hidden_tables = |table: &str| -> bool {
80 if include_hidden_tables {
81 true
82 } else {
83 !table.starts_with('_')
84 }
85 };
86
87 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
88
89 let _database_name = if !is_sqlite {
90 let database_name = url
96 .path_segments()
97 .unwrap_or_else(|| {
98 panic!(
99 "There is no database name as part of the url path: {}",
100 url.as_str()
101 )
102 })
103 .next()
104 .unwrap();
105
106 if database_name.is_empty() {
108 panic!(
109 "There is no database name as part of the url path: {}",
110 url.as_str()
111 );
112 }
113
114 database_name
115 } else {
116 Default::default()
117 };
118
119 let (schema_name, table_stmts) = match url.scheme() {
120 "mysql" => {
121 #[cfg(not(feature = "sqlx-mysql"))]
122 {
123 panic!("mysql feature is off")
124 }
125 #[cfg(feature = "sqlx-mysql")]
126 {
127 use sea_schema::mysql::discovery::SchemaDiscovery;
128 use sqlx::MySql;
129
130 println!("Connecting to MySQL ...");
131 let connection = sqlx_connect::<MySql>(
132 max_connections,
133 acquire_timeout,
134 url.as_str(),
135 None,
136 )
137 .await?;
138 println!("Discovering schema ...");
139 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
140 let schema = schema_discovery.discover().await?;
141 let table_stmts = schema
142 .tables
143 .into_iter()
144 .filter(|schema| filter_tables(&schema.info.name))
145 .filter(|schema| filter_hidden_tables(&schema.info.name))
146 .filter(|schema| filter_skip_tables(&schema.info.name))
147 .map(|schema| schema.write())
148 .collect();
149 (None, table_stmts)
150 }
151 }
152 "sqlite" => {
153 #[cfg(not(feature = "sqlx-sqlite"))]
154 {
155 panic!("sqlite feature is off")
156 }
157 #[cfg(feature = "sqlx-sqlite")]
158 {
159 use sea_schema::sqlite::discovery::SchemaDiscovery;
160 use sqlx::Sqlite;
161
162 println!("Connecting to SQLite ...");
163 let connection = sqlx_connect::<Sqlite>(
164 max_connections,
165 acquire_timeout,
166 url.as_str(),
167 None,
168 )
169 .await?;
170 println!("Discovering schema ...");
171 let schema_discovery = SchemaDiscovery::new(connection);
172 let schema = schema_discovery
173 .discover()
174 .await?
175 .merge_indexes_into_table();
176 let table_stmts = schema
177 .tables
178 .into_iter()
179 .filter(|schema| filter_tables(&schema.name))
180 .filter(|schema| filter_hidden_tables(&schema.name))
181 .filter(|schema| filter_skip_tables(&schema.name))
182 .map(|schema| schema.write())
183 .collect();
184 (None, table_stmts)
185 }
186 }
187 "postgres" | "postgresql" => {
188 #[cfg(not(feature = "sqlx-postgres"))]
189 {
190 panic!("postgres feature is off")
191 }
192 #[cfg(feature = "sqlx-postgres")]
193 {
194 use sea_schema::postgres::discovery::SchemaDiscovery;
195 use sqlx::Postgres;
196
197 println!("Connecting to Postgres ...");
198 let schema = database_schema.as_deref().unwrap_or("public");
199 let connection = sqlx_connect::<Postgres>(
200 max_connections,
201 acquire_timeout,
202 url.as_str(),
203 Some(schema),
204 )
205 .await?;
206 println!("Discovering schema ...");
207 let schema_discovery = SchemaDiscovery::new(connection, schema);
208 let schema = schema_discovery.discover().await?;
209 let table_stmts = schema
210 .tables
211 .into_iter()
212 .filter(|schema| filter_tables(&schema.info.name))
213 .filter(|schema| filter_hidden_tables(&schema.info.name))
214 .filter(|schema| filter_skip_tables(&schema.info.name))
215 .map(|schema| schema.write())
216 .collect();
217 (database_schema, table_stmts)
218 }
219 }
220 _ => unimplemented!("{} is not supported", url.scheme()),
221 };
222 println!("... discovered.");
223
224 let writer_context = EntityWriterContext::new(
225 if expanded_format {
226 EntityFormat::Expanded
227 } else if frontend_format {
228 EntityFormat::Frontend
229 } else if let Some(entity_format) = entity_format {
230 EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
231 } else {
232 EntityFormat::default()
233 },
234 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
235 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
236 with_copy_enums,
237 date_time_crate.into(),
238 big_integer_type.into(),
239 schema_name,
240 lib,
241 serde_skip_deserializing_primary_key,
242 serde_skip_hidden_column,
243 model_extra_derives,
244 model_extra_attributes,
245 enum_extra_derives,
246 enum_extra_attributes,
247 column_extra_derives,
248 seaography,
249 impl_active_model_behavior,
250 );
251 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
252
253 let dir = Path::new(&output_dir);
254 fs::create_dir_all(dir)?;
255
256 let mut merge_fallback_files: Vec<String> = Vec::new();
257
258 for OutputFile { name, content } in output.files.iter() {
259 let file_path = dir.join(name);
260 println!("Writing {}", file_path.display());
261
262 if !matches!(
263 name.as_str(),
264 "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
265 ) && file_path.exists()
266 && preserve_user_modifications
267 {
268 let prev_content = fs::read_to_string(&file_path)?;
269 match merge_entity_files(&prev_content, content) {
270 Ok(merged) => {
271 fs::write(file_path, merged)?;
272 }
273 Err(MergeReport {
274 output,
275 warnings,
276 fallback_applied,
277 }) => {
278 for message in warnings {
279 eprintln!("{message}");
280 }
281 fs::write(file_path, output)?;
282 if fallback_applied {
283 merge_fallback_files.push(name.clone());
284 }
285 }
286 }
287 } else {
288 fs::write(file_path, content)?;
289 };
290 }
291
292 for OutputFile { name, .. } in output.files.iter() {
294 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
296 return Err(format!("Fail to format file `{name}`").into());
298 }
299 }
300
301 if merge_fallback_files.is_empty() {
302 println!("... Done.");
303 } else {
304 return Err(format!(
305 "Merge fallback applied for {} file(s): \n{}",
306 merge_fallback_files.len(),
307 merge_fallback_files.join("\n")
308 )
309 .into());
310 }
311 }
312 }
313
314 Ok(())
315}
316
317async fn sqlx_connect<DB>(
318 max_connections: u32,
319 acquire_timeout: u64,
320 url: &str,
321 schema: Option<&str>,
322) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
323where
324 DB: sqlx::Database,
325 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
326{
327 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
328 .max_connections(max_connections)
329 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
330 if let Some(schema) = schema {
333 let sql = format!("SET search_path = '{schema}'");
334 pool_options = pool_options.after_connect(move |conn, _| {
335 let sql = sql.clone();
336 Box::pin(async move {
337 sqlx::Executor::execute(conn, sql.as_str())
338 .await
339 .map(|_| ())
340 })
341 });
342 }
343 pool_options.connect(url).await.map_err(Into::into)
344}
345
346impl From<DateTimeCrate> for CodegenDateTimeCrate {
347 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
348 match date_time_crate {
349 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
350 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
351 }
352 }
353}
354
355impl From<BigIntegerType> for CodegenBigIntegerType {
356 fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
357 match date_time_crate {
358 BigIntegerType::I64 => CodegenBigIntegerType::I64,
359 BigIntegerType::I32 => CodegenBigIntegerType::I32,
360 }
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use clap::Parser;
367
368 use super::*;
369 use crate::{Cli, Commands};
370
371 #[test]
372 #[should_panic(
373 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
374 )]
375 fn test_generate_entity_no_protocol() {
376 let cli = Cli::parse_from([
377 "sea-orm-cli",
378 "generate",
379 "entity",
380 "--database-url",
381 "://root:root@localhost:3306/database",
382 ]);
383
384 match cli.command {
385 Commands::Generate { command } => {
386 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
387 }
388 _ => unreachable!(),
389 }
390 }
391
392 #[test]
393 #[should_panic(
394 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
395 )]
396 fn test_generate_entity_no_database_section() {
397 let cli = Cli::parse_from([
398 "sea-orm-cli",
399 "generate",
400 "entity",
401 "--database-url",
402 "postgresql://root:root@localhost:3306",
403 ]);
404
405 match cli.command {
406 Commands::Generate { command } => {
407 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
408 }
409 _ => unreachable!(),
410 }
411 }
412
413 #[test]
414 #[should_panic(
415 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
416 )]
417 fn test_generate_entity_no_database_path() {
418 let cli = Cli::parse_from([
419 "sea-orm-cli",
420 "generate",
421 "entity",
422 "--database-url",
423 "mysql://root:root@localhost:3306/",
424 ]);
425
426 match cli.command {
427 Commands::Generate { command } => {
428 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
429 }
430 _ => unreachable!(),
431 }
432 }
433
434 #[test]
435 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
436 fn test_generate_entity_no_host() {
437 let cli = Cli::parse_from([
438 "sea-orm-cli",
439 "generate",
440 "entity",
441 "--database-url",
442 "postgres://root:root@/database",
443 ]);
444
445 match cli.command {
446 Commands::Generate { command } => {
447 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
448 }
449 _ => unreachable!(),
450 }
451 }
452}