sea_orm_cli/commands/
generate.rs

1use core::time;
2use sea_orm_codegen::{
3    DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
4    WithPrelude, WithSerde,
5};
6use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{prelude::*, EnvFilter};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            compact_format: _,
19            expanded_format,
20            frontend_format,
21            include_hidden_tables,
22            tables,
23            ignore_tables,
24            max_connections,
25            acquire_timeout,
26            output_dir,
27            database_schema,
28            database_url,
29            with_prelude,
30            with_serde,
31            serde_skip_deserializing_primary_key,
32            serde_skip_hidden_column,
33            with_copy_enums,
34            date_time_crate,
35            lib,
36            model_extra_derives,
37            model_extra_attributes,
38            enum_extra_derives,
39            enum_extra_attributes,
40            seaography,
41            impl_active_model_behavior,
42        } => {
43            if verbose {
44                let _ = tracing_subscriber::fmt()
45                    .with_max_level(tracing::Level::DEBUG)
46                    .with_test_writer()
47                    .try_init();
48            } else {
49                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
50                let fmt_layer = tracing_subscriber::fmt::layer()
51                    .with_target(false)
52                    .with_level(false)
53                    .without_time();
54
55                let _ = tracing_subscriber::registry()
56                    .with(filter_layer)
57                    .with(fmt_layer)
58                    .try_init();
59            }
60
61            // The database should be a valid URL that can be parsed
62            // protocol://username:password@host/database_name
63            let url = Url::parse(&database_url)?;
64
65            // Make sure we have all the required url components
66            //
67            // Missing scheme will have been caught by the Url::parse() call
68            // above
69            let is_sqlite = url.scheme() == "sqlite";
70
71            // Closures for filtering tables
72            let filter_tables =
73                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
74
75            let filter_hidden_tables = |table: &str| -> bool {
76                if include_hidden_tables {
77                    true
78                } else {
79                    !table.starts_with('_')
80                }
81            };
82
83            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
84
85            let _database_name = if !is_sqlite {
86                // The database name should be the first element of the path string
87                //
88                // Throwing an error if there is no database name since it might be
89                // accepted by the database without it, while we're looking to dump
90                // information from a particular database
91                let database_name = url
92                    .path_segments()
93                    .unwrap_or_else(|| {
94                        panic!(
95                            "There is no database name as part of the url path: {}",
96                            url.as_str()
97                        )
98                    })
99                    .next()
100                    .unwrap();
101
102                // An empty string as the database name is also an error
103                if database_name.is_empty() {
104                    panic!(
105                        "There is no database name as part of the url path: {}",
106                        url.as_str()
107                    );
108                }
109
110                database_name
111            } else {
112                Default::default()
113            };
114
115            let (schema_name, table_stmts) = match url.scheme() {
116                "mysql" => {
117                    #[cfg(not(feature = "sqlx-mysql"))]
118                    {
119                        panic!("mysql feature is off")
120                    }
121                    #[cfg(feature = "sqlx-mysql")]
122                    {
123                        use sea_schema::mysql::discovery::SchemaDiscovery;
124                        use sqlx::MySql;
125
126                        println!("Connecting to MySQL ...");
127                        let connection = sqlx_connect::<MySql>(
128                            max_connections,
129                            acquire_timeout,
130                            url.as_str(),
131                            None,
132                        )
133                        .await?;
134                        println!("Discovering schema ...");
135                        let schema_discovery = SchemaDiscovery::new(connection, _database_name);
136                        let schema = schema_discovery.discover().await?;
137                        let table_stmts = schema
138                            .tables
139                            .into_iter()
140                            .filter(|schema| filter_tables(&schema.info.name))
141                            .filter(|schema| filter_hidden_tables(&schema.info.name))
142                            .filter(|schema| filter_skip_tables(&schema.info.name))
143                            .map(|schema| schema.write())
144                            .collect();
145                        (None, table_stmts)
146                    }
147                }
148                "sqlite" => {
149                    #[cfg(not(feature = "sqlx-sqlite"))]
150                    {
151                        panic!("sqlite feature is off")
152                    }
153                    #[cfg(feature = "sqlx-sqlite")]
154                    {
155                        use sea_schema::sqlite::discovery::SchemaDiscovery;
156                        use sqlx::Sqlite;
157
158                        println!("Connecting to SQLite ...");
159                        let connection = sqlx_connect::<Sqlite>(
160                            max_connections,
161                            acquire_timeout,
162                            url.as_str(),
163                            None,
164                        )
165                        .await?;
166                        println!("Discovering schema ...");
167                        let schema_discovery = SchemaDiscovery::new(connection);
168                        let schema = schema_discovery
169                            .discover()
170                            .await?
171                            .merge_indexes_into_table();
172                        let table_stmts = schema
173                            .tables
174                            .into_iter()
175                            .filter(|schema| filter_tables(&schema.name))
176                            .filter(|schema| filter_hidden_tables(&schema.name))
177                            .filter(|schema| filter_skip_tables(&schema.name))
178                            .map(|schema| schema.write())
179                            .collect();
180                        (None, table_stmts)
181                    }
182                }
183                "postgres" | "postgresql" => {
184                    #[cfg(not(feature = "sqlx-postgres"))]
185                    {
186                        panic!("postgres feature is off")
187                    }
188                    #[cfg(feature = "sqlx-postgres")]
189                    {
190                        use sea_schema::postgres::discovery::SchemaDiscovery;
191                        use sqlx::Postgres;
192
193                        println!("Connecting to Postgres ...");
194                        let schema = database_schema.as_deref().unwrap_or("public");
195                        let connection = sqlx_connect::<Postgres>(
196                            max_connections,
197                            acquire_timeout,
198                            url.as_str(),
199                            Some(schema),
200                        )
201                        .await?;
202                        println!("Discovering schema ...");
203                        let schema_discovery = SchemaDiscovery::new(connection, schema);
204                        let schema = schema_discovery.discover().await?;
205                        let table_stmts = schema
206                            .tables
207                            .into_iter()
208                            .filter(|schema| filter_tables(&schema.info.name))
209                            .filter(|schema| filter_hidden_tables(&schema.info.name))
210                            .filter(|schema| filter_skip_tables(&schema.info.name))
211                            .map(|schema| schema.write())
212                            .collect();
213                        (database_schema, table_stmts)
214                    }
215                }
216                _ => unimplemented!("{} is not supported", url.scheme()),
217            };
218            println!("... discovered.");
219
220            let writer_context = EntityWriterContext::new(
221                expanded_format,
222                frontend_format,
223                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
224                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
225                with_copy_enums,
226                date_time_crate.into(),
227                schema_name,
228                lib,
229                serde_skip_deserializing_primary_key,
230                serde_skip_hidden_column,
231                model_extra_derives,
232                model_extra_attributes,
233                enum_extra_derives,
234                enum_extra_attributes,
235                seaography,
236                impl_active_model_behavior,
237            );
238            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
239
240            let dir = Path::new(&output_dir);
241            fs::create_dir_all(dir)?;
242
243            for OutputFile { name, content } in output.files.iter() {
244                let file_path = dir.join(name);
245                println!("Writing {}", file_path.display());
246                let mut file = fs::File::create(file_path)?;
247                file.write_all(content.as_bytes())?;
248            }
249
250            // Format each of the files
251            for OutputFile { name, .. } in output.files.iter() {
252                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
253                if !exit_status.success() {
254                    // Propagate the error if any
255                    return Err(format!("Fail to format file `{name}`").into());
256                }
257            }
258
259            println!("... Done.");
260        }
261    }
262
263    Ok(())
264}
265
266async fn sqlx_connect<DB>(
267    max_connections: u32,
268    acquire_timeout: u64,
269    url: &str,
270    schema: Option<&str>,
271) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
272where
273    DB: sqlx::Database,
274    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
275{
276    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
277        .max_connections(max_connections)
278        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
279    // Set search_path for Postgres, E.g. Some("public") by default
280    // MySQL & SQLite connection initialize with schema `None`
281    if let Some(schema) = schema {
282        let sql = format!("SET search_path = '{schema}'");
283        pool_options = pool_options.after_connect(move |conn, _| {
284            let sql = sql.clone();
285            Box::pin(async move {
286                sqlx::Executor::execute(conn, sql.as_str())
287                    .await
288                    .map(|_| ())
289            })
290        });
291    }
292    pool_options.connect(url).await.map_err(Into::into)
293}
294
295impl From<DateTimeCrate> for CodegenDateTimeCrate {
296    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
297        match date_time_crate {
298            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
299            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use clap::Parser;
307
308    use super::*;
309    use crate::{Cli, Commands};
310
311    #[test]
312    #[should_panic(
313        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
314    )]
315    fn test_generate_entity_no_protocol() {
316        let cli = Cli::parse_from([
317            "sea-orm-cli",
318            "generate",
319            "entity",
320            "--database-url",
321            "://root:root@localhost:3306/database",
322        ]);
323
324        match cli.command {
325            Commands::Generate { command } => {
326                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
327            }
328            _ => unreachable!(),
329        }
330    }
331
332    #[test]
333    #[should_panic(
334        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
335    )]
336    fn test_generate_entity_no_database_section() {
337        let cli = Cli::parse_from([
338            "sea-orm-cli",
339            "generate",
340            "entity",
341            "--database-url",
342            "postgresql://root:root@localhost:3306",
343        ]);
344
345        match cli.command {
346            Commands::Generate { command } => {
347                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
348            }
349            _ => unreachable!(),
350        }
351    }
352
353    #[test]
354    #[should_panic(
355        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
356    )]
357    fn test_generate_entity_no_database_path() {
358        let cli = Cli::parse_from([
359            "sea-orm-cli",
360            "generate",
361            "entity",
362            "--database-url",
363            "mysql://root:root@localhost:3306/",
364        ]);
365
366        match cli.command {
367            Commands::Generate { command } => {
368                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
369            }
370            _ => unreachable!(),
371        }
372    }
373
374    #[test]
375    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
376    fn test_generate_entity_no_password() {
377        let cli = Cli::parse_from([
378            "sea-orm-cli",
379            "generate",
380            "entity",
381            "--database-url",
382            "mysql://root:@localhost:3306/database",
383        ]);
384
385        match cli.command {
386            Commands::Generate { command } => {
387                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
388            }
389            _ => unreachable!(),
390        }
391    }
392
393    #[test]
394    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
395    fn test_generate_entity_no_host() {
396        let cli = Cli::parse_from([
397            "sea-orm-cli",
398            "generate",
399            "entity",
400            "--database-url",
401            "postgres://root:root@/database",
402        ]);
403
404        match cli.command {
405            Commands::Generate { command } => {
406                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
407            }
408            _ => unreachable!(),
409        }
410    }
411}