sea_orm_cli/commands/
generate.rs

1use core::time;
2use sea_orm_codegen::{
3    DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
4    WithPrelude, WithSerde,
5};
6use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{prelude::*, EnvFilter};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            compact_format: _,
19            expanded_format,
20            include_hidden_tables,
21            tables,
22            ignore_tables,
23            max_connections,
24            acquire_timeout,
25            output_dir,
26            database_schema,
27            database_url,
28            with_prelude,
29            with_serde,
30            serde_skip_deserializing_primary_key,
31            serde_skip_hidden_column,
32            with_copy_enums,
33            date_time_crate,
34            lib,
35            model_extra_derives,
36            model_extra_attributes,
37            enum_extra_derives,
38            enum_extra_attributes,
39            seaography,
40            impl_active_model_behavior,
41        } => {
42            if verbose {
43                let _ = tracing_subscriber::fmt()
44                    .with_max_level(tracing::Level::DEBUG)
45                    .with_test_writer()
46                    .try_init();
47            } else {
48                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
49                let fmt_layer = tracing_subscriber::fmt::layer()
50                    .with_target(false)
51                    .with_level(false)
52                    .without_time();
53
54                let _ = tracing_subscriber::registry()
55                    .with(filter_layer)
56                    .with(fmt_layer)
57                    .try_init();
58            }
59
60            // The database should be a valid URL that can be parsed
61            // protocol://username:password@host/database_name
62            let url = Url::parse(&database_url)?;
63
64            // Make sure we have all the required url components
65            //
66            // Missing scheme will have been caught by the Url::parse() call
67            // above
68            let is_sqlite = url.scheme() == "sqlite";
69
70            // Closures for filtering tables
71            let filter_tables =
72                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
73
74            let filter_hidden_tables = |table: &str| -> bool {
75                if include_hidden_tables {
76                    true
77                } else {
78                    !table.starts_with('_')
79                }
80            };
81
82            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
83
84            let _database_name = if !is_sqlite {
85                // The database name should be the first element of the path string
86                //
87                // Throwing an error if there is no database name since it might be
88                // accepted by the database without it, while we're looking to dump
89                // information from a particular database
90                let database_name = url
91                    .path_segments()
92                    .unwrap_or_else(|| {
93                        panic!(
94                            "There is no database name as part of the url path: {}",
95                            url.as_str()
96                        )
97                    })
98                    .next()
99                    .unwrap();
100
101                // An empty string as the database name is also an error
102                if database_name.is_empty() {
103                    panic!(
104                        "There is no database name as part of the url path: {}",
105                        url.as_str()
106                    );
107                }
108
109                database_name
110            } else {
111                Default::default()
112            };
113
114            let (schema_name, table_stmts) = match url.scheme() {
115                "mysql" => {
116                    #[cfg(not(feature = "sqlx-mysql"))]
117                    {
118                        panic!("mysql feature is off")
119                    }
120                    #[cfg(feature = "sqlx-mysql")]
121                    {
122                        use sea_schema::mysql::discovery::SchemaDiscovery;
123                        use sqlx::MySql;
124
125                        println!("Connecting to MySQL ...");
126                        let connection = sqlx_connect::<MySql>(
127                            max_connections,
128                            acquire_timeout,
129                            url.as_str(),
130                            None,
131                        )
132                        .await?;
133                        println!("Discovering schema ...");
134                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
135                        let schema = schema_discovery.discover().await?;
136                        let table_stmts = schema
137                            .tables
138                            .into_iter()
139                            .filter(|schema| filter_tables(&schema.info.name))
140                            .filter(|schema| filter_hidden_tables(&schema.info.name))
141                            .filter(|schema| filter_skip_tables(&schema.info.name))
142                            .map(|schema| schema.write())
143                            .collect();
144                        (None, table_stmts)
145                    }
146                }
147                "sqlite" => {
148                    #[cfg(not(feature = "sqlx-sqlite"))]
149                    {
150                        panic!("sqlite feature is off")
151                    }
152                    #[cfg(feature = "sqlx-sqlite")]
153                    {
154                        use sea_schema::sqlite::discovery::SchemaDiscovery;
155                        use sqlx::Sqlite;
156
157                        println!("Connecting to SQLite ...");
158                        let connection = sqlx_connect::<Sqlite>(
159                            max_connections,
160                            acquire_timeout,
161                            url.as_str(),
162                            None,
163                        )
164                        .await?;
165                        println!("Discovering schema ...");
166                        let schema_discovery = SchemaDiscovery::new(connection);
167                        let schema = schema_discovery
168                            .discover()
169                            .await?
170                            .merge_indexes_into_table();
171                        let table_stmts = schema
172                            .tables
173                            .into_iter()
174                            .filter(|schema| filter_tables(&schema.name))
175                            .filter(|schema| filter_hidden_tables(&schema.name))
176                            .filter(|schema| filter_skip_tables(&schema.name))
177                            .map(|schema| schema.write())
178                            .collect();
179                        (None, table_stmts)
180                    }
181                }
182                "postgres" | "postgresql" => {
183                    #[cfg(not(feature = "sqlx-postgres"))]
184                    {
185                        panic!("postgres feature is off")
186                    }
187                    #[cfg(feature = "sqlx-postgres")]
188                    {
189                        use sea_schema::postgres::discovery::SchemaDiscovery;
190                        use sqlx::Postgres;
191
192                        println!("Connecting to Postgres ...");
193                        let schema = database_schema.as_deref().unwrap_or("public");
194                        let connection = sqlx_connect::<Postgres>(
195                            max_connections,
196                            acquire_timeout,
197                            url.as_str(),
198                            Some(schema),
199                        )
200                        .await?;
201                        println!("Discovering schema ...");
202                        let schema_discovery = SchemaDiscovery::new(connection, schema);
203                        let schema = schema_discovery.discover().await?;
204                        let table_stmts = schema
205                            .tables
206                            .into_iter()
207                            .filter(|schema| filter_tables(&schema.info.name))
208                            .filter(|schema| filter_hidden_tables(&schema.info.name))
209                            .filter(|schema| filter_skip_tables(&schema.info.name))
210                            .map(|schema| schema.write())
211                            .collect();
212                        (database_schema, table_stmts)
213                    }
214                }
215                _ => unimplemented!("{} is not supported", url.scheme()),
216            };
217            println!("... discovered.");
218
219            let writer_context = EntityWriterContext::new(
220                expanded_format,
221                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
222                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
223                with_copy_enums,
224                date_time_crate.into(),
225                schema_name,
226                lib,
227                serde_skip_deserializing_primary_key,
228                serde_skip_hidden_column,
229                model_extra_derives,
230                model_extra_attributes,
231                enum_extra_derives,
232                enum_extra_attributes,
233                seaography,
234                impl_active_model_behavior,
235            );
236            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
237
238            let dir = Path::new(&output_dir);
239            fs::create_dir_all(dir)?;
240
241            for OutputFile { name, content } in output.files.iter() {
242                let file_path = dir.join(name);
243                println!("Writing {}", file_path.display());
244                let mut file = fs::File::create(file_path)?;
245                file.write_all(content.as_bytes())?;
246            }
247
248            // Format each of the files
249            for OutputFile { name, .. } in output.files.iter() {
250                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
251                if !exit_status.success() {
252                    // Propagate the error if any
253                    return Err(format!("Fail to format file `{name}`").into());
254                }
255            }
256
257            println!("... Done.");
258        }
259    }
260
261    Ok(())
262}
263
264async fn sqlx_connect<DB>(
265    max_connections: u32,
266    acquire_timeout: u64,
267    url: &str,
268    schema: Option<&str>,
269) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
270where
271    DB: sqlx::Database,
272    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
273{
274    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
275        .max_connections(max_connections)
276        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
277    // Set search_path for Postgres, E.g. Some("public") by default
278    // MySQL & SQLite connection initialize with schema `None`
279    if let Some(schema) = schema {
280        let sql = format!("SET search_path = '{schema}'");
281        pool_options = pool_options.after_connect(move |conn, _| {
282            let sql = sql.clone();
283            Box::pin(async move {
284                sqlx::Executor::execute(conn, sql.as_str())
285                    .await
286                    .map(|_| ())
287            })
288        });
289    }
290    pool_options.connect(url).await.map_err(Into::into)
291}
292
293impl From<DateTimeCrate> for CodegenDateTimeCrate {
294    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
295        match date_time_crate {
296            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
297            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use clap::Parser;
305
306    use super::*;
307    use crate::{Cli, Commands};
308
309    #[test]
310    #[should_panic(
311        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
312    )]
313    fn test_generate_entity_no_protocol() {
314        let cli = Cli::parse_from([
315            "sea-orm-cli",
316            "generate",
317            "entity",
318            "--database-url",
319            "://root:root@localhost:3306/database",
320        ]);
321
322        match cli.command {
323            Commands::Generate { command } => {
324                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
325            }
326            _ => unreachable!(),
327        }
328    }
329
330    #[test]
331    #[should_panic(
332        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
333    )]
334    fn test_generate_entity_no_database_section() {
335        let cli = Cli::parse_from([
336            "sea-orm-cli",
337            "generate",
338            "entity",
339            "--database-url",
340            "postgresql://root:root@localhost:3306",
341        ]);
342
343        match cli.command {
344            Commands::Generate { command } => {
345                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
346            }
347            _ => unreachable!(),
348        }
349    }
350
351    #[test]
352    #[should_panic(
353        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
354    )]
355    fn test_generate_entity_no_database_path() {
356        let cli = Cli::parse_from([
357            "sea-orm-cli",
358            "generate",
359            "entity",
360            "--database-url",
361            "mysql://root:root@localhost:3306/",
362        ]);
363
364        match cli.command {
365            Commands::Generate { command } => {
366                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
367            }
368            _ => unreachable!(),
369        }
370    }
371
372    #[test]
373    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
374    fn test_generate_entity_no_password() {
375        let cli = Cli::parse_from([
376            "sea-orm-cli",
377            "generate",
378            "entity",
379            "--database-url",
380            "mysql://root:@localhost:3306/database",
381        ]);
382
383        match cli.command {
384            Commands::Generate { command } => {
385                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
386            }
387            _ => unreachable!(),
388        }
389    }
390
391    #[test]
392    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
393    fn test_generate_entity_no_host() {
394        let cli = Cli::parse_from([
395            "sea-orm-cli",
396            "generate",
397            "entity",
398            "--database-url",
399            "postgres://root:root@/database",
400        ]);
401
402        match cli.command {
403            Commands::Generate { command } => {
404                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
405            }
406            _ => unreachable!(),
407        }
408    }
409}