1use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4 BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
5 DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
6 MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12fn split_by_comma_ignoring_parentheses(s: &str) -> Vec<String> {
16 let mut result = Vec::new();
17 let mut current = String::new();
18 let mut paren_depth = 0usize;
19 let mut bracket_depth = 0usize;
20 let mut brace_depth = 0usize;
21
22 for c in s.chars() {
23 match c {
24 '(' => {
25 paren_depth += 1;
26 current.push(c);
27 }
28 ')' => {
29 paren_depth = paren_depth.saturating_sub(1);
30 current.push(c);
31 }
32 '[' => {
33 bracket_depth += 1;
34 current.push(c);
35 }
36 ']' => {
37 bracket_depth = bracket_depth.saturating_sub(1);
38 current.push(c);
39 }
40 '{' => {
41 brace_depth += 1;
42 current.push(c);
43 }
44 '}' => {
45 brace_depth = brace_depth.saturating_sub(1);
46 current.push(c);
47 }
48 ',' if paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 => {
49 let trimmed = current.trim();
50 if !trimmed.is_empty() {
51 result.push(trimmed.to_string());
52 }
53 current.clear();
54 }
55 _ => {
56 current.push(c);
57 }
58 }
59 }
60
61 let trimmed = current.trim();
63 if !trimmed.is_empty() {
64 result.push(trimmed.to_string());
65 }
66
67 result
68}
69
70fn process_comma_separated_values(values: Vec<String>) -> Vec<String> {
74 values
75 .into_iter()
76 .flat_map(|s| split_by_comma_ignoring_parentheses(&s))
77 .collect()
78}
79
80pub async fn run_generate_command(
81 command: GenerateSubcommands,
82 verbose: bool,
83) -> Result<(), Box<dyn Error>> {
84 match command {
85 GenerateSubcommands::Entity {
86 entity_format,
87 compact_format: _,
88 expanded_format,
89 frontend_format,
90 include_hidden_tables,
91 tables,
92 ignore_tables,
93 max_connections,
94 acquire_timeout,
95 output_dir,
96 database_schema,
97 database_url,
98 with_prelude,
99 with_serde,
100 serde_skip_deserializing_primary_key,
101 serde_skip_hidden_column,
102 with_copy_enums,
103 date_time_crate,
104 big_integer_type,
105 lib,
106 model_extra_derives,
107 model_extra_attributes,
108 enum_extra_derives,
109 enum_extra_attributes,
110 column_extra_derives,
111 seaography,
112 impl_active_model_behavior,
113 preserve_user_modifications,
114 banner_version,
115 er_diagram,
116 } => {
117 if verbose {
118 let _ = tracing_subscriber::fmt()
119 .with_max_level(tracing::Level::DEBUG)
120 .with_test_writer()
121 .try_init();
122 } else {
123 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
124 let fmt_layer = tracing_subscriber::fmt::layer()
125 .with_target(false)
126 .with_level(false)
127 .without_time();
128
129 let _ = tracing_subscriber::registry()
130 .with(filter_layer)
131 .with(fmt_layer)
132 .try_init();
133 }
134
135 let url = Url::parse(&database_url)?;
138
139 let is_sqlite = url.scheme() == "sqlite";
144
145 let filter_tables =
147 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
148
149 let filter_hidden_tables = |table: &str| -> bool {
150 if include_hidden_tables {
151 true
152 } else {
153 !table.starts_with('_')
154 }
155 };
156
157 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
158
159 let _database_name = if !is_sqlite {
160 let database_name = url
166 .path_segments()
167 .unwrap_or_else(|| {
168 panic!(
169 "There is no database name as part of the url path: {}",
170 url.as_str()
171 )
172 })
173 .next()
174 .unwrap();
175
176 if database_name.is_empty() {
178 panic!(
179 "There is no database name as part of the url path: {}",
180 url.as_str()
181 );
182 }
183
184 database_name
185 } else {
186 Default::default()
187 };
188
189 let (schema_name, table_stmts) = match url.scheme() {
190 "mysql" => {
191 #[cfg(not(feature = "sqlx-mysql"))]
192 {
193 panic!("mysql feature is off")
194 }
195 #[cfg(feature = "sqlx-mysql")]
196 {
197 use sea_schema::mysql::discovery::SchemaDiscovery;
198 use sqlx::MySql;
199
200 println!("Connecting to MySQL ...");
201 let connection = sqlx_connect::<MySql>(
202 max_connections,
203 acquire_timeout,
204 url.as_str(),
205 None,
206 )
207 .await?;
208 println!("Discovering schema ...");
209 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
210 let schema = schema_discovery.discover().await?;
211 let table_stmts = schema
212 .tables
213 .into_iter()
214 .filter(|schema| filter_tables(&schema.info.name))
215 .filter(|schema| filter_hidden_tables(&schema.info.name))
216 .filter(|schema| filter_skip_tables(&schema.info.name))
217 .map(|schema| schema.write())
218 .collect();
219 (None, table_stmts)
220 }
221 }
222 "sqlite" => {
223 #[cfg(not(feature = "sqlx-sqlite"))]
224 {
225 panic!("sqlite feature is off")
226 }
227 #[cfg(feature = "sqlx-sqlite")]
228 {
229 use sea_schema::sqlite::discovery::SchemaDiscovery;
230 use sqlx::Sqlite;
231
232 println!("Connecting to SQLite ...");
233 let connection = sqlx_connect::<Sqlite>(
234 max_connections,
235 acquire_timeout,
236 url.as_str(),
237 None,
238 )
239 .await?;
240 println!("Discovering schema ...");
241 let schema_discovery = SchemaDiscovery::new(connection);
242 let schema = schema_discovery
243 .discover()
244 .await?
245 .merge_indexes_into_table();
246 let table_stmts = schema
247 .tables
248 .into_iter()
249 .filter(|schema| filter_tables(&schema.name))
250 .filter(|schema| filter_hidden_tables(&schema.name))
251 .filter(|schema| filter_skip_tables(&schema.name))
252 .map(|schema| schema.write())
253 .collect();
254 (None, table_stmts)
255 }
256 }
257 "postgres" | "postgresql" => {
258 #[cfg(not(feature = "sqlx-postgres"))]
259 {
260 panic!("postgres feature is off")
261 }
262 #[cfg(feature = "sqlx-postgres")]
263 {
264 use sea_schema::postgres::discovery::SchemaDiscovery;
265 use sqlx::Postgres;
266
267 println!("Connecting to Postgres ...");
268 let schema = database_schema.as_deref().unwrap_or("public");
269 let connection = sqlx_connect::<Postgres>(
270 max_connections,
271 acquire_timeout,
272 url.as_str(),
273 Some(schema),
274 )
275 .await?;
276 println!("Discovering schema ...");
277 let schema_discovery = SchemaDiscovery::new(connection, schema);
278 let schema = schema_discovery.discover().await?;
279 let table_stmts = schema
280 .tables
281 .into_iter()
282 .filter(|schema| filter_tables(&schema.info.name))
283 .filter(|schema| filter_hidden_tables(&schema.info.name))
284 .filter(|schema| filter_skip_tables(&schema.info.name))
285 .map(|schema| schema.write())
286 .collect();
287 (database_schema, table_stmts)
288 }
289 }
290 _ => unimplemented!("{} is not supported", url.scheme()),
291 };
292 println!("... discovered.");
293
294 let model_extra_derives = process_comma_separated_values(model_extra_derives);
298 let model_extra_attributes = process_comma_separated_values(model_extra_attributes);
299 let enum_extra_derives = process_comma_separated_values(enum_extra_derives);
300 let enum_extra_attributes = process_comma_separated_values(enum_extra_attributes);
301 let column_extra_derives = process_comma_separated_values(column_extra_derives);
302
303 let writer_context = EntityWriterContext::new(
304 if expanded_format {
305 EntityFormat::Expanded
306 } else if frontend_format {
307 EntityFormat::Frontend
308 } else if let Some(entity_format) = entity_format {
309 EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
310 } else {
311 EntityFormat::default()
312 },
313 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
314 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
315 with_copy_enums,
316 date_time_crate.into(),
317 big_integer_type.into(),
318 schema_name,
319 lib,
320 serde_skip_deserializing_primary_key,
321 serde_skip_hidden_column,
322 model_extra_derives,
323 model_extra_attributes,
324 enum_extra_derives,
325 enum_extra_attributes,
326 column_extra_derives,
327 seaography,
328 impl_active_model_behavior,
329 banner_version.into(),
330 );
331 let entity_writer = EntityTransformer::transform(table_stmts)?;
332
333 let dir = Path::new(&output_dir);
334 fs::create_dir_all(dir)?;
335
336 if er_diagram {
337 let diagram = entity_writer.generate_er_diagram();
338 let diagram_path = dir.join("entities.mermaid");
339 fs::write(&diagram_path, &diagram)?;
340 println!("Writing {}", diagram_path.display());
341 }
342
343 let output = entity_writer.generate(&writer_context);
344
345 let mut merge_fallback_files: Vec<String> = Vec::new();
346
347 for OutputFile { name, content } in output.files.iter() {
348 let file_path = dir.join(name);
349 println!("Writing {}", file_path.display());
350
351 if !matches!(
352 name.as_str(),
353 "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
354 ) && file_path.exists()
355 && preserve_user_modifications
356 {
357 let prev_content = fs::read_to_string(&file_path)?;
358 match merge_entity_files(&prev_content, content) {
359 Ok(merged) => {
360 fs::write(file_path, merged)?;
361 }
362 Err(MergeReport {
363 output,
364 warnings,
365 fallback_applied,
366 }) => {
367 for message in warnings {
368 eprintln!("{message}");
369 }
370 fs::write(file_path, output)?;
371 if fallback_applied {
372 merge_fallback_files.push(name.clone());
373 }
374 }
375 }
376 } else {
377 fs::write(file_path, content)?;
378 };
379 }
380
381 for OutputFile { name, .. } in output.files.iter() {
383 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
385 return Err(format!("Fail to format file `{name}`").into());
387 }
388 }
389
390 if merge_fallback_files.is_empty() {
391 println!("... Done.");
392 } else {
393 return Err(format!(
394 "Merge fallback applied for {} file(s): \n{}",
395 merge_fallback_files.len(),
396 merge_fallback_files.join("\n")
397 )
398 .into());
399 }
400 }
401 }
402
403 Ok(())
404}
405
406async fn sqlx_connect<DB>(
407 max_connections: u32,
408 acquire_timeout: u64,
409 url: &str,
410 schema: Option<&str>,
411) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
412where
413 DB: sqlx::Database,
414 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
415{
416 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
417 .max_connections(max_connections)
418 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
419 if let Some(schema) = schema {
422 let sql = format!("SET search_path = '{schema}'");
423 pool_options = pool_options.after_connect(move |conn, _| {
424 let sql = sql.clone();
425 Box::pin(async move {
426 sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql))
427 .await
428 .map(|_| ())
429 })
430 });
431 }
432 pool_options.connect(url).await.map_err(Into::into)
433}
434
435impl From<DateTimeCrate> for CodegenDateTimeCrate {
436 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
437 match date_time_crate {
438 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
439 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
440 }
441 }
442}
443
444impl From<BigIntegerType> for CodegenBigIntegerType {
445 fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
446 match date_time_crate {
447 BigIntegerType::I64 => CodegenBigIntegerType::I64,
448 BigIntegerType::I32 => CodegenBigIntegerType::I32,
449 }
450 }
451}
452
453impl From<BannerVersion> for CodegenBannerVersion {
454 fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
455 match banner_version {
456 BannerVersion::Off => CodegenBannerVersion::Off,
457 BannerVersion::Major => CodegenBannerVersion::Major,
458 BannerVersion::Minor => CodegenBannerVersion::Minor,
459 BannerVersion::Patch => CodegenBannerVersion::Patch,
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use clap::Parser;
467
468 use super::*;
469 use crate::{Cli, Commands};
470
471 #[test]
472 #[should_panic(
473 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
474 )]
475 fn test_generate_entity_no_protocol() {
476 let cli = Cli::parse_from([
477 "sea-orm-cli",
478 "generate",
479 "entity",
480 "--database-url",
481 "://root:root@localhost:3306/database",
482 ]);
483
484 match cli.command {
485 Commands::Generate { command } => {
486 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
487 }
488 _ => unreachable!(),
489 }
490 }
491
492 #[test]
493 #[should_panic(
494 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
495 )]
496 fn test_generate_entity_no_database_section() {
497 let cli = Cli::parse_from([
498 "sea-orm-cli",
499 "generate",
500 "entity",
501 "--database-url",
502 "postgresql://root:root@localhost:3306",
503 ]);
504
505 match cli.command {
506 Commands::Generate { command } => {
507 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
508 }
509 _ => unreachable!(),
510 }
511 }
512
513 #[test]
514 #[should_panic(
515 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
516 )]
517 fn test_generate_entity_no_database_path() {
518 let cli = Cli::parse_from([
519 "sea-orm-cli",
520 "generate",
521 "entity",
522 "--database-url",
523 "mysql://root:root@localhost:3306/",
524 ]);
525
526 match cli.command {
527 Commands::Generate { command } => {
528 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
529 }
530 _ => unreachable!(),
531 }
532 }
533
534 #[test]
535 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
536 fn test_generate_entity_no_host() {
537 let cli = Cli::parse_from([
538 "sea-orm-cli",
539 "generate",
540 "entity",
541 "--database-url",
542 "postgres://root:root@/database",
543 ]);
544
545 match cli.command {
546 Commands::Generate { command } => {
547 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
548 }
549 _ => unreachable!(),
550 }
551 }
552
553 #[test]
554 fn test_split_by_comma_simple() {
555 let result = super::split_by_comma_ignoring_parentheses("a,b,c");
557 assert_eq!(result, vec!["a", "b", "c"]);
558 }
559
560 #[test]
561 fn test_split_by_comma_with_parentheses() {
562 let result = super::split_by_comma_ignoring_parentheses("test(a, b)");
564 assert_eq!(result, vec!["test(a, b)"]);
565
566 let result = super::split_by_comma_ignoring_parentheses("attr1,test(a, b)");
568 assert_eq!(result, vec!["attr1", "test(a, b)"]);
569 }
570
571 #[test]
572 fn test_split_by_comma_with_nested_parentheses() {
573 let result =
575 super::split_by_comma_ignoring_parentheses("cfg_attr(debug_assertions, derive(Debug))");
576 assert_eq!(result, vec!["cfg_attr(debug_assertions, derive(Debug))"]);
577
578 let result = super::split_by_comma_ignoring_parentheses(
580 "cfg_attr(feature1, attr(a, b)),cfg_attr(feature2, attr(c, d))",
581 );
582 assert_eq!(
583 result,
584 vec![
585 "cfg_attr(feature1, attr(a, b))",
586 "cfg_attr(feature2, attr(c, d))"
587 ]
588 );
589 }
590
591 #[test]
592 fn test_split_by_comma_with_brackets() {
593 let result = super::split_by_comma_ignoring_parentheses(
595 "serde(rename_all = \"camelCase\"),ts(export)",
596 );
597 assert_eq!(
598 result,
599 vec!["serde(rename_all = \"camelCase\")", "ts(export)"]
600 );
601
602 let result = super::split_by_comma_ignoring_parentheses("attr[key, value],other");
604 assert_eq!(result, vec!["attr[key, value]", "other"]);
605 }
606
607 #[test]
608 fn test_split_by_comma_with_braces() {
609 let result = super::split_by_comma_ignoring_parentheses("derive{a, b},other");
611 assert_eq!(result, vec!["derive{a, b}", "other"]);
612 }
613
614 #[test]
615 fn test_split_by_comma_empty() {
616 let result = super::split_by_comma_ignoring_parentheses("");
618 assert!(result.is_empty());
619
620 let result = super::split_by_comma_ignoring_parentheses(" ");
622 assert!(result.is_empty());
623 }
624
625 #[test]
626 fn test_split_by_comma_whitespace_handling() {
627 let result = super::split_by_comma_ignoring_parentheses(" a , b ");
629 assert_eq!(result, vec!["a", "b"]);
630
631 let result = super::split_by_comma_ignoring_parentheses("test( a , b )");
633 assert_eq!(result, vec!["test( a , b )"]);
634 }
635
636 #[test]
637 fn test_process_comma_separated_values() {
638 let input = vec![
640 "attr1,attr2".to_string(),
641 "test(a, b)".to_string(),
642 "attr3".to_string(),
643 ];
644 let result = super::process_comma_separated_values(input);
645 assert_eq!(result, vec!["attr1", "attr2", "test(a, b)", "attr3"]);
646 }
647
648 #[test]
649 fn test_split_by_comma_real_world_examples() {
650 let result = super::split_by_comma_ignoring_parentheses(
652 "cfg_attr(debug_assertions, derive(Debug)),serde(rename_all = \"camelCase\")",
653 );
654 assert_eq!(
655 result,
656 vec![
657 "cfg_attr(debug_assertions, derive(Debug))",
658 "serde(rename_all = \"camelCase\")"
659 ]
660 );
661
662 let result = super::split_by_comma_ignoring_parentheses(
664 "derive(Debug, Clone),derive(Serialize, Deserialize)",
665 );
666 assert_eq!(
667 result,
668 vec!["derive(Debug, Clone)", "derive(Serialize, Deserialize)"]
669 );
670 }
671}