sea_orm_cli/commands/
generate.rs

1use sea_orm_codegen::{
2    DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
3    WithSerde,
4};
5use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
6use tracing_subscriber::{prelude::*, EnvFilter};
7use url::Url;
8
9use crate::{DateTimeCrate, GenerateSubcommands};
10
11pub async fn run_generate_command(
12    command: GenerateSubcommands,
13    verbose: bool,
14) -> Result<(), Box<dyn Error>> {
15    match command {
16        GenerateSubcommands::Entity {
17            compact_format: _,
18            expanded_format,
19            include_hidden_tables,
20            tables,
21            ignore_tables,
22            max_connections,
23            output_dir,
24            database_schema,
25            database_url,
26            with_serde,
27            serde_skip_deserializing_primary_key,
28            serde_skip_hidden_column,
29            with_copy_enums,
30            date_time_crate,
31            lib,
32            model_extra_derives,
33            model_extra_attributes,
34            enum_extra_derives,
35            enum_extra_attributes,
36            seaography,
37        } => {
38            if verbose {
39                let _ = tracing_subscriber::fmt()
40                    .with_max_level(tracing::Level::DEBUG)
41                    .with_test_writer()
42                    .try_init();
43            } else {
44                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
45                let fmt_layer = tracing_subscriber::fmt::layer()
46                    .with_target(false)
47                    .with_level(false)
48                    .without_time();
49
50                let _ = tracing_subscriber::registry()
51                    .with(filter_layer)
52                    .with(fmt_layer)
53                    .try_init();
54            }
55
56            // The database should be a valid URL that can be parsed
57            // protocol://username:password@host/database_name
58            let url = Url::parse(&database_url)?;
59
60            // Make sure we have all the required url components
61            //
62            // Missing scheme will have been caught by the Url::parse() call
63            // above
64            let is_sqlite = url.scheme() == "sqlite";
65
66            // Closures for filtering tables
67            let filter_tables =
68                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
69
70            let filter_hidden_tables = |table: &str| -> bool {
71                if include_hidden_tables {
72                    true
73                } else {
74                    !table.starts_with('_')
75                }
76            };
77
78            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
79
80            let database_name = if !is_sqlite {
81                // The database name should be the first element of the path string
82                //
83                // Throwing an error if there is no database name since it might be
84                // accepted by the database without it, while we're looking to dump
85                // information from a particular database
86                let database_name = url
87                    .path_segments()
88                    .unwrap_or_else(|| {
89                        panic!(
90                            "There is no database name as part of the url path: {}",
91                            url.as_str()
92                        )
93                    })
94                    .next()
95                    .unwrap();
96
97                // An empty string as the database name is also an error
98                if database_name.is_empty() {
99                    panic!(
100                        "There is no database name as part of the url path: {}",
101                        url.as_str()
102                    );
103                }
104
105                database_name
106            } else {
107                Default::default()
108            };
109
110            let (schema_name, table_stmts) = match url.scheme() {
111                "mysql" => {
112                    use sea_schema::mysql::discovery::SchemaDiscovery;
113                    use sqlx::MySql;
114
115                    println!("Connecting to MySQL ...");
116                    let connection =
117                        sqlx_connect::<MySql>(max_connections, url.as_str(), None).await?;
118                    println!("Discovering schema ...");
119                    let schema_discovery = SchemaDiscovery::new(connection, database_name);
120                    let schema = schema_discovery.discover().await?;
121                    let table_stmts = schema
122                        .tables
123                        .into_iter()
124                        .filter(|schema| filter_tables(&schema.info.name))
125                        .filter(|schema| filter_hidden_tables(&schema.info.name))
126                        .filter(|schema| filter_skip_tables(&schema.info.name))
127                        .map(|schema| schema.write())
128                        .collect();
129                    (None, table_stmts)
130                }
131                "sqlite" => {
132                    use sea_schema::sqlite::discovery::SchemaDiscovery;
133                    use sqlx::Sqlite;
134
135                    println!("Connecting to SQLite ...");
136                    let connection =
137                        sqlx_connect::<Sqlite>(max_connections, url.as_str(), None).await?;
138                    println!("Discovering schema ...");
139                    let schema_discovery = SchemaDiscovery::new(connection);
140                    let schema = schema_discovery
141                        .discover()
142                        .await?
143                        .merge_indexes_into_table();
144                    let table_stmts = schema
145                        .tables
146                        .into_iter()
147                        .filter(|schema| filter_tables(&schema.name))
148                        .filter(|schema| filter_hidden_tables(&schema.name))
149                        .filter(|schema| filter_skip_tables(&schema.name))
150                        .map(|schema| schema.write())
151                        .collect();
152                    (None, table_stmts)
153                }
154                "postgres" | "postgresql" => {
155                    use sea_schema::postgres::discovery::SchemaDiscovery;
156                    use sqlx::Postgres;
157
158                    println!("Connecting to Postgres ...");
159                    let schema = database_schema.as_deref().unwrap_or("public");
160                    let connection =
161                        sqlx_connect::<Postgres>(max_connections, url.as_str(), Some(schema))
162                            .await?;
163                    println!("Discovering schema ...");
164                    let schema_discovery = SchemaDiscovery::new(connection, schema);
165                    let schema = schema_discovery.discover().await?;
166                    let table_stmts = schema
167                        .tables
168                        .into_iter()
169                        .filter(|schema| filter_tables(&schema.info.name))
170                        .filter(|schema| filter_hidden_tables(&schema.info.name))
171                        .filter(|schema| filter_skip_tables(&schema.info.name))
172                        .map(|schema| schema.write())
173                        .collect();
174                    (database_schema, table_stmts)
175                }
176                _ => unimplemented!("{} is not supported", url.scheme()),
177            };
178            println!("... discovered.");
179
180            let writer_context = EntityWriterContext::new(
181                expanded_format,
182                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
183                with_copy_enums,
184                date_time_crate.into(),
185                schema_name,
186                lib,
187                serde_skip_deserializing_primary_key,
188                serde_skip_hidden_column,
189                model_extra_derives,
190                model_extra_attributes,
191                enum_extra_derives,
192                enum_extra_attributes,
193                seaography,
194            );
195            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
196
197            let dir = Path::new(&output_dir);
198            fs::create_dir_all(dir)?;
199
200            for OutputFile { name, content } in output.files.iter() {
201                let file_path = dir.join(name);
202                println!("Writing {}", file_path.display());
203                let mut file = fs::File::create(file_path)?;
204                file.write_all(content.as_bytes())?;
205            }
206
207            // Format each of the files
208            for OutputFile { name, .. } in output.files.iter() {
209                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
210                if !exit_status.success() {
211                    // Propagate the error if any
212                    return Err(format!("Fail to format file `{name}`").into());
213                }
214            }
215
216            println!("... Done.");
217        }
218    }
219
220    Ok(())
221}
222
223async fn sqlx_connect<DB>(
224    max_connections: u32,
225    url: &str,
226    schema: Option<&str>,
227) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
228where
229    DB: sqlx::Database,
230    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
231{
232    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new().max_connections(max_connections);
233    // Set search_path for Postgres, E.g. Some("public") by default
234    // MySQL & SQLite connection initialize with schema `None`
235    if let Some(schema) = schema {
236        let sql = format!("SET search_path = '{schema}'");
237        pool_options = pool_options.after_connect(move |conn, _| {
238            let sql = sql.clone();
239            Box::pin(async move {
240                sqlx::Executor::execute(conn, sql.as_str())
241                    .await
242                    .map(|_| ())
243            })
244        });
245    }
246    pool_options.connect(url).await.map_err(Into::into)
247}
248
249impl From<DateTimeCrate> for CodegenDateTimeCrate {
250    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
251        match date_time_crate {
252            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
253            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
254        }
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use clap::Parser;
261
262    use super::*;
263    use crate::{Cli, Commands};
264
265    #[test]
266    #[should_panic(
267        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
268    )]
269    fn test_generate_entity_no_protocol() {
270        let cli = Cli::parse_from([
271            "sea-orm-cli",
272            "generate",
273            "entity",
274            "--database-url",
275            "://root:root@localhost:3306/database",
276        ]);
277
278        match cli.command {
279            Commands::Generate { command } => {
280                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
281            }
282            _ => unreachable!(),
283        }
284    }
285
286    #[test]
287    #[should_panic(
288        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
289    )]
290    fn test_generate_entity_no_database_section() {
291        let cli = Cli::parse_from([
292            "sea-orm-cli",
293            "generate",
294            "entity",
295            "--database-url",
296            "postgresql://root:root@localhost:3306",
297        ]);
298
299        match cli.command {
300            Commands::Generate { command } => {
301                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
302            }
303            _ => unreachable!(),
304        }
305    }
306
307    #[test]
308    #[should_panic(
309        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
310    )]
311    fn test_generate_entity_no_database_path() {
312        let cli = Cli::parse_from([
313            "sea-orm-cli",
314            "generate",
315            "entity",
316            "--database-url",
317            "mysql://root:root@localhost:3306/",
318        ]);
319
320        match cli.command {
321            Commands::Generate { command } => {
322                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
323            }
324            _ => unreachable!(),
325        }
326    }
327
328    #[test]
329    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
330    fn test_generate_entity_no_password() {
331        let cli = Cli::parse_from([
332            "sea-orm-cli",
333            "generate",
334            "entity",
335            "--database-url",
336            "mysql://root:@localhost:3306/database",
337        ]);
338
339        match cli.command {
340            Commands::Generate { command } => {
341                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
342            }
343            _ => unreachable!(),
344        }
345    }
346
347    #[test]
348    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
349    fn test_generate_entity_no_host() {
350        let cli = Cli::parse_from([
351            "sea-orm-cli",
352            "generate",
353            "entity",
354            "--database-url",
355            "postgres://root:root@/database",
356        ]);
357
358        match cli.command {
359            Commands::Generate { command } => {
360                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
361            }
362            _ => unreachable!(),
363        }
364    }
365}