1use core::time;
2use sea_orm_codegen::{
3 DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
4 WithPrelude, WithSerde,
5};
6use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{prelude::*, EnvFilter};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 compact_format: _,
19 expanded_format,
20 include_hidden_tables,
21 tables,
22 ignore_tables,
23 max_connections,
24 acquire_timeout,
25 output_dir,
26 database_schema,
27 database_url,
28 with_prelude,
29 with_serde,
30 serde_skip_deserializing_primary_key,
31 serde_skip_hidden_column,
32 with_copy_enums,
33 date_time_crate,
34 lib,
35 model_extra_derives,
36 model_extra_attributes,
37 enum_extra_derives,
38 enum_extra_attributes,
39 seaography,
40 impl_active_model_behavior,
41 } => {
42 if verbose {
43 let _ = tracing_subscriber::fmt()
44 .with_max_level(tracing::Level::DEBUG)
45 .with_test_writer()
46 .try_init();
47 } else {
48 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
49 let fmt_layer = tracing_subscriber::fmt::layer()
50 .with_target(false)
51 .with_level(false)
52 .without_time();
53
54 let _ = tracing_subscriber::registry()
55 .with(filter_layer)
56 .with(fmt_layer)
57 .try_init();
58 }
59
60 let url = Url::parse(&database_url)?;
63
64 let is_sqlite = url.scheme() == "sqlite";
69
70 let filter_tables =
72 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
73
74 let filter_hidden_tables = |table: &str| -> bool {
75 if include_hidden_tables {
76 true
77 } else {
78 !table.starts_with('_')
79 }
80 };
81
82 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
83
84 let database_name = if !is_sqlite {
85 let database_name = url
91 .path_segments()
92 .unwrap_or_else(|| {
93 panic!(
94 "There is no database name as part of the url path: {}",
95 url.as_str()
96 )
97 })
98 .next()
99 .unwrap();
100
101 if database_name.is_empty() {
103 panic!(
104 "There is no database name as part of the url path: {}",
105 url.as_str()
106 );
107 }
108
109 database_name
110 } else {
111 Default::default()
112 };
113
114 let (schema_name, table_stmts) = match url.scheme() {
115 "mysql" => {
116 use sea_schema::mysql::discovery::SchemaDiscovery;
117 use sqlx::MySql;
118
119 println!("Connecting to MySQL ...");
120 let connection =
121 sqlx_connect::<MySql>(max_connections, acquire_timeout, url.as_str(), None)
122 .await?;
123
124 println!("Discovering schema ...");
125 let schema_discovery = SchemaDiscovery::new(connection, database_name);
126 let schema = schema_discovery.discover().await?;
127 let table_stmts = schema
128 .tables
129 .into_iter()
130 .filter(|schema| filter_tables(&schema.info.name))
131 .filter(|schema| filter_hidden_tables(&schema.info.name))
132 .filter(|schema| filter_skip_tables(&schema.info.name))
133 .map(|schema| schema.write())
134 .collect();
135 (None, table_stmts)
136 }
137 "sqlite" => {
138 use sea_schema::sqlite::discovery::SchemaDiscovery;
139 use sqlx::Sqlite;
140
141 println!("Connecting to SQLite ...");
142 let connection = sqlx_connect::<Sqlite>(
143 max_connections,
144 acquire_timeout,
145 url.as_str(),
146 None,
147 )
148 .await?;
149
150 println!("Discovering schema ...");
151 let schema_discovery = SchemaDiscovery::new(connection);
152 let schema = schema_discovery
153 .discover()
154 .await?
155 .merge_indexes_into_table();
156 let table_stmts = schema
157 .tables
158 .into_iter()
159 .filter(|schema| filter_tables(&schema.name))
160 .filter(|schema| filter_hidden_tables(&schema.name))
161 .filter(|schema| filter_skip_tables(&schema.name))
162 .map(|schema| schema.write())
163 .collect();
164 (None, table_stmts)
165 }
166 "postgres" | "postgresql" => {
167 use sea_schema::postgres::discovery::SchemaDiscovery;
168 use sqlx::Postgres;
169
170 println!("Connecting to Postgres ...");
171 let schema = database_schema.as_deref().unwrap_or("public");
172 let connection = sqlx_connect::<Postgres>(
173 max_connections,
174 acquire_timeout,
175 url.as_str(),
176 Some(schema),
177 )
178 .await?;
179 println!("Discovering schema ...");
180 let schema_discovery = SchemaDiscovery::new(connection, schema);
181 let schema = schema_discovery.discover().await?;
182 let table_stmts = schema
183 .tables
184 .into_iter()
185 .filter(|schema| filter_tables(&schema.info.name))
186 .filter(|schema| filter_hidden_tables(&schema.info.name))
187 .filter(|schema| filter_skip_tables(&schema.info.name))
188 .map(|schema| schema.write())
189 .collect();
190 (database_schema, table_stmts)
191 }
192 _ => unimplemented!("{} is not supported", url.scheme()),
193 };
194 println!("... discovered.");
195
196 let writer_context = EntityWriterContext::new(
197 expanded_format,
198 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
199 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
200 with_copy_enums,
201 date_time_crate.into(),
202 schema_name,
203 lib,
204 serde_skip_deserializing_primary_key,
205 serde_skip_hidden_column,
206 model_extra_derives,
207 model_extra_attributes,
208 enum_extra_derives,
209 enum_extra_attributes,
210 seaography,
211 impl_active_model_behavior,
212 );
213 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
214
215 let dir = Path::new(&output_dir);
216 fs::create_dir_all(dir)?;
217
218 for OutputFile { name, content } in output.files.iter() {
219 let file_path = dir.join(name);
220 println!("Writing {}", file_path.display());
221 let mut file = fs::File::create(file_path)?;
222 file.write_all(content.as_bytes())?;
223 }
224
225 for OutputFile { name, .. } in output.files.iter() {
227 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
229 return Err(format!("Fail to format file `{name}`").into());
231 }
232 }
233
234 println!("... Done.");
235 }
236 }
237
238 Ok(())
239}
240
241async fn sqlx_connect<DB>(
242 max_connections: u32,
243 acquire_timeout: u64,
244 url: &str,
245 schema: Option<&str>,
246) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
247where
248 DB: sqlx::Database,
249 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
250{
251 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
252 .max_connections(max_connections)
253 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
254 if let Some(schema) = schema {
257 let sql = format!("SET search_path = '{schema}'");
258 pool_options = pool_options.after_connect(move |conn, _| {
259 let sql = sql.clone();
260 Box::pin(async move {
261 sqlx::Executor::execute(conn, sql.as_str())
262 .await
263 .map(|_| ())
264 })
265 });
266 }
267 pool_options.connect(url).await.map_err(Into::into)
268}
269
270impl From<DateTimeCrate> for CodegenDateTimeCrate {
271 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
272 match date_time_crate {
273 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
274 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use clap::Parser;
282
283 use super::*;
284 use crate::{Cli, Commands};
285
286 #[test]
287 #[should_panic(
288 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
289 )]
290 fn test_generate_entity_no_protocol() {
291 let cli = Cli::parse_from([
292 "sea-orm-cli",
293 "generate",
294 "entity",
295 "--database-url",
296 "://root:root@localhost:3306/database",
297 ]);
298
299 match cli.command {
300 Commands::Generate { command } => {
301 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
302 }
303 _ => unreachable!(),
304 }
305 }
306
307 #[test]
308 #[should_panic(
309 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
310 )]
311 fn test_generate_entity_no_database_section() {
312 let cli = Cli::parse_from([
313 "sea-orm-cli",
314 "generate",
315 "entity",
316 "--database-url",
317 "postgresql://root:root@localhost:3306",
318 ]);
319
320 match cli.command {
321 Commands::Generate { command } => {
322 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
323 }
324 _ => unreachable!(),
325 }
326 }
327
328 #[test]
329 #[should_panic(
330 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
331 )]
332 fn test_generate_entity_no_database_path() {
333 let cli = Cli::parse_from([
334 "sea-orm-cli",
335 "generate",
336 "entity",
337 "--database-url",
338 "mysql://root:root@localhost:3306/",
339 ]);
340
341 match cli.command {
342 Commands::Generate { command } => {
343 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
344 }
345 _ => unreachable!(),
346 }
347 }
348
349 #[test]
350 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
351 fn test_generate_entity_no_password() {
352 let cli = Cli::parse_from([
353 "sea-orm-cli",
354 "generate",
355 "entity",
356 "--database-url",
357 "mysql://root:@localhost:3306/database",
358 ]);
359
360 match cli.command {
361 Commands::Generate { command } => {
362 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
363 }
364 _ => unreachable!(),
365 }
366 }
367
368 #[test]
369 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
370 fn test_generate_entity_no_host() {
371 let cli = Cli::parse_from([
372 "sea-orm-cli",
373 "generate",
374 "entity",
375 "--database-url",
376 "postgres://root:root@/database",
377 ]);
378
379 match cli.command {
380 Commands::Generate { command } => {
381 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
382 }
383 _ => unreachable!(),
384 }
385 }
386}