sea_orm_cli/commands/
generate.rs

1use core::time;
2use sea_orm_codegen::{
3    DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
4    MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
5};
6use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{EnvFilter, prelude::*};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            entity_format,
19            compact_format: _,
20            expanded_format,
21            frontend_format,
22            include_hidden_tables,
23            tables,
24            ignore_tables,
25            max_connections,
26            acquire_timeout,
27            output_dir,
28            database_schema,
29            database_url,
30            with_prelude,
31            with_serde,
32            serde_skip_deserializing_primary_key,
33            serde_skip_hidden_column,
34            with_copy_enums,
35            date_time_crate,
36            lib,
37            model_extra_derives,
38            model_extra_attributes,
39            enum_extra_derives,
40            enum_extra_attributes,
41            column_extra_derives,
42            seaography,
43            impl_active_model_behavior,
44            preserve_user_modifications,
45        } => {
46            if verbose {
47                let _ = tracing_subscriber::fmt()
48                    .with_max_level(tracing::Level::DEBUG)
49                    .with_test_writer()
50                    .try_init();
51            } else {
52                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
53                let fmt_layer = tracing_subscriber::fmt::layer()
54                    .with_target(false)
55                    .with_level(false)
56                    .without_time();
57
58                let _ = tracing_subscriber::registry()
59                    .with(filter_layer)
60                    .with(fmt_layer)
61                    .try_init();
62            }
63
64            // The database should be a valid URL that can be parsed
65            // protocol://username:password@host/database_name
66            let url = Url::parse(&database_url)?;
67
68            // Make sure we have all the required url components
69            //
70            // Missing scheme will have been caught by the Url::parse() call
71            // above
72            let is_sqlite = url.scheme() == "sqlite";
73
74            // Closures for filtering tables
75            let filter_tables =
76                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
77
78            let filter_hidden_tables = |table: &str| -> bool {
79                if include_hidden_tables {
80                    true
81                } else {
82                    !table.starts_with('_')
83                }
84            };
85
86            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
87
88            let _database_name = if !is_sqlite {
89                // The database name should be the first element of the path string
90                //
91                // Throwing an error if there is no database name since it might be
92                // accepted by the database without it, while we're looking to dump
93                // information from a particular database
94                let database_name = url
95                    .path_segments()
96                    .unwrap_or_else(|| {
97                        panic!(
98                            "There is no database name as part of the url path: {}",
99                            url.as_str()
100                        )
101                    })
102                    .next()
103                    .unwrap();
104
105                // An empty string as the database name is also an error
106                if database_name.is_empty() {
107                    panic!(
108                        "There is no database name as part of the url path: {}",
109                        url.as_str()
110                    );
111                }
112
113                database_name
114            } else {
115                Default::default()
116            };
117
118            let (schema_name, table_stmts) = match url.scheme() {
119                "mysql" => {
120                    #[cfg(not(feature = "sqlx-mysql"))]
121                    {
122                        panic!("mysql feature is off")
123                    }
124                    #[cfg(feature = "sqlx-mysql")]
125                    {
126                        use sea_schema::mysql::discovery::SchemaDiscovery;
127                        use sqlx::MySql;
128
129                        println!("Connecting to MySQL ...");
130                        let connection = sqlx_connect::<MySql>(
131                            max_connections,
132                            acquire_timeout,
133                            url.as_str(),
134                            None,
135                        )
136                        .await?;
137                        println!("Discovering schema ...");
138                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
139                        let schema = schema_discovery.discover().await?;
140                        let table_stmts = schema
141                            .tables
142                            .into_iter()
143                            .filter(|schema| filter_tables(&schema.info.name))
144                            .filter(|schema| filter_hidden_tables(&schema.info.name))
145                            .filter(|schema| filter_skip_tables(&schema.info.name))
146                            .map(|schema| schema.write())
147                            .collect();
148                        (None, table_stmts)
149                    }
150                }
151                "sqlite" => {
152                    #[cfg(not(feature = "sqlx-sqlite"))]
153                    {
154                        panic!("sqlite feature is off")
155                    }
156                    #[cfg(feature = "sqlx-sqlite")]
157                    {
158                        use sea_schema::sqlite::discovery::SchemaDiscovery;
159                        use sqlx::Sqlite;
160
161                        println!("Connecting to SQLite ...");
162                        let connection = sqlx_connect::<Sqlite>(
163                            max_connections,
164                            acquire_timeout,
165                            url.as_str(),
166                            None,
167                        )
168                        .await?;
169                        println!("Discovering schema ...");
170                        let schema_discovery = SchemaDiscovery::new(connection);
171                        let schema = schema_discovery
172                            .discover()
173                            .await?
174                            .merge_indexes_into_table();
175                        let table_stmts = schema
176                            .tables
177                            .into_iter()
178                            .filter(|schema| filter_tables(&schema.name))
179                            .filter(|schema| filter_hidden_tables(&schema.name))
180                            .filter(|schema| filter_skip_tables(&schema.name))
181                            .map(|schema| schema.write())
182                            .collect();
183                        (None, table_stmts)
184                    }
185                }
186                "postgres" | "postgresql" => {
187                    #[cfg(not(feature = "sqlx-postgres"))]
188                    {
189                        panic!("postgres feature is off")
190                    }
191                    #[cfg(feature = "sqlx-postgres")]
192                    {
193                        use sea_schema::postgres::discovery::SchemaDiscovery;
194                        use sqlx::Postgres;
195
196                        println!("Connecting to Postgres ...");
197                        let schema = database_schema.as_deref().unwrap_or("public");
198                        let connection = sqlx_connect::<Postgres>(
199                            max_connections,
200                            acquire_timeout,
201                            url.as_str(),
202                            Some(schema),
203                        )
204                        .await?;
205                        println!("Discovering schema ...");
206                        let schema_discovery = SchemaDiscovery::new(connection, schema);
207                        let schema = schema_discovery.discover().await?;
208                        let table_stmts = schema
209                            .tables
210                            .into_iter()
211                            .filter(|schema| filter_tables(&schema.info.name))
212                            .filter(|schema| filter_hidden_tables(&schema.info.name))
213                            .filter(|schema| filter_skip_tables(&schema.info.name))
214                            .map(|schema| schema.write())
215                            .collect();
216                        (database_schema, table_stmts)
217                    }
218                }
219                _ => unimplemented!("{} is not supported", url.scheme()),
220            };
221            println!("... discovered.");
222
223            let writer_context = EntityWriterContext::new(
224                if expanded_format {
225                    EntityFormat::Expanded
226                } else if frontend_format {
227                    EntityFormat::Frontend
228                } else if let Some(entity_format) = entity_format {
229                    EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
230                } else {
231                    EntityFormat::default()
232                },
233                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
234                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
235                with_copy_enums,
236                date_time_crate.into(),
237                schema_name,
238                lib,
239                serde_skip_deserializing_primary_key,
240                serde_skip_hidden_column,
241                model_extra_derives,
242                model_extra_attributes,
243                enum_extra_derives,
244                enum_extra_attributes,
245                column_extra_derives,
246                seaography,
247                impl_active_model_behavior,
248            );
249            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
250
251            let dir = Path::new(&output_dir);
252            fs::create_dir_all(dir)?;
253
254            let mut merge_fallback_files: Vec<String> = Vec::new();
255
256            for OutputFile { name, content } in output.files.iter() {
257                let file_path = dir.join(name);
258                println!("Writing {}", file_path.display());
259
260                if !matches!(
261                    name.as_str(),
262                    "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
263                ) && file_path.exists()
264                    && preserve_user_modifications
265                {
266                    let prev_content = fs::read_to_string(&file_path)?;
267                    match merge_entity_files(&prev_content, content) {
268                        Ok(merged) => {
269                            fs::write(file_path, merged)?;
270                        }
271                        Err(MergeReport {
272                            output,
273                            warnings,
274                            fallback_applied,
275                        }) => {
276                            for message in warnings {
277                                eprintln!("{message}");
278                            }
279                            fs::write(file_path, output)?;
280                            if fallback_applied {
281                                merge_fallback_files.push(name.clone());
282                            }
283                        }
284                    }
285                } else {
286                    fs::write(file_path, content)?;
287                };
288            }
289
290            // Format each of the files
291            for OutputFile { name, .. } in output.files.iter() {
292                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
293                if !exit_status.success() {
294                    // Propagate the error if any
295                    return Err(format!("Fail to format file `{name}`").into());
296                }
297            }
298
299            if merge_fallback_files.is_empty() {
300                println!("... Done.");
301            } else {
302                return Err(format!(
303                    "Merge fallback applied for {} file(s): \n{}",
304                    merge_fallback_files.len(),
305                    merge_fallback_files.join("\n")
306                )
307                .into());
308            }
309        }
310    }
311
312    Ok(())
313}
314
315async fn sqlx_connect<DB>(
316    max_connections: u32,
317    acquire_timeout: u64,
318    url: &str,
319    schema: Option<&str>,
320) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
321where
322    DB: sqlx::Database,
323    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
324{
325    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
326        .max_connections(max_connections)
327        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
328    // Set search_path for Postgres, E.g. Some("public") by default
329    // MySQL & SQLite connection initialize with schema `None`
330    if let Some(schema) = schema {
331        let sql = format!("SET search_path = '{schema}'");
332        pool_options = pool_options.after_connect(move |conn, _| {
333            let sql = sql.clone();
334            Box::pin(async move {
335                sqlx::Executor::execute(conn, sql.as_str())
336                    .await
337                    .map(|_| ())
338            })
339        });
340    }
341    pool_options.connect(url).await.map_err(Into::into)
342}
343
344impl From<DateTimeCrate> for CodegenDateTimeCrate {
345    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
346        match date_time_crate {
347            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
348            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use clap::Parser;
356
357    use super::*;
358    use crate::{Cli, Commands};
359
360    #[test]
361    #[should_panic(
362        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
363    )]
364    fn test_generate_entity_no_protocol() {
365        let cli = Cli::parse_from([
366            "sea-orm-cli",
367            "generate",
368            "entity",
369            "--database-url",
370            "://root:root@localhost:3306/database",
371        ]);
372
373        match cli.command {
374            Commands::Generate { command } => {
375                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
376            }
377            _ => unreachable!(),
378        }
379    }
380
381    #[test]
382    #[should_panic(
383        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
384    )]
385    fn test_generate_entity_no_database_section() {
386        let cli = Cli::parse_from([
387            "sea-orm-cli",
388            "generate",
389            "entity",
390            "--database-url",
391            "postgresql://root:root@localhost:3306",
392        ]);
393
394        match cli.command {
395            Commands::Generate { command } => {
396                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
397            }
398            _ => unreachable!(),
399        }
400    }
401
402    #[test]
403    #[should_panic(
404        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
405    )]
406    fn test_generate_entity_no_database_path() {
407        let cli = Cli::parse_from([
408            "sea-orm-cli",
409            "generate",
410            "entity",
411            "--database-url",
412            "mysql://root:root@localhost:3306/",
413        ]);
414
415        match cli.command {
416            Commands::Generate { command } => {
417                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
418            }
419            _ => unreachable!(),
420        }
421    }
422
423    #[test]
424    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
425    fn test_generate_entity_no_host() {
426        let cli = Cli::parse_from([
427            "sea-orm-cli",
428            "generate",
429            "entity",
430            "--database-url",
431            "postgres://root:root@/database",
432        ]);
433
434        match cli.command {
435            Commands::Generate { command } => {
436                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
437            }
438            _ => unreachable!(),
439        }
440    }
441}