Skip to main content

sea_orm_cli/commands/
generate.rs

1use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4    BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
5    DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
6    MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            entity_format,
19            compact_format: _,
20            expanded_format,
21            frontend_format,
22            include_hidden_tables,
23            tables,
24            ignore_tables,
25            max_connections,
26            acquire_timeout,
27            output_dir,
28            database_schema,
29            database_url,
30            with_prelude,
31            with_serde,
32            serde_skip_deserializing_primary_key,
33            serde_skip_hidden_column,
34            with_copy_enums,
35            date_time_crate,
36            big_integer_type,
37            lib,
38            model_extra_derives,
39            model_extra_attributes,
40            enum_extra_derives,
41            enum_extra_attributes,
42            column_extra_derives,
43            seaography,
44            impl_active_model_behavior,
45            preserve_user_modifications,
46            banner_version,
47            er_diagram,
48        } => {
49            if verbose {
50                let _ = tracing_subscriber::fmt()
51                    .with_max_level(tracing::Level::DEBUG)
52                    .with_test_writer()
53                    .try_init();
54            } else {
55                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
56                let fmt_layer = tracing_subscriber::fmt::layer()
57                    .with_target(false)
58                    .with_level(false)
59                    .without_time();
60
61                let _ = tracing_subscriber::registry()
62                    .with(filter_layer)
63                    .with(fmt_layer)
64                    .try_init();
65            }
66
67            // The database should be a valid URL that can be parsed
68            // protocol://username:password@host/database_name
69            let url = Url::parse(&database_url)?;
70
71            // Make sure we have all the required url components
72            //
73            // Missing scheme will have been caught by the Url::parse() call
74            // above
75            let is_sqlite = url.scheme() == "sqlite";
76
77            // Closures for filtering tables
78            let filter_tables =
79                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
80
81            let filter_hidden_tables = |table: &str| -> bool {
82                if include_hidden_tables {
83                    true
84                } else {
85                    !table.starts_with('_')
86                }
87            };
88
89            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
90
91            let _database_name = if !is_sqlite {
92                // The database name should be the first element of the path string
93                //
94                // Throwing an error if there is no database name since it might be
95                // accepted by the database without it, while we're looking to dump
96                // information from a particular database
97                let database_name = url
98                    .path_segments()
99                    .unwrap_or_else(|| {
100                        panic!(
101                            "There is no database name as part of the url path: {}",
102                            url.as_str()
103                        )
104                    })
105                    .next()
106                    .unwrap();
107
108                // An empty string as the database name is also an error
109                if database_name.is_empty() {
110                    panic!(
111                        "There is no database name as part of the url path: {}",
112                        url.as_str()
113                    );
114                }
115
116                database_name
117            } else {
118                Default::default()
119            };
120
121            let (schema_name, table_stmts) = match url.scheme() {
122                "mysql" => {
123                    #[cfg(not(feature = "sqlx-mysql"))]
124                    {
125                        panic!("mysql feature is off")
126                    }
127                    #[cfg(feature = "sqlx-mysql")]
128                    {
129                        use sea_schema::mysql::discovery::SchemaDiscovery;
130                        use sqlx::MySql;
131
132                        println!("Connecting to MySQL ...");
133                        let connection = sqlx_connect::<MySql>(
134                            max_connections,
135                            acquire_timeout,
136                            url.as_str(),
137                            None,
138                        )
139                        .await?;
140                        println!("Discovering schema ...");
141                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
142                        let schema = schema_discovery.discover().await?;
143                        let table_stmts = schema
144                            .tables
145                            .into_iter()
146                            .filter(|schema| filter_tables(&schema.info.name))
147                            .filter(|schema| filter_hidden_tables(&schema.info.name))
148                            .filter(|schema| filter_skip_tables(&schema.info.name))
149                            .map(|schema| schema.write())
150                            .collect();
151                        (None, table_stmts)
152                    }
153                }
154                "sqlite" => {
155                    #[cfg(not(feature = "sqlx-sqlite"))]
156                    {
157                        panic!("sqlite feature is off")
158                    }
159                    #[cfg(feature = "sqlx-sqlite")]
160                    {
161                        use sea_schema::sqlite::discovery::SchemaDiscovery;
162                        use sqlx::Sqlite;
163
164                        println!("Connecting to SQLite ...");
165                        let connection = sqlx_connect::<Sqlite>(
166                            max_connections,
167                            acquire_timeout,
168                            url.as_str(),
169                            None,
170                        )
171                        .await?;
172                        println!("Discovering schema ...");
173                        let schema_discovery = SchemaDiscovery::new(connection);
174                        let schema = schema_discovery
175                            .discover()
176                            .await?
177                            .merge_indexes_into_table();
178                        let table_stmts = schema
179                            .tables
180                            .into_iter()
181                            .filter(|schema| filter_tables(&schema.name))
182                            .filter(|schema| filter_hidden_tables(&schema.name))
183                            .filter(|schema| filter_skip_tables(&schema.name))
184                            .map(|schema| schema.write())
185                            .collect();
186                        (None, table_stmts)
187                    }
188                }
189                "postgres" | "postgresql" => {
190                    #[cfg(not(feature = "sqlx-postgres"))]
191                    {
192                        panic!("postgres feature is off")
193                    }
194                    #[cfg(feature = "sqlx-postgres")]
195                    {
196                        use sea_schema::postgres::discovery::SchemaDiscovery;
197                        use sqlx::Postgres;
198
199                        println!("Connecting to Postgres ...");
200                        let schema = database_schema.as_deref().unwrap_or("public");
201                        let connection = sqlx_connect::<Postgres>(
202                            max_connections,
203                            acquire_timeout,
204                            url.as_str(),
205                            Some(schema),
206                        )
207                        .await?;
208                        println!("Discovering schema ...");
209                        let schema_discovery = SchemaDiscovery::new(connection, schema);
210                        let schema = schema_discovery.discover().await?;
211                        let table_stmts = schema
212                            .tables
213                            .into_iter()
214                            .filter(|schema| filter_tables(&schema.info.name))
215                            .filter(|schema| filter_hidden_tables(&schema.info.name))
216                            .filter(|schema| filter_skip_tables(&schema.info.name))
217                            .map(|schema| schema.write())
218                            .collect();
219                        (database_schema, table_stmts)
220                    }
221                }
222                _ => unimplemented!("{} is not supported", url.scheme()),
223            };
224            println!("... discovered.");
225
226            let writer_context = EntityWriterContext::new(
227                if expanded_format {
228                    EntityFormat::Expanded
229                } else if frontend_format {
230                    EntityFormat::Frontend
231                } else if let Some(entity_format) = entity_format {
232                    EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
233                } else {
234                    EntityFormat::default()
235                },
236                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
237                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
238                with_copy_enums,
239                date_time_crate.into(),
240                big_integer_type.into(),
241                schema_name,
242                lib,
243                serde_skip_deserializing_primary_key,
244                serde_skip_hidden_column,
245                model_extra_derives,
246                model_extra_attributes,
247                enum_extra_derives,
248                enum_extra_attributes,
249                column_extra_derives,
250                seaography,
251                impl_active_model_behavior,
252                banner_version.into(),
253            );
254            let entity_writer = EntityTransformer::transform(table_stmts)?;
255
256            let dir = Path::new(&output_dir);
257            fs::create_dir_all(dir)?;
258
259            if er_diagram {
260                let diagram = entity_writer.generate_er_diagram();
261                let diagram_path = dir.join("entities.mermaid");
262                fs::write(&diagram_path, &diagram)?;
263                println!("Writing {}", diagram_path.display());
264            }
265
266            let output = entity_writer.generate(&writer_context);
267
268            let mut merge_fallback_files: Vec<String> = Vec::new();
269
270            for OutputFile { name, content } in output.files.iter() {
271                let file_path = dir.join(name);
272                println!("Writing {}", file_path.display());
273
274                if !matches!(
275                    name.as_str(),
276                    "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
277                ) && file_path.exists()
278                    && preserve_user_modifications
279                {
280                    let prev_content = fs::read_to_string(&file_path)?;
281                    match merge_entity_files(&prev_content, content) {
282                        Ok(merged) => {
283                            fs::write(file_path, merged)?;
284                        }
285                        Err(MergeReport {
286                            output,
287                            warnings,
288                            fallback_applied,
289                        }) => {
290                            for message in warnings {
291                                eprintln!("{message}");
292                            }
293                            fs::write(file_path, output)?;
294                            if fallback_applied {
295                                merge_fallback_files.push(name.clone());
296                            }
297                        }
298                    }
299                } else {
300                    fs::write(file_path, content)?;
301                };
302            }
303
304            // Format each of the files
305            for OutputFile { name, .. } in output.files.iter() {
306                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
307                if !exit_status.success() {
308                    // Propagate the error if any
309                    return Err(format!("Fail to format file `{name}`").into());
310                }
311            }
312
313            if merge_fallback_files.is_empty() {
314                println!("... Done.");
315            } else {
316                return Err(format!(
317                    "Merge fallback applied for {} file(s): \n{}",
318                    merge_fallback_files.len(),
319                    merge_fallback_files.join("\n")
320                )
321                .into());
322            }
323        }
324    }
325
326    Ok(())
327}
328
329async fn sqlx_connect<DB>(
330    max_connections: u32,
331    acquire_timeout: u64,
332    url: &str,
333    schema: Option<&str>,
334) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
335where
336    DB: sqlx::Database,
337    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
338{
339    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
340        .max_connections(max_connections)
341        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
342    // Set search_path for Postgres, E.g. Some("public") by default
343    // MySQL & SQLite connection initialize with schema `None`
344    if let Some(schema) = schema {
345        let sql = format!("SET search_path = '{schema}'");
346        pool_options = pool_options.after_connect(move |conn, _| {
347            let sql = sql.clone();
348            Box::pin(async move {
349                sqlx::Executor::execute(conn, sql.as_str())
350                    .await
351                    .map(|_| ())
352            })
353        });
354    }
355    pool_options.connect(url).await.map_err(Into::into)
356}
357
358impl From<DateTimeCrate> for CodegenDateTimeCrate {
359    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
360        match date_time_crate {
361            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
362            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
363        }
364    }
365}
366
367impl From<BigIntegerType> for CodegenBigIntegerType {
368    fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
369        match date_time_crate {
370            BigIntegerType::I64 => CodegenBigIntegerType::I64,
371            BigIntegerType::I32 => CodegenBigIntegerType::I32,
372        }
373    }
374}
375
376impl From<BannerVersion> for CodegenBannerVersion {
377    fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
378        match banner_version {
379            BannerVersion::Off => CodegenBannerVersion::Off,
380            BannerVersion::Major => CodegenBannerVersion::Major,
381            BannerVersion::Minor => CodegenBannerVersion::Minor,
382            BannerVersion::Patch => CodegenBannerVersion::Patch,
383        }
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use clap::Parser;
390
391    use super::*;
392    use crate::{Cli, Commands};
393
394    #[test]
395    #[should_panic(
396        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
397    )]
398    fn test_generate_entity_no_protocol() {
399        let cli = Cli::parse_from([
400            "sea-orm-cli",
401            "generate",
402            "entity",
403            "--database-url",
404            "://root:root@localhost:3306/database",
405        ]);
406
407        match cli.command {
408            Commands::Generate { command } => {
409                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
410            }
411            _ => unreachable!(),
412        }
413    }
414
415    #[test]
416    #[should_panic(
417        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
418    )]
419    fn test_generate_entity_no_database_section() {
420        let cli = Cli::parse_from([
421            "sea-orm-cli",
422            "generate",
423            "entity",
424            "--database-url",
425            "postgresql://root:root@localhost:3306",
426        ]);
427
428        match cli.command {
429            Commands::Generate { command } => {
430                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
431            }
432            _ => unreachable!(),
433        }
434    }
435
436    #[test]
437    #[should_panic(
438        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
439    )]
440    fn test_generate_entity_no_database_path() {
441        let cli = Cli::parse_from([
442            "sea-orm-cli",
443            "generate",
444            "entity",
445            "--database-url",
446            "mysql://root:root@localhost:3306/",
447        ]);
448
449        match cli.command {
450            Commands::Generate { command } => {
451                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
452            }
453            _ => unreachable!(),
454        }
455    }
456
457    #[test]
458    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
459    fn test_generate_entity_no_host() {
460        let cli = Cli::parse_from([
461            "sea-orm-cli",
462            "generate",
463            "entity",
464            "--database-url",
465            "postgres://root:root@/database",
466        ]);
467
468        match cli.command {
469            Commands::Generate { command } => {
470                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
471            }
472            _ => unreachable!(),
473        }
474    }
475}