1use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4 BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
5 DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
6 MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 entity_format,
19 compact_format: _,
20 expanded_format,
21 frontend_format,
22 include_hidden_tables,
23 tables,
24 ignore_tables,
25 max_connections,
26 acquire_timeout,
27 output_dir,
28 database_schema,
29 database_url,
30 with_prelude,
31 with_serde,
32 serde_skip_deserializing_primary_key,
33 serde_skip_hidden_column,
34 with_copy_enums,
35 date_time_crate,
36 big_integer_type,
37 lib,
38 model_extra_derives,
39 model_extra_attributes,
40 enum_extra_derives,
41 enum_extra_attributes,
42 column_extra_derives,
43 seaography,
44 impl_active_model_behavior,
45 preserve_user_modifications,
46 banner_version,
47 er_diagram,
48 } => {
49 if verbose {
50 let _ = tracing_subscriber::fmt()
51 .with_max_level(tracing::Level::DEBUG)
52 .with_test_writer()
53 .try_init();
54 } else {
55 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
56 let fmt_layer = tracing_subscriber::fmt::layer()
57 .with_target(false)
58 .with_level(false)
59 .without_time();
60
61 let _ = tracing_subscriber::registry()
62 .with(filter_layer)
63 .with(fmt_layer)
64 .try_init();
65 }
66
67 let url = Url::parse(&database_url)?;
70
71 let is_sqlite = url.scheme() == "sqlite";
76
77 let filter_tables =
79 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
80
81 let filter_hidden_tables = |table: &str| -> bool {
82 if include_hidden_tables {
83 true
84 } else {
85 !table.starts_with('_')
86 }
87 };
88
89 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
90
91 let _database_name = if !is_sqlite {
92 let database_name = url
98 .path_segments()
99 .unwrap_or_else(|| {
100 panic!(
101 "There is no database name as part of the url path: {}",
102 url.as_str()
103 )
104 })
105 .next()
106 .unwrap();
107
108 if database_name.is_empty() {
110 panic!(
111 "There is no database name as part of the url path: {}",
112 url.as_str()
113 );
114 }
115
116 database_name
117 } else {
118 Default::default()
119 };
120
121 let (schema_name, table_stmts) = match url.scheme() {
122 "mysql" => {
123 #[cfg(not(feature = "sqlx-mysql"))]
124 {
125 panic!("mysql feature is off")
126 }
127 #[cfg(feature = "sqlx-mysql")]
128 {
129 use sea_schema::mysql::discovery::SchemaDiscovery;
130 use sqlx::MySql;
131
132 println!("Connecting to MySQL ...");
133 let connection = sqlx_connect::<MySql>(
134 max_connections,
135 acquire_timeout,
136 url.as_str(),
137 None,
138 )
139 .await?;
140 println!("Discovering schema ...");
141 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
142 let schema = schema_discovery.discover().await?;
143 let table_stmts = schema
144 .tables
145 .into_iter()
146 .filter(|schema| filter_tables(&schema.info.name))
147 .filter(|schema| filter_hidden_tables(&schema.info.name))
148 .filter(|schema| filter_skip_tables(&schema.info.name))
149 .map(|schema| schema.write())
150 .collect();
151 (None, table_stmts)
152 }
153 }
154 "sqlite" => {
155 #[cfg(not(feature = "sqlx-sqlite"))]
156 {
157 panic!("sqlite feature is off")
158 }
159 #[cfg(feature = "sqlx-sqlite")]
160 {
161 use sea_schema::sqlite::discovery::SchemaDiscovery;
162 use sqlx::Sqlite;
163
164 println!("Connecting to SQLite ...");
165 let connection = sqlx_connect::<Sqlite>(
166 max_connections,
167 acquire_timeout,
168 url.as_str(),
169 None,
170 )
171 .await?;
172 println!("Discovering schema ...");
173 let schema_discovery = SchemaDiscovery::new(connection);
174 let schema = schema_discovery
175 .discover()
176 .await?
177 .merge_indexes_into_table();
178 let table_stmts = schema
179 .tables
180 .into_iter()
181 .filter(|schema| filter_tables(&schema.name))
182 .filter(|schema| filter_hidden_tables(&schema.name))
183 .filter(|schema| filter_skip_tables(&schema.name))
184 .map(|schema| schema.write())
185 .collect();
186 (None, table_stmts)
187 }
188 }
189 "postgres" | "postgresql" => {
190 #[cfg(not(feature = "sqlx-postgres"))]
191 {
192 panic!("postgres feature is off")
193 }
194 #[cfg(feature = "sqlx-postgres")]
195 {
196 use sea_schema::postgres::discovery::SchemaDiscovery;
197 use sqlx::Postgres;
198
199 println!("Connecting to Postgres ...");
200 let schema = database_schema.as_deref().unwrap_or("public");
201 let connection = sqlx_connect::<Postgres>(
202 max_connections,
203 acquire_timeout,
204 url.as_str(),
205 Some(schema),
206 )
207 .await?;
208 println!("Discovering schema ...");
209 let schema_discovery = SchemaDiscovery::new(connection, schema);
210 let schema = schema_discovery.discover().await?;
211 let table_stmts = schema
212 .tables
213 .into_iter()
214 .filter(|schema| filter_tables(&schema.info.name))
215 .filter(|schema| filter_hidden_tables(&schema.info.name))
216 .filter(|schema| filter_skip_tables(&schema.info.name))
217 .map(|schema| schema.write())
218 .collect();
219 (database_schema, table_stmts)
220 }
221 }
222 _ => unimplemented!("{} is not supported", url.scheme()),
223 };
224 println!("... discovered.");
225
226 let writer_context = EntityWriterContext::new(
227 if expanded_format {
228 EntityFormat::Expanded
229 } else if frontend_format {
230 EntityFormat::Frontend
231 } else if let Some(entity_format) = entity_format {
232 EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
233 } else {
234 EntityFormat::default()
235 },
236 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
237 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
238 with_copy_enums,
239 date_time_crate.into(),
240 big_integer_type.into(),
241 schema_name,
242 lib,
243 serde_skip_deserializing_primary_key,
244 serde_skip_hidden_column,
245 model_extra_derives,
246 model_extra_attributes,
247 enum_extra_derives,
248 enum_extra_attributes,
249 column_extra_derives,
250 seaography,
251 impl_active_model_behavior,
252 banner_version.into(),
253 );
254 let entity_writer = EntityTransformer::transform(table_stmts)?;
255
256 let dir = Path::new(&output_dir);
257 fs::create_dir_all(dir)?;
258
259 if er_diagram {
260 let diagram = entity_writer.generate_er_diagram();
261 let diagram_path = dir.join("entities.mermaid");
262 fs::write(&diagram_path, &diagram)?;
263 println!("Writing {}", diagram_path.display());
264 }
265
266 let output = entity_writer.generate(&writer_context);
267
268 let mut merge_fallback_files: Vec<String> = Vec::new();
269
270 for OutputFile { name, content } in output.files.iter() {
271 let file_path = dir.join(name);
272 println!("Writing {}", file_path.display());
273
274 if !matches!(
275 name.as_str(),
276 "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
277 ) && file_path.exists()
278 && preserve_user_modifications
279 {
280 let prev_content = fs::read_to_string(&file_path)?;
281 match merge_entity_files(&prev_content, content) {
282 Ok(merged) => {
283 fs::write(file_path, merged)?;
284 }
285 Err(MergeReport {
286 output,
287 warnings,
288 fallback_applied,
289 }) => {
290 for message in warnings {
291 eprintln!("{message}");
292 }
293 fs::write(file_path, output)?;
294 if fallback_applied {
295 merge_fallback_files.push(name.clone());
296 }
297 }
298 }
299 } else {
300 fs::write(file_path, content)?;
301 };
302 }
303
304 for OutputFile { name, .. } in output.files.iter() {
306 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
308 return Err(format!("Fail to format file `{name}`").into());
310 }
311 }
312
313 if merge_fallback_files.is_empty() {
314 println!("... Done.");
315 } else {
316 return Err(format!(
317 "Merge fallback applied for {} file(s): \n{}",
318 merge_fallback_files.len(),
319 merge_fallback_files.join("\n")
320 )
321 .into());
322 }
323 }
324 }
325
326 Ok(())
327}
328
329async fn sqlx_connect<DB>(
330 max_connections: u32,
331 acquire_timeout: u64,
332 url: &str,
333 schema: Option<&str>,
334) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
335where
336 DB: sqlx::Database,
337 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
338{
339 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
340 .max_connections(max_connections)
341 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
342 if let Some(schema) = schema {
345 let sql = format!("SET search_path = '{schema}'");
346 pool_options = pool_options.after_connect(move |conn, _| {
347 let sql = sql.clone();
348 Box::pin(async move {
349 sqlx::Executor::execute(conn, sql.as_str())
350 .await
351 .map(|_| ())
352 })
353 });
354 }
355 pool_options.connect(url).await.map_err(Into::into)
356}
357
358impl From<DateTimeCrate> for CodegenDateTimeCrate {
359 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
360 match date_time_crate {
361 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
362 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
363 }
364 }
365}
366
367impl From<BigIntegerType> for CodegenBigIntegerType {
368 fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
369 match date_time_crate {
370 BigIntegerType::I64 => CodegenBigIntegerType::I64,
371 BigIntegerType::I32 => CodegenBigIntegerType::I32,
372 }
373 }
374}
375
376impl From<BannerVersion> for CodegenBannerVersion {
377 fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
378 match banner_version {
379 BannerVersion::Off => CodegenBannerVersion::Off,
380 BannerVersion::Major => CodegenBannerVersion::Major,
381 BannerVersion::Minor => CodegenBannerVersion::Minor,
382 BannerVersion::Patch => CodegenBannerVersion::Patch,
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use clap::Parser;
390
391 use super::*;
392 use crate::{Cli, Commands};
393
394 #[test]
395 #[should_panic(
396 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
397 )]
398 fn test_generate_entity_no_protocol() {
399 let cli = Cli::parse_from([
400 "sea-orm-cli",
401 "generate",
402 "entity",
403 "--database-url",
404 "://root:root@localhost:3306/database",
405 ]);
406
407 match cli.command {
408 Commands::Generate { command } => {
409 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
410 }
411 _ => unreachable!(),
412 }
413 }
414
415 #[test]
416 #[should_panic(
417 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
418 )]
419 fn test_generate_entity_no_database_section() {
420 let cli = Cli::parse_from([
421 "sea-orm-cli",
422 "generate",
423 "entity",
424 "--database-url",
425 "postgresql://root:root@localhost:3306",
426 ]);
427
428 match cli.command {
429 Commands::Generate { command } => {
430 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
431 }
432 _ => unreachable!(),
433 }
434 }
435
436 #[test]
437 #[should_panic(
438 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
439 )]
440 fn test_generate_entity_no_database_path() {
441 let cli = Cli::parse_from([
442 "sea-orm-cli",
443 "generate",
444 "entity",
445 "--database-url",
446 "mysql://root:root@localhost:3306/",
447 ]);
448
449 match cli.command {
450 Commands::Generate { command } => {
451 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
452 }
453 _ => unreachable!(),
454 }
455 }
456
457 #[test]
458 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
459 fn test_generate_entity_no_host() {
460 let cli = Cli::parse_from([
461 "sea-orm-cli",
462 "generate",
463 "entity",
464 "--database-url",
465 "postgres://root:root@/database",
466 ]);
467
468 match cli.command {
469 Commands::Generate { command } => {
470 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
471 }
472 _ => unreachable!(),
473 }
474 }
475}