Skip to main content

sea_orm_cli/commands/
generate.rs

1use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4    BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
5    DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
6    MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12/// Split a string by comma while respecting parentheses nesting.
13/// This allows attributes like `test(a, b)` to be treated as a single value
14/// instead of being split into `test(a` and ` b)`.
15fn split_by_comma_ignoring_parentheses(s: &str) -> Vec<String> {
16    let mut result = Vec::new();
17    let mut current = String::new();
18    let mut paren_depth = 0usize;
19    let mut bracket_depth = 0usize;
20    let mut brace_depth = 0usize;
21
22    for c in s.chars() {
23        match c {
24            '(' => {
25                paren_depth += 1;
26                current.push(c);
27            }
28            ')' => {
29                paren_depth = paren_depth.saturating_sub(1);
30                current.push(c);
31            }
32            '[' => {
33                bracket_depth += 1;
34                current.push(c);
35            }
36            ']' => {
37                bracket_depth = bracket_depth.saturating_sub(1);
38                current.push(c);
39            }
40            '{' => {
41                brace_depth += 1;
42                current.push(c);
43            }
44            '}' => {
45                brace_depth = brace_depth.saturating_sub(1);
46                current.push(c);
47            }
48            ',' if paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 => {
49                let trimmed = current.trim();
50                if !trimmed.is_empty() {
51                    result.push(trimmed.to_string());
52                }
53                current.clear();
54            }
55            _ => {
56                current.push(c);
57            }
58        }
59    }
60
61    // Add the last segment
62    let trimmed = current.trim();
63    if !trimmed.is_empty() {
64        result.push(trimmed.to_string());
65    }
66
67    result
68}
69
70/// Process a vector of strings that may contain comma-separated values with nested parentheses.
71/// This handles the case where clap no longer splits by comma, so we need to manually split
72/// each string while respecting parentheses nesting.
73fn process_comma_separated_values(values: Vec<String>) -> Vec<String> {
74    values
75        .into_iter()
76        .flat_map(|s| split_by_comma_ignoring_parentheses(&s))
77        .collect()
78}
79
80pub async fn run_generate_command(
81    command: GenerateSubcommands,
82    verbose: bool,
83) -> Result<(), Box<dyn Error>> {
84    match command {
85        GenerateSubcommands::Entity {
86            entity_format,
87            compact_format: _,
88            expanded_format,
89            frontend_format,
90            include_hidden_tables,
91            tables,
92            ignore_tables,
93            max_connections,
94            acquire_timeout,
95            output_dir,
96            database_schema,
97            database_url,
98            with_prelude,
99            with_serde,
100            serde_skip_deserializing_primary_key,
101            serde_skip_hidden_column,
102            with_copy_enums,
103            date_time_crate,
104            big_integer_type,
105            lib,
106            model_extra_derives,
107            model_extra_attributes,
108            enum_extra_derives,
109            enum_extra_attributes,
110            column_extra_derives,
111            seaography,
112            impl_active_model_behavior,
113            preserve_user_modifications,
114            banner_version,
115            er_diagram,
116        } => {
117            if verbose {
118                let _ = tracing_subscriber::fmt()
119                    .with_max_level(tracing::Level::DEBUG)
120                    .with_test_writer()
121                    .try_init();
122            } else {
123                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
124                let fmt_layer = tracing_subscriber::fmt::layer()
125                    .with_target(false)
126                    .with_level(false)
127                    .without_time();
128
129                let _ = tracing_subscriber::registry()
130                    .with(filter_layer)
131                    .with(fmt_layer)
132                    .try_init();
133            }
134
135            // The database should be a valid URL that can be parsed
136            // protocol://username:password@host/database_name
137            let url = Url::parse(&database_url)?;
138
139            // Make sure we have all the required url components
140            //
141            // Missing scheme will have been caught by the Url::parse() call
142            // above
143            let is_sqlite = url.scheme() == "sqlite";
144
145            // Closures for filtering tables
146            let filter_tables =
147                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
148
149            let filter_hidden_tables = |table: &str| -> bool {
150                if include_hidden_tables {
151                    true
152                } else {
153                    !table.starts_with('_')
154                }
155            };
156
157            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
158
159            let _database_name = if !is_sqlite {
160                // The database name should be the first element of the path string
161                //
162                // Throwing an error if there is no database name since it might be
163                // accepted by the database without it, while we're looking to dump
164                // information from a particular database
165                let database_name = url
166                    .path_segments()
167                    .unwrap_or_else(|| {
168                        panic!(
169                            "There is no database name as part of the url path: {}",
170                            url.as_str()
171                        )
172                    })
173                    .next()
174                    .unwrap();
175
176                // An empty string as the database name is also an error
177                if database_name.is_empty() {
178                    panic!(
179                        "There is no database name as part of the url path: {}",
180                        url.as_str()
181                    );
182                }
183
184                database_name
185            } else {
186                Default::default()
187            };
188
189            let (schema_name, table_stmts) = match url.scheme() {
190                "mysql" => {
191                    #[cfg(not(feature = "sqlx-mysql"))]
192                    {
193                        panic!("mysql feature is off")
194                    }
195                    #[cfg(feature = "sqlx-mysql")]
196                    {
197                        use sea_schema::mysql::discovery::SchemaDiscovery;
198                        use sqlx::MySql;
199
200                        println!("Connecting to MySQL ...");
201                        let connection = sqlx_connect::<MySql>(
202                            max_connections,
203                            acquire_timeout,
204                            url.as_str(),
205                            None,
206                        )
207                        .await?;
208                        println!("Discovering schema ...");
209                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
210                        let schema = schema_discovery.discover().await?;
211                        let table_stmts = schema
212                            .tables
213                            .into_iter()
214                            .filter(|schema| filter_tables(&schema.info.name))
215                            .filter(|schema| filter_hidden_tables(&schema.info.name))
216                            .filter(|schema| filter_skip_tables(&schema.info.name))
217                            .map(|schema| schema.write())
218                            .collect();
219                        (None, table_stmts)
220                    }
221                }
222                "sqlite" => {
223                    #[cfg(not(feature = "sqlx-sqlite"))]
224                    {
225                        panic!("sqlite feature is off")
226                    }
227                    #[cfg(feature = "sqlx-sqlite")]
228                    {
229                        use sea_schema::sqlite::discovery::SchemaDiscovery;
230                        use sqlx::Sqlite;
231
232                        println!("Connecting to SQLite ...");
233                        let connection = sqlx_connect::<Sqlite>(
234                            max_connections,
235                            acquire_timeout,
236                            url.as_str(),
237                            None,
238                        )
239                        .await?;
240                        println!("Discovering schema ...");
241                        let schema_discovery = SchemaDiscovery::new(connection);
242                        let schema = schema_discovery
243                            .discover()
244                            .await?
245                            .merge_indexes_into_table();
246                        let table_stmts = schema
247                            .tables
248                            .into_iter()
249                            .filter(|schema| filter_tables(&schema.name))
250                            .filter(|schema| filter_hidden_tables(&schema.name))
251                            .filter(|schema| filter_skip_tables(&schema.name))
252                            .map(|schema| schema.write())
253                            .collect();
254                        (None, table_stmts)
255                    }
256                }
257                "postgres" | "postgresql" => {
258                    #[cfg(not(feature = "sqlx-postgres"))]
259                    {
260                        panic!("postgres feature is off")
261                    }
262                    #[cfg(feature = "sqlx-postgres")]
263                    {
264                        use sea_schema::postgres::discovery::SchemaDiscovery;
265                        use sqlx::Postgres;
266
267                        println!("Connecting to Postgres ...");
268                        let schema = database_schema.as_deref().unwrap_or("public");
269                        let connection = sqlx_connect::<Postgres>(
270                            max_connections,
271                            acquire_timeout,
272                            url.as_str(),
273                            Some(schema),
274                        )
275                        .await?;
276                        println!("Discovering schema ...");
277                        let schema_discovery = SchemaDiscovery::new(connection, schema);
278                        let schema = schema_discovery.discover().await?;
279                        let table_stmts = schema
280                            .tables
281                            .into_iter()
282                            .filter(|schema| filter_tables(&schema.info.name))
283                            .filter(|schema| filter_hidden_tables(&schema.info.name))
284                            .filter(|schema| filter_skip_tables(&schema.info.name))
285                            .map(|schema| schema.write())
286                            .collect();
287                        (database_schema, table_stmts)
288                    }
289                }
290                _ => unimplemented!("{} is not supported", url.scheme()),
291            };
292            println!("... discovered.");
293
294            // Process extra derives and attributes, splitting by comma while respecting parentheses
295            // This handles cases like `--model-extra-attributes 'cfg_attr(debug_assertions, derive(Debug))'`
296            // which should be treated as a single attribute, not split into `cfg_attr(debug_assertions` and ` derive(Debug))`
297            let model_extra_derives = process_comma_separated_values(model_extra_derives);
298            let model_extra_attributes = process_comma_separated_values(model_extra_attributes);
299            let enum_extra_derives = process_comma_separated_values(enum_extra_derives);
300            let enum_extra_attributes = process_comma_separated_values(enum_extra_attributes);
301            let column_extra_derives = process_comma_separated_values(column_extra_derives);
302
303            let writer_context = EntityWriterContext::new(
304                if expanded_format {
305                    EntityFormat::Expanded
306                } else if frontend_format {
307                    EntityFormat::Frontend
308                } else if let Some(entity_format) = entity_format {
309                    EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
310                } else {
311                    EntityFormat::default()
312                },
313                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
314                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
315                with_copy_enums,
316                date_time_crate.into(),
317                big_integer_type.into(),
318                schema_name,
319                lib,
320                serde_skip_deserializing_primary_key,
321                serde_skip_hidden_column,
322                model_extra_derives,
323                model_extra_attributes,
324                enum_extra_derives,
325                enum_extra_attributes,
326                column_extra_derives,
327                seaography,
328                impl_active_model_behavior,
329                banner_version.into(),
330            );
331            let entity_writer = EntityTransformer::transform(table_stmts)?;
332
333            let dir = Path::new(&output_dir);
334            fs::create_dir_all(dir)?;
335
336            if er_diagram {
337                let diagram = entity_writer.generate_er_diagram();
338                let diagram_path = dir.join("entities.mermaid");
339                fs::write(&diagram_path, &diagram)?;
340                println!("Writing {}", diagram_path.display());
341            }
342
343            let output = entity_writer.generate(&writer_context);
344
345            let mut merge_fallback_files: Vec<String> = Vec::new();
346
347            for OutputFile { name, content } in output.files.iter() {
348                let file_path = dir.join(name);
349                println!("Writing {}", file_path.display());
350
351                if !matches!(
352                    name.as_str(),
353                    "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
354                ) && file_path.exists()
355                    && preserve_user_modifications
356                {
357                    let prev_content = fs::read_to_string(&file_path)?;
358                    match merge_entity_files(&prev_content, content) {
359                        Ok(merged) => {
360                            fs::write(file_path, merged)?;
361                        }
362                        Err(MergeReport {
363                            output,
364                            warnings,
365                            fallback_applied,
366                        }) => {
367                            for message in warnings {
368                                eprintln!("{message}");
369                            }
370                            fs::write(file_path, output)?;
371                            if fallback_applied {
372                                merge_fallback_files.push(name.clone());
373                            }
374                        }
375                    }
376                } else {
377                    fs::write(file_path, content)?;
378                };
379            }
380
381            // Format each of the files
382            for OutputFile { name, .. } in output.files.iter() {
383                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
384                if !exit_status.success() {
385                    // Propagate the error if any
386                    return Err(format!("Fail to format file `{name}`").into());
387                }
388            }
389
390            if merge_fallback_files.is_empty() {
391                println!("... Done.");
392            } else {
393                return Err(format!(
394                    "Merge fallback applied for {} file(s): \n{}",
395                    merge_fallback_files.len(),
396                    merge_fallback_files.join("\n")
397                )
398                .into());
399            }
400        }
401    }
402
403    Ok(())
404}
405
406async fn sqlx_connect<DB>(
407    max_connections: u32,
408    acquire_timeout: u64,
409    url: &str,
410    schema: Option<&str>,
411) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
412where
413    DB: sqlx::Database,
414    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
415{
416    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
417        .max_connections(max_connections)
418        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
419    // Set search_path for Postgres, E.g. Some("public") by default
420    // MySQL & SQLite connection initialize with schema `None`
421    if let Some(schema) = schema {
422        let sql = format!("SET search_path = '{schema}'");
423        pool_options = pool_options.after_connect(move |conn, _| {
424            let sql = sql.clone();
425            Box::pin(async move {
426                sqlx::Executor::execute(conn, sqlx::AssertSqlSafe(sql))
427                    .await
428                    .map(|_| ())
429            })
430        });
431    }
432    pool_options.connect(url).await.map_err(Into::into)
433}
434
435impl From<DateTimeCrate> for CodegenDateTimeCrate {
436    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
437        match date_time_crate {
438            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
439            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
440        }
441    }
442}
443
444impl From<BigIntegerType> for CodegenBigIntegerType {
445    fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
446        match date_time_crate {
447            BigIntegerType::I64 => CodegenBigIntegerType::I64,
448            BigIntegerType::I32 => CodegenBigIntegerType::I32,
449        }
450    }
451}
452
453impl From<BannerVersion> for CodegenBannerVersion {
454    fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
455        match banner_version {
456            BannerVersion::Off => CodegenBannerVersion::Off,
457            BannerVersion::Major => CodegenBannerVersion::Major,
458            BannerVersion::Minor => CodegenBannerVersion::Minor,
459            BannerVersion::Patch => CodegenBannerVersion::Patch,
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use clap::Parser;
467
468    use super::*;
469    use crate::{Cli, Commands};
470
471    #[test]
472    #[should_panic(
473        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
474    )]
475    fn test_generate_entity_no_protocol() {
476        let cli = Cli::parse_from([
477            "sea-orm-cli",
478            "generate",
479            "entity",
480            "--database-url",
481            "://root:root@localhost:3306/database",
482        ]);
483
484        match cli.command {
485            Commands::Generate { command } => {
486                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
487            }
488            _ => unreachable!(),
489        }
490    }
491
492    #[test]
493    #[should_panic(
494        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
495    )]
496    fn test_generate_entity_no_database_section() {
497        let cli = Cli::parse_from([
498            "sea-orm-cli",
499            "generate",
500            "entity",
501            "--database-url",
502            "postgresql://root:root@localhost:3306",
503        ]);
504
505        match cli.command {
506            Commands::Generate { command } => {
507                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
508            }
509            _ => unreachable!(),
510        }
511    }
512
513    #[test]
514    #[should_panic(
515        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
516    )]
517    fn test_generate_entity_no_database_path() {
518        let cli = Cli::parse_from([
519            "sea-orm-cli",
520            "generate",
521            "entity",
522            "--database-url",
523            "mysql://root:root@localhost:3306/",
524        ]);
525
526        match cli.command {
527            Commands::Generate { command } => {
528                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
529            }
530            _ => unreachable!(),
531        }
532    }
533
534    #[test]
535    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
536    fn test_generate_entity_no_host() {
537        let cli = Cli::parse_from([
538            "sea-orm-cli",
539            "generate",
540            "entity",
541            "--database-url",
542            "postgres://root:root@/database",
543        ]);
544
545        match cli.command {
546            Commands::Generate { command } => {
547                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
548            }
549            _ => unreachable!(),
550        }
551    }
552
553    #[test]
554    fn test_split_by_comma_simple() {
555        // Simple comma-separated values should split normally
556        let result = super::split_by_comma_ignoring_parentheses("a,b,c");
557        assert_eq!(result, vec!["a", "b", "c"]);
558    }
559
560    #[test]
561    fn test_split_by_comma_with_parentheses() {
562        // Comma inside parentheses should NOT split
563        let result = super::split_by_comma_ignoring_parentheses("test(a, b)");
564        assert_eq!(result, vec!["test(a, b)"]);
565
566        // Multiple values, one with parentheses containing comma
567        let result = super::split_by_comma_ignoring_parentheses("attr1,test(a, b)");
568        assert_eq!(result, vec!["attr1", "test(a, b)"]);
569    }
570
571    #[test]
572    fn test_split_by_comma_with_nested_parentheses() {
573        // Nested parentheses with commas
574        let result =
575            super::split_by_comma_ignoring_parentheses("cfg_attr(debug_assertions, derive(Debug))");
576        assert_eq!(result, vec!["cfg_attr(debug_assertions, derive(Debug))"]);
577
578        // Multiple nested parentheses
579        let result = super::split_by_comma_ignoring_parentheses(
580            "cfg_attr(feature1, attr(a, b)),cfg_attr(feature2, attr(c, d))",
581        );
582        assert_eq!(
583            result,
584            vec![
585                "cfg_attr(feature1, attr(a, b))",
586                "cfg_attr(feature2, attr(c, d))"
587            ]
588        );
589    }
590
591    #[test]
592    fn test_split_by_comma_with_brackets() {
593        // Brackets should also be respected
594        let result = super::split_by_comma_ignoring_parentheses(
595            "serde(rename_all = \"camelCase\"),ts(export)",
596        );
597        assert_eq!(
598            result,
599            vec!["serde(rename_all = \"camelCase\")", "ts(export)"]
600        );
601
602        // Brackets with commas
603        let result = super::split_by_comma_ignoring_parentheses("attr[key, value],other");
604        assert_eq!(result, vec!["attr[key, value]", "other"]);
605    }
606
607    #[test]
608    fn test_split_by_comma_with_braces() {
609        // Braces should also be respected
610        let result = super::split_by_comma_ignoring_parentheses("derive{a, b},other");
611        assert_eq!(result, vec!["derive{a, b}", "other"]);
612    }
613
614    #[test]
615    fn test_split_by_comma_empty() {
616        // Empty string should return empty vec
617        let result = super::split_by_comma_ignoring_parentheses("");
618        assert!(result.is_empty());
619
620        // Only whitespace should return empty vec
621        let result = super::split_by_comma_ignoring_parentheses("   ");
622        assert!(result.is_empty());
623    }
624
625    #[test]
626    fn test_split_by_comma_whitespace_handling() {
627        // Whitespace around values should be trimmed
628        let result = super::split_by_comma_ignoring_parentheses("  a  ,  b  ");
629        assert_eq!(result, vec!["a", "b"]);
630
631        // Whitespace inside parentheses should be preserved
632        let result = super::split_by_comma_ignoring_parentheses("test( a , b )");
633        assert_eq!(result, vec!["test( a , b )"]);
634    }
635
636    #[test]
637    fn test_process_comma_separated_values() {
638        // Process multiple strings, each potentially containing comma-separated values
639        let input = vec![
640            "attr1,attr2".to_string(),
641            "test(a, b)".to_string(),
642            "attr3".to_string(),
643        ];
644        let result = super::process_comma_separated_values(input);
645        assert_eq!(result, vec!["attr1", "attr2", "test(a, b)", "attr3"]);
646    }
647
648    #[test]
649    fn test_split_by_comma_real_world_examples() {
650        // Real-world example: cfg_attr with derive
651        let result = super::split_by_comma_ignoring_parentheses(
652            "cfg_attr(debug_assertions, derive(Debug)),serde(rename_all = \"camelCase\")",
653        );
654        assert_eq!(
655            result,
656            vec![
657                "cfg_attr(debug_assertions, derive(Debug))",
658                "serde(rename_all = \"camelCase\")"
659            ]
660        );
661
662        // Real-world example: multiple derives
663        let result = super::split_by_comma_ignoring_parentheses(
664            "derive(Debug, Clone),derive(Serialize, Deserialize)",
665        );
666        assert_eq!(
667            result,
668            vec!["derive(Debug, Clone)", "derive(Serialize, Deserialize)"]
669        );
670    }
671}