sea_orm_cli/commands/
generate.rs

1use core::time;
2use sea_orm_codegen::{
3    DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
4    WithPrelude, WithSerde,
5};
6use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{prelude::*, EnvFilter};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13    command: GenerateSubcommands,
14    verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16    match command {
17        GenerateSubcommands::Entity {
18            compact_format: _,
19            expanded_format,
20            include_hidden_tables,
21            tables,
22            ignore_tables,
23            max_connections,
24            acquire_timeout,
25            output_dir,
26            database_schema,
27            database_url,
28            with_prelude,
29            with_serde,
30            serde_skip_deserializing_primary_key,
31            serde_skip_hidden_column,
32            with_copy_enums,
33            date_time_crate,
34            lib,
35            model_extra_derives,
36            model_extra_attributes,
37            enum_extra_derives,
38            enum_extra_attributes,
39            seaography,
40            impl_active_model_behavior,
41        } => {
42            if verbose {
43                let _ = tracing_subscriber::fmt()
44                    .with_max_level(tracing::Level::DEBUG)
45                    .with_test_writer()
46                    .try_init();
47            } else {
48                let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
49                let fmt_layer = tracing_subscriber::fmt::layer()
50                    .with_target(false)
51                    .with_level(false)
52                    .without_time();
53
54                let _ = tracing_subscriber::registry()
55                    .with(filter_layer)
56                    .with(fmt_layer)
57                    .try_init();
58            }
59
60            // The database should be a valid URL that can be parsed
61            // protocol://username:password@host/database_name
62            let url = Url::parse(&database_url)?;
63
64            // Make sure we have all the required url components
65            //
66            // Missing scheme will have been caught by the Url::parse() call
67            // above
68            let is_sqlite = url.scheme() == "sqlite";
69
70            // Closures for filtering tables
71            let filter_tables =
72                |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
73
74            let filter_hidden_tables = |table: &str| -> bool {
75                if include_hidden_tables {
76                    true
77                } else {
78                    !table.starts_with('_')
79                }
80            };
81
82            let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
83
84            let database_name = if !is_sqlite {
85                // The database name should be the first element of the path string
86                //
87                // Throwing an error if there is no database name since it might be
88                // accepted by the database without it, while we're looking to dump
89                // information from a particular database
90                let database_name = url
91                    .path_segments()
92                    .unwrap_or_else(|| {
93                        panic!(
94                            "There is no database name as part of the url path: {}",
95                            url.as_str()
96                        )
97                    })
98                    .next()
99                    .unwrap();
100
101                // An empty string as the database name is also an error
102                if database_name.is_empty() {
103                    panic!(
104                        "There is no database name as part of the url path: {}",
105                        url.as_str()
106                    );
107                }
108
109                database_name
110            } else {
111                Default::default()
112            };
113
114            let (schema_name, table_stmts) = match url.scheme() {
115                "mysql" => {
116                    use sea_schema::mysql::discovery::SchemaDiscovery;
117                    use sqlx::MySql;
118
119                    println!("Connecting to MySQL ...");
120                    let connection =
121                        sqlx_connect::<MySql>(max_connections, acquire_timeout, url.as_str(), None)
122                            .await?;
123
124                    println!("Discovering schema ...");
125                    let schema_discovery = SchemaDiscovery::new(connection, database_name);
126                    let schema = schema_discovery.discover().await?;
127                    let table_stmts = schema
128                        .tables
129                        .into_iter()
130                        .filter(|schema| filter_tables(&schema.info.name))
131                        .filter(|schema| filter_hidden_tables(&schema.info.name))
132                        .filter(|schema| filter_skip_tables(&schema.info.name))
133                        .map(|schema| schema.write())
134                        .collect();
135                    (None, table_stmts)
136                }
137                "sqlite" => {
138                    use sea_schema::sqlite::discovery::SchemaDiscovery;
139                    use sqlx::Sqlite;
140
141                    println!("Connecting to SQLite ...");
142                    let connection = sqlx_connect::<Sqlite>(
143                        max_connections,
144                        acquire_timeout,
145                        url.as_str(),
146                        None,
147                    )
148                    .await?;
149
150                    println!("Discovering schema ...");
151                    let schema_discovery = SchemaDiscovery::new(connection);
152                    let schema = schema_discovery
153                        .discover()
154                        .await?
155                        .merge_indexes_into_table();
156                    let table_stmts = schema
157                        .tables
158                        .into_iter()
159                        .filter(|schema| filter_tables(&schema.name))
160                        .filter(|schema| filter_hidden_tables(&schema.name))
161                        .filter(|schema| filter_skip_tables(&schema.name))
162                        .map(|schema| schema.write())
163                        .collect();
164                    (None, table_stmts)
165                }
166                "postgres" | "postgresql" => {
167                    use sea_schema::postgres::discovery::SchemaDiscovery;
168                    use sqlx::Postgres;
169
170                    println!("Connecting to Postgres ...");
171                    let schema = database_schema.as_deref().unwrap_or("public");
172                    let connection = sqlx_connect::<Postgres>(
173                        max_connections,
174                        acquire_timeout,
175                        url.as_str(),
176                        Some(schema),
177                    )
178                    .await?;
179                    println!("Discovering schema ...");
180                    let schema_discovery = SchemaDiscovery::new(connection, schema);
181                    let schema = schema_discovery.discover().await?;
182                    let table_stmts = schema
183                        .tables
184                        .into_iter()
185                        .filter(|schema| filter_tables(&schema.info.name))
186                        .filter(|schema| filter_hidden_tables(&schema.info.name))
187                        .filter(|schema| filter_skip_tables(&schema.info.name))
188                        .map(|schema| schema.write())
189                        .collect();
190                    (database_schema, table_stmts)
191                }
192                _ => unimplemented!("{} is not supported", url.scheme()),
193            };
194            println!("... discovered.");
195
196            let writer_context = EntityWriterContext::new(
197                expanded_format,
198                WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
199                WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
200                with_copy_enums,
201                date_time_crate.into(),
202                schema_name,
203                lib,
204                serde_skip_deserializing_primary_key,
205                serde_skip_hidden_column,
206                model_extra_derives,
207                model_extra_attributes,
208                enum_extra_derives,
209                enum_extra_attributes,
210                seaography,
211                impl_active_model_behavior,
212            );
213            let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
214
215            let dir = Path::new(&output_dir);
216            fs::create_dir_all(dir)?;
217
218            for OutputFile { name, content } in output.files.iter() {
219                let file_path = dir.join(name);
220                println!("Writing {}", file_path.display());
221                let mut file = fs::File::create(file_path)?;
222                file.write_all(content.as_bytes())?;
223            }
224
225            // Format each of the files
226            for OutputFile { name, .. } in output.files.iter() {
227                let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; // Get the status code
228                if !exit_status.success() {
229                    // Propagate the error if any
230                    return Err(format!("Fail to format file `{name}`").into());
231                }
232            }
233
234            println!("... Done.");
235        }
236    }
237
238    Ok(())
239}
240
241async fn sqlx_connect<DB>(
242    max_connections: u32,
243    acquire_timeout: u64,
244    url: &str,
245    schema: Option<&str>,
246) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
247where
248    DB: sqlx::Database,
249    for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
250{
251    let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
252        .max_connections(max_connections)
253        .acquire_timeout(time::Duration::from_secs(acquire_timeout));
254    // Set search_path for Postgres, E.g. Some("public") by default
255    // MySQL & SQLite connection initialize with schema `None`
256    if let Some(schema) = schema {
257        let sql = format!("SET search_path = '{schema}'");
258        pool_options = pool_options.after_connect(move |conn, _| {
259            let sql = sql.clone();
260            Box::pin(async move {
261                sqlx::Executor::execute(conn, sql.as_str())
262                    .await
263                    .map(|_| ())
264            })
265        });
266    }
267    pool_options.connect(url).await.map_err(Into::into)
268}
269
270impl From<DateTimeCrate> for CodegenDateTimeCrate {
271    fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
272        match date_time_crate {
273            DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
274            DateTimeCrate::Time => CodegenDateTimeCrate::Time,
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use clap::Parser;
282
283    use super::*;
284    use crate::{Cli, Commands};
285
286    #[test]
287    #[should_panic(
288        expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
289    )]
290    fn test_generate_entity_no_protocol() {
291        let cli = Cli::parse_from([
292            "sea-orm-cli",
293            "generate",
294            "entity",
295            "--database-url",
296            "://root:root@localhost:3306/database",
297        ]);
298
299        match cli.command {
300            Commands::Generate { command } => {
301                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
302            }
303            _ => unreachable!(),
304        }
305    }
306
307    #[test]
308    #[should_panic(
309        expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
310    )]
311    fn test_generate_entity_no_database_section() {
312        let cli = Cli::parse_from([
313            "sea-orm-cli",
314            "generate",
315            "entity",
316            "--database-url",
317            "postgresql://root:root@localhost:3306",
318        ]);
319
320        match cli.command {
321            Commands::Generate { command } => {
322                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
323            }
324            _ => unreachable!(),
325        }
326    }
327
328    #[test]
329    #[should_panic(
330        expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
331    )]
332    fn test_generate_entity_no_database_path() {
333        let cli = Cli::parse_from([
334            "sea-orm-cli",
335            "generate",
336            "entity",
337            "--database-url",
338            "mysql://root:root@localhost:3306/",
339        ]);
340
341        match cli.command {
342            Commands::Generate { command } => {
343                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
344            }
345            _ => unreachable!(),
346        }
347    }
348
349    #[test]
350    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
351    fn test_generate_entity_no_password() {
352        let cli = Cli::parse_from([
353            "sea-orm-cli",
354            "generate",
355            "entity",
356            "--database-url",
357            "mysql://root:@localhost:3306/database",
358        ]);
359
360        match cli.command {
361            Commands::Generate { command } => {
362                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
363            }
364            _ => unreachable!(),
365        }
366    }
367
368    #[test]
369    #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
370    fn test_generate_entity_no_host() {
371        let cli = Cli::parse_from([
372            "sea-orm-cli",
373            "generate",
374            "entity",
375            "--database-url",
376            "postgres://root:root@/database",
377        ]);
378
379        match cli.command {
380            Commands::Generate { command } => {
381                smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
382            }
383            _ => unreachable!(),
384        }
385    }
386}