sea_orm_cli/commands/
generate.rs

1use crate::{BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4    BigIntegerType as CodegenBigIntegerType, DateTimeCrate as CodegenDateTimeCrate, EntityFormat,
5    EntityTransformer, EntityWriterContext, MergeReport, OutputFile, WithPrelude, WithSerde,
6    merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            entity_format,
19            compact_format: _,
20            expanded_format,
21            frontend_format,
22            include_hidden_tables,
23            tables,
24            ignore_tables,
25            max_connections,
26            acquire_timeout,
27            output_dir,
28            database_schema,
29            database_url,
30            with_prelude,
31            with_serde,
32            serde_skip_deserializing_primary_key,
33            serde_skip_hidden_column,
34            with_copy_enums,
35            date_time_crate,
36            big_integer_type,
37            lib,
38            model_extra_derives,
39            model_extra_attributes,
40            enum_extra_derives,
41            enum_extra_attributes,
42            column_extra_derives,
43            seaography,
44            impl_active_model_behavior,
45            preserve_user_modifications,
46        } => {
47            if verbose {
48                let _ = tracing_subscriber::fmt()
49                    .with_max_level(tracing::Level::DEBUG)
50                    .with_test_writer()
51                    .try_init();
52            } else {
53                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
54                let fmt_layer = tracing_subscriber::fmt::layer()
55                    .with_target(false)
56                    .with_level(false)
57                    .without_time();
58
59                let _ = tracing_subscriber::registry()
60                    .with(filter_layer)
61                    .with(fmt_layer)
62                    .try_init();
63            }
64
65            // The database should be a valid URL that can be parsed
66            // protocol://username:password@host/database_name
67            let url = Url::parse(&database_url)?;
68
69            // Make sure we have all the required url components
70            //
71            // Missing scheme will have been caught by the Url::parse() call
72            // above
73            let is_sqlite = url.scheme() == "sqlite";
74
75            // Closures for filtering tables
76            let filter_tables =
77                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
78
79            let filter_hidden_tables = |table: &str| -> bool {
80                if include_hidden_tables {
81                    true
82                } else {
83                    !table.starts_with('_')
84                }
85            };
86
87            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
88
89            let _database_name = if !is_sqlite {
90                // The database name should be the first element of the path string
91                //
92                // Throwing an error if there is no database name since it might be
93                // accepted by the database without it, while we're looking to dump
94                // information from a particular database
95                let database_name = url
96                    .path_segments()
97                    .unwrap_or_else(|| {
98                        panic!(
99                            "There is no database name as part of the url path: {}",
100                            url.as_str()
101                        )
102                    })
103                    .next()
104                    .unwrap();
105
106                // An empty string as the database name is also an error
107                if database_name.is_empty() {
108                    panic!(
109                        "There is no database name as part of the url path: {}",
110                        url.as_str()
111                    );
112                }
113
114                database_name
115            } else {
116                Default::default()
117            };
118
119            let (schema_name, table_stmts) = match url.scheme() {
120                "mysql" => {
121                    #[cfg(not(feature = "sqlx-mysql"))]
122                    {
123                        panic!("mysql feature is off")
124                    }
125                    #[cfg(feature = "sqlx-mysql")]
126                    {
127                        use sea_schema::mysql::discovery::SchemaDiscovery;
128                        use sqlx::MySql;
129
130                        println!("Connecting to MySQL ...");
131                        let connection = sqlx_connect::<MySql>(
132                            max_connections,
133                            acquire_timeout,
134                            url.as_str(),
135                            None,
136                        )
137                        .await?;
138                        println!("Discovering schema ...");
139                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
140                        let schema = schema_discovery.discover().await?;
141                        let table_stmts = schema
142                            .tables
143                            .into_iter()
144                            .filter(|schema| filter_tables(&schema.info.name))
145                            .filter(|schema| filter_hidden_tables(&schema.info.name))
146                            .filter(|schema| filter_skip_tables(&schema.info.name))
147                            .map(|schema| schema.write())
148                            .collect();
149                        (None, table_stmts)
150                    }
151                }
152                "sqlite" => {
153                    #[cfg(not(feature = "sqlx-sqlite"))]
154                    {
155                        panic!("sqlite feature is off")
156                    }
157                    #[cfg(feature = "sqlx-sqlite")]
158                    {
159                        use sea_schema::sqlite::discovery::SchemaDiscovery;
160                        use sqlx::Sqlite;
161
162                        println!("Connecting to SQLite ...");
163                        let connection = sqlx_connect::<Sqlite>(
164                            max_connections,
165                            acquire_timeout,
166                            url.as_str(),
167                            None,
168                        )
169                        .await?;
170                        println!("Discovering schema ...");
171                        let schema_discovery = SchemaDiscovery::new(connection);
172                        let schema = schema_discovery
173                            .discover()
174                            .await?
175                            .merge_indexes_into_table();
176                        let table_stmts = schema
177                            .tables
178                            .into_iter()
179                            .filter(|schema| filter_tables(&schema.name))
180                            .filter(|schema| filter_hidden_tables(&schema.name))
181                            .filter(|schema| filter_skip_tables(&schema.name))
182                            .map(|schema| schema.write())
183                            .collect();
184                        (None, table_stmts)
185                    }
186                }
187                "postgres" | "postgresql" => {
188                    #[cfg(not(feature = "sqlx-postgres"))]
189                    {
190                        panic!("postgres feature is off")
191                    }
192                    #[cfg(feature = "sqlx-postgres")]
193                    {
194                        use sea_schema::postgres::discovery::SchemaDiscovery;
195                        use sqlx::Postgres;
196
197                        println!("Connecting to Postgres ...");
198                        let schema = database_schema.as_deref().unwrap_or("public");
199                        let connection = sqlx_connect::<Postgres>(
200                            max_connections,
201                            acquire_timeout,
202                            url.as_str(),
203                            Some(schema),
204                        )
205                        .await?;
206                        println!("Discovering schema ...");
207                        let schema_discovery = SchemaDiscovery::new(connection, schema);
208                        let schema = schema_discovery.discover().await?;
209                        let table_stmts = schema
210                            .tables
211                            .into_iter()
212                            .filter(|schema| filter_tables(&schema.info.name))
213                            .filter(|schema| filter_hidden_tables(&schema.info.name))
214                            .filter(|schema| filter_skip_tables(&schema.info.name))
215                            .map(|schema| schema.write())
216                            .collect();
217                        (database_schema, table_stmts)
218                    }
219                }
220                _ => unimplemented!("{} is not supported", url.scheme()),
221            };
222            println!("... discovered.");
223
224            let writer_context = EntityWriterContext::new(
225                if expanded_format {
226                    EntityFormat::Expanded
227                } else if frontend_format {
228                    EntityFormat::Frontend
229                } else if let Some(entity_format) = entity_format {
230                    EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
231                } else {
232                    EntityFormat::default()
233                },
234                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
235                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
236                with_copy_enums,
237                date_time_crate.into(),
238                big_integer_type.into(),
239                schema_name,
240                lib,
241                serde_skip_deserializing_primary_key,
242                serde_skip_hidden_column,
243                model_extra_derives,
244                model_extra_attributes,
245                enum_extra_derives,
246                enum_extra_attributes,
247                column_extra_derives,
248                seaography,
249                impl_active_model_behavior,
250            );
251            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
252
253            let dir = Path::new(&output_dir);
254            fs::create_dir_all(dir)?;
255
256            let mut merge_fallback_files: Vec<String> = Vec::new();
257
258            for OutputFile { name, content } in output.files.iter() {
259                let file_path = dir.join(name);
260                println!("Writing {}", file_path.display());
261
262                if !matches!(
263                    name.as_str(),
264                    "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
265                ) && file_path.exists()
266                    && preserve_user_modifications
267                {
268                    let prev_content = fs::read_to_string(&file_path)?;
269                    match merge_entity_files(&prev_content, content) {
270                        Ok(merged) => {
271                            fs::write(file_path, merged)?;
272                        }
273                        Err(MergeReport {
274                            output,
275                            warnings,
276                            fallback_applied,
277                        }) => {
278                            for message in warnings {
279                                eprintln!("{message}");
280                            }
281                            fs::write(file_path, output)?;
282                            if fallback_applied {
283                                merge_fallback_files.push(name.clone());
284                            }
285                        }
286                    }
287                } else {
288                    fs::write(file_path, content)?;
289                };
290            }
291
292            // Format each of the files
293            for OutputFile { name, .. } in output.files.iter() {
294                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
295                if !exit_status.success() {
296                    // Propagate the error if any
297                    return Err(format!("Fail to format file `{name}`").into());
298                }
299            }
300
301            if merge_fallback_files.is_empty() {
302                println!("... Done.");
303            } else {
304                return Err(format!(
305                    "Merge fallback applied for {} file(s): \n{}",
306                    merge_fallback_files.len(),
307                    merge_fallback_files.join("\n")
308                )
309                .into());
310            }
311        }
312    }
313
314    Ok(())
315}
316
317async fn sqlx_connect<DB>(
318    max_connections: u32,
319    acquire_timeout: u64,
320    url: &str,
321    schema: Option<&str>,
322) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
323where
324    DB: sqlx::Database,
325    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
326{
327    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
328        .max_connections(max_connections)
329        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
330    // Set search_path for Postgres, E.g. Some("public") by default
331    // MySQL & SQLite connection initialize with schema `None`
332    if let Some(schema) = schema {
333        let sql = format!("SET search_path = '{schema}'");
334        pool_options = pool_options.after_connect(move |conn, _| {
335            let sql = sql.clone();
336            Box::pin(async move {
337                sqlx::Executor::execute(conn, sql.as_str())
338                    .await
339                    .map(|_| ())
340            })
341        });
342    }
343    pool_options.connect(url).await.map_err(Into::into)
344}
345
346impl From<DateTimeCrate> for CodegenDateTimeCrate {
347    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
348        match date_time_crate {
349            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
350            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
351        }
352    }
353}
354
355impl From<BigIntegerType> for CodegenBigIntegerType {
356    fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
357        match date_time_crate {
358            BigIntegerType::I64 => CodegenBigIntegerType::I64,
359            BigIntegerType::I32 => CodegenBigIntegerType::I32,
360        }
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use clap::Parser;
367
368    use super::*;
369    use crate::{Cli, Commands};
370
371    #[test]
372    #[should_panic(
373        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
374    )]
375    fn test_generate_entity_no_protocol() {
376        let cli = Cli::parse_from([
377            "sea-orm-cli",
378            "generate",
379            "entity",
380            "--database-url",
381            "://root:root@localhost:3306/database",
382        ]);
383
384        match cli.command {
385            Commands::Generate { command } => {
386                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
387            }
388            _ => unreachable!(),
389        }
390    }
391
392    #[test]
393    #[should_panic(
394        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
395    )]
396    fn test_generate_entity_no_database_section() {
397        let cli = Cli::parse_from([
398            "sea-orm-cli",
399            "generate",
400            "entity",
401            "--database-url",
402            "postgresql://root:root@localhost:3306",
403        ]);
404
405        match cli.command {
406            Commands::Generate { command } => {
407                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
408            }
409            _ => unreachable!(),
410        }
411    }
412
413    #[test]
414    #[should_panic(
415        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
416    )]
417    fn test_generate_entity_no_database_path() {
418        let cli = Cli::parse_from([
419            "sea-orm-cli",
420            "generate",
421            "entity",
422            "--database-url",
423            "mysql://root:root@localhost:3306/",
424        ]);
425
426        match cli.command {
427            Commands::Generate { command } => {
428                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
429            }
430            _ => unreachable!(),
431        }
432    }
433
434    #[test]
435    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
436    fn test_generate_entity_no_host() {
437        let cli = Cli::parse_from([
438            "sea-orm-cli",
439            "generate",
440            "entity",
441            "--database-url",
442            "postgres://root:root@/database",
443        ]);
444
445        match cli.command {
446            Commands::Generate { command } => {
447                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
448            }
449            _ => unreachable!(),
450        }
451    }
452}