sea_orm_cli/commands/
generate.rs

1use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4    BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
5    DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
6    MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            entity_format,
19            compact_format: _,
20            expanded_format,
21            frontend_format,
22            include_hidden_tables,
23            tables,
24            ignore_tables,
25            max_connections,
26            acquire_timeout,
27            output_dir,
28            database_schema,
29            database_url,
30            with_prelude,
31            with_serde,
32            serde_skip_deserializing_primary_key,
33            serde_skip_hidden_column,
34            with_copy_enums,
35            date_time_crate,
36            big_integer_type,
37            lib,
38            model_extra_derives,
39            model_extra_attributes,
40            enum_extra_derives,
41            enum_extra_attributes,
42            column_extra_derives,
43            seaography,
44            impl_active_model_behavior,
45            preserve_user_modifications,
46            banner_version,
47        } => {
48            if verbose {
49                let _ = tracing_subscriber::fmt()
50                    .with_max_level(tracing::Level::DEBUG)
51                    .with_test_writer()
52                    .try_init();
53            } else {
54                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
55                let fmt_layer = tracing_subscriber::fmt::layer()
56                    .with_target(false)
57                    .with_level(false)
58                    .without_time();
59
60                let _ = tracing_subscriber::registry()
61                    .with(filter_layer)
62                    .with(fmt_layer)
63                    .try_init();
64            }
65
66            // The database should be a valid URL that can be parsed
67            // protocol://username:password@host/database_name
68            let url = Url::parse(&database_url)?;
69
70            // Make sure we have all the required url components
71            //
72            // Missing scheme will have been caught by the Url::parse() call
73            // above
74            let is_sqlite = url.scheme() == "sqlite";
75
76            // Closures for filtering tables
77            let filter_tables =
78                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
79
80            let filter_hidden_tables = |table: &str| -> bool {
81                if include_hidden_tables {
82                    true
83                } else {
84                    !table.starts_with('_')
85                }
86            };
87
88            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
89
90            let _database_name = if !is_sqlite {
91                // The database name should be the first element of the path string
92                //
93                // Throwing an error if there is no database name since it might be
94                // accepted by the database without it, while we're looking to dump
95                // information from a particular database
96                let database_name = url
97                    .path_segments()
98                    .unwrap_or_else(|| {
99                        panic!(
100                            "There is no database name as part of the url path: {}",
101                            url.as_str()
102                        )
103                    })
104                    .next()
105                    .unwrap();
106
107                // An empty string as the database name is also an error
108                if database_name.is_empty() {
109                    panic!(
110                        "There is no database name as part of the url path: {}",
111                        url.as_str()
112                    );
113                }
114
115                database_name
116            } else {
117                Default::default()
118            };
119
120            let (schema_name, table_stmts) = match url.scheme() {
121                "mysql" => {
122                    #[cfg(not(feature = "sqlx-mysql"))]
123                    {
124                        panic!("mysql feature is off")
125                    }
126                    #[cfg(feature = "sqlx-mysql")]
127                    {
128                        use sea_schema::mysql::discovery::SchemaDiscovery;
129                        use sqlx::MySql;
130
131                        println!("Connecting to MySQL ...");
132                        let connection = sqlx_connect::<MySql>(
133                            max_connections,
134                            acquire_timeout,
135                            url.as_str(),
136                            None,
137                        )
138                        .await?;
139                        println!("Discovering schema ...");
140                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
141                        let schema = schema_discovery.discover().await?;
142                        let table_stmts = schema
143                            .tables
144                            .into_iter()
145                            .filter(|schema| filter_tables(&schema.info.name))
146                            .filter(|schema| filter_hidden_tables(&schema.info.name))
147                            .filter(|schema| filter_skip_tables(&schema.info.name))
148                            .map(|schema| schema.write())
149                            .collect();
150                        (None, table_stmts)
151                    }
152                }
153                "sqlite" => {
154                    #[cfg(not(feature = "sqlx-sqlite"))]
155                    {
156                        panic!("sqlite feature is off")
157                    }
158                    #[cfg(feature = "sqlx-sqlite")]
159                    {
160                        use sea_schema::sqlite::discovery::SchemaDiscovery;
161                        use sqlx::Sqlite;
162
163                        println!("Connecting to SQLite ...");
164                        let connection = sqlx_connect::<Sqlite>(
165                            max_connections,
166                            acquire_timeout,
167                            url.as_str(),
168                            None,
169                        )
170                        .await?;
171                        println!("Discovering schema ...");
172                        let schema_discovery = SchemaDiscovery::new(connection);
173                        let schema = schema_discovery
174                            .discover()
175                            .await?
176                            .merge_indexes_into_table();
177                        let table_stmts = schema
178                            .tables
179                            .into_iter()
180                            .filter(|schema| filter_tables(&schema.name))
181                            .filter(|schema| filter_hidden_tables(&schema.name))
182                            .filter(|schema| filter_skip_tables(&schema.name))
183                            .map(|schema| schema.write())
184                            .collect();
185                        (None, table_stmts)
186                    }
187                }
188                "postgres" | "postgresql" => {
189                    #[cfg(not(feature = "sqlx-postgres"))]
190                    {
191                        panic!("postgres feature is off")
192                    }
193                    #[cfg(feature = "sqlx-postgres")]
194                    {
195                        use sea_schema::postgres::discovery::SchemaDiscovery;
196                        use sqlx::Postgres;
197
198                        println!("Connecting to Postgres ...");
199                        let schema = database_schema.as_deref().unwrap_or("public");
200                        let connection = sqlx_connect::<Postgres>(
201                            max_connections,
202                            acquire_timeout,
203                            url.as_str(),
204                            Some(schema),
205                        )
206                        .await?;
207                        println!("Discovering schema ...");
208                        let schema_discovery = SchemaDiscovery::new(connection, schema);
209                        let schema = schema_discovery.discover().await?;
210                        let table_stmts = schema
211                            .tables
212                            .into_iter()
213                            .filter(|schema| filter_tables(&schema.info.name))
214                            .filter(|schema| filter_hidden_tables(&schema.info.name))
215                            .filter(|schema| filter_skip_tables(&schema.info.name))
216                            .map(|schema| schema.write())
217                            .collect();
218                        (database_schema, table_stmts)
219                    }
220                }
221                _ => unimplemented!("{} is not supported", url.scheme()),
222            };
223            println!("... discovered.");
224
225            let writer_context = EntityWriterContext::new(
226                if expanded_format {
227                    EntityFormat::Expanded
228                } else if frontend_format {
229                    EntityFormat::Frontend
230                } else if let Some(entity_format) = entity_format {
231                    EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
232                } else {
233                    EntityFormat::default()
234                },
235                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
236                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
237                with_copy_enums,
238                date_time_crate.into(),
239                big_integer_type.into(),
240                schema_name,
241                lib,
242                serde_skip_deserializing_primary_key,
243                serde_skip_hidden_column,
244                model_extra_derives,
245                model_extra_attributes,
246                enum_extra_derives,
247                enum_extra_attributes,
248                column_extra_derives,
249                seaography,
250                impl_active_model_behavior,
251                banner_version.into(),
252            );
253            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
254
255            let dir = Path::new(&output_dir);
256            fs::create_dir_all(dir)?;
257
258            let mut merge_fallback_files: Vec<String> = Vec::new();
259
260            for OutputFile { name, content } in output.files.iter() {
261                let file_path = dir.join(name);
262                println!("Writing {}", file_path.display());
263
264                if !matches!(
265                    name.as_str(),
266                    "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
267                ) && file_path.exists()
268                    && preserve_user_modifications
269                {
270                    let prev_content = fs::read_to_string(&file_path)?;
271                    match merge_entity_files(&prev_content, content) {
272                        Ok(merged) => {
273                            fs::write(file_path, merged)?;
274                        }
275                        Err(MergeReport {
276                            output,
277                            warnings,
278                            fallback_applied,
279                        }) => {
280                            for message in warnings {
281                                eprintln!("{message}");
282                            }
283                            fs::write(file_path, output)?;
284                            if fallback_applied {
285                                merge_fallback_files.push(name.clone());
286                            }
287                        }
288                    }
289                } else {
290                    fs::write(file_path, content)?;
291                };
292            }
293
294            // Format each of the files
295            for OutputFile { name, .. } in output.files.iter() {
296                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
297                if !exit_status.success() {
298                    // Propagate the error if any
299                    return Err(format!("Fail to format file `{name}`").into());
300                }
301            }
302
303            if merge_fallback_files.is_empty() {
304                println!("... Done.");
305            } else {
306                return Err(format!(
307                    "Merge fallback applied for {} file(s): \n{}",
308                    merge_fallback_files.len(),
309                    merge_fallback_files.join("\n")
310                )
311                .into());
312            }
313        }
314    }
315
316    Ok(())
317}
318
319async fn sqlx_connect<DB>(
320    max_connections: u32,
321    acquire_timeout: u64,
322    url: &str,
323    schema: Option<&str>,
324) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
325where
326    DB: sqlx::Database,
327    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
328{
329    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
330        .max_connections(max_connections)
331        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
332    // Set search_path for Postgres, E.g. Some("public") by default
333    // MySQL & SQLite connection initialize with schema `None`
334    if let Some(schema) = schema {
335        let sql = format!("SET search_path = '{schema}'");
336        pool_options = pool_options.after_connect(move |conn, _| {
337            let sql = sql.clone();
338            Box::pin(async move {
339                sqlx::Executor::execute(conn, sql.as_str())
340                    .await
341                    .map(|_| ())
342            })
343        });
344    }
345    pool_options.connect(url).await.map_err(Into::into)
346}
347
348impl From<DateTimeCrate> for CodegenDateTimeCrate {
349    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
350        match date_time_crate {
351            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
352            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
353        }
354    }
355}
356
357impl From<BigIntegerType> for CodegenBigIntegerType {
358    fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
359        match date_time_crate {
360            BigIntegerType::I64 => CodegenBigIntegerType::I64,
361            BigIntegerType::I32 => CodegenBigIntegerType::I32,
362        }
363    }
364}
365
366impl From<BannerVersion> for CodegenBannerVersion {
367    fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
368        match banner_version {
369            BannerVersion::Off => CodegenBannerVersion::Off,
370            BannerVersion::Major => CodegenBannerVersion::Major,
371            BannerVersion::Minor => CodegenBannerVersion::Minor,
372            BannerVersion::Patch => CodegenBannerVersion::Patch,
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use clap::Parser;
380
381    use super::*;
382    use crate::{Cli, Commands};
383
384    #[test]
385    #[should_panic(
386        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
387    )]
388    fn test_generate_entity_no_protocol() {
389        let cli = Cli::parse_from([
390            "sea-orm-cli",
391            "generate",
392            "entity",
393            "--database-url",
394            "://root:root@localhost:3306/database",
395        ]);
396
397        match cli.command {
398            Commands::Generate { command } => {
399                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
400            }
401            _ => unreachable!(),
402        }
403    }
404
405    #[test]
406    #[should_panic(
407        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
408    )]
409    fn test_generate_entity_no_database_section() {
410        let cli = Cli::parse_from([
411            "sea-orm-cli",
412            "generate",
413            "entity",
414            "--database-url",
415            "postgresql://root:root@localhost:3306",
416        ]);
417
418        match cli.command {
419            Commands::Generate { command } => {
420                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
421            }
422            _ => unreachable!(),
423        }
424    }
425
426    #[test]
427    #[should_panic(
428        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
429    )]
430    fn test_generate_entity_no_database_path() {
431        let cli = Cli::parse_from([
432            "sea-orm-cli",
433            "generate",
434            "entity",
435            "--database-url",
436            "mysql://root:root@localhost:3306/",
437        ]);
438
439        match cli.command {
440            Commands::Generate { command } => {
441                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
442            }
443            _ => unreachable!(),
444        }
445    }
446
447    #[test]
448    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
449    fn test_generate_entity_no_host() {
450        let cli = Cli::parse_from([
451            "sea-orm-cli",
452            "generate",
453            "entity",
454            "--database-url",
455            "postgres://root:root@/database",
456        ]);
457
458        match cli.command {
459            Commands::Generate { command } => {
460                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
461            }
462            _ => unreachable!(),
463        }
464    }
465}