1use core::time;
2use sea_orm_codegen::{
3 DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
4 MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
5};
6use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
7use tracing_subscriber::{EnvFilter, prelude::*};
8use url::Url;
9
10use crate::{DateTimeCrate, GenerateSubcommands};
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 entity_format,
19 compact_format: _,
20 expanded_format,
21 frontend_format,
22 include_hidden_tables,
23 tables,
24 ignore_tables,
25 max_connections,
26 acquire_timeout,
27 output_dir,
28 database_schema,
29 database_url,
30 with_prelude,
31 with_serde,
32 serde_skip_deserializing_primary_key,
33 serde_skip_hidden_column,
34 with_copy_enums,
35 date_time_crate,
36 lib,
37 model_extra_derives,
38 model_extra_attributes,
39 enum_extra_derives,
40 enum_extra_attributes,
41 column_extra_derives,
42 seaography,
43 impl_active_model_behavior,
44 preserve_user_modifications,
45 } => {
46 if verbose {
47 let _ = tracing_subscriber::fmt()
48 .with_max_level(tracing::Level::DEBUG)
49 .with_test_writer()
50 .try_init();
51 } else {
52 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
53 let fmt_layer = tracing_subscriber::fmt::layer()
54 .with_target(false)
55 .with_level(false)
56 .without_time();
57
58 let _ = tracing_subscriber::registry()
59 .with(filter_layer)
60 .with(fmt_layer)
61 .try_init();
62 }
63
64 let url = Url::parse(&database_url)?;
67
68 let is_sqlite = url.scheme() == "sqlite";
73
74 let filter_tables =
76 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
77
78 let filter_hidden_tables = |table: &str| -> bool {
79 if include_hidden_tables {
80 true
81 } else {
82 !table.starts_with('_')
83 }
84 };
85
86 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
87
88 let _database_name = if !is_sqlite {
89 let database_name = url
95 .path_segments()
96 .unwrap_or_else(|| {
97 panic!(
98 "There is no database name as part of the url path: {}",
99 url.as_str()
100 )
101 })
102 .next()
103 .unwrap();
104
105 if database_name.is_empty() {
107 panic!(
108 "There is no database name as part of the url path: {}",
109 url.as_str()
110 );
111 }
112
113 database_name
114 } else {
115 Default::default()
116 };
117
118 let (schema_name, table_stmts) = match url.scheme() {
119 "mysql" => {
120 #[cfg(not(feature = "sqlx-mysql"))]
121 {
122 panic!("mysql feature is off")
123 }
124 #[cfg(feature = "sqlx-mysql")]
125 {
126 use sea_schema::mysql::discovery::SchemaDiscovery;
127 use sqlx::MySql;
128
129 println!("Connecting to MySQL ...");
130 let connection = sqlx_connect::<MySql>(
131 max_connections,
132 acquire_timeout,
133 url.as_str(),
134 None,
135 )
136 .await?;
137 println!("Discovering schema ...");
138 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
139 let schema = schema_discovery.discover().await?;
140 let table_stmts = schema
141 .tables
142 .into_iter()
143 .filter(|schema| filter_tables(&schema.info.name))
144 .filter(|schema| filter_hidden_tables(&schema.info.name))
145 .filter(|schema| filter_skip_tables(&schema.info.name))
146 .map(|schema| schema.write())
147 .collect();
148 (None, table_stmts)
149 }
150 }
151 "sqlite" => {
152 #[cfg(not(feature = "sqlx-sqlite"))]
153 {
154 panic!("sqlite feature is off")
155 }
156 #[cfg(feature = "sqlx-sqlite")]
157 {
158 use sea_schema::sqlite::discovery::SchemaDiscovery;
159 use sqlx::Sqlite;
160
161 println!("Connecting to SQLite ...");
162 let connection = sqlx_connect::<Sqlite>(
163 max_connections,
164 acquire_timeout,
165 url.as_str(),
166 None,
167 )
168 .await?;
169 println!("Discovering schema ...");
170 let schema_discovery = SchemaDiscovery::new(connection);
171 let schema = schema_discovery
172 .discover()
173 .await?
174 .merge_indexes_into_table();
175 let table_stmts = schema
176 .tables
177 .into_iter()
178 .filter(|schema| filter_tables(&schema.name))
179 .filter(|schema| filter_hidden_tables(&schema.name))
180 .filter(|schema| filter_skip_tables(&schema.name))
181 .map(|schema| schema.write())
182 .collect();
183 (None, table_stmts)
184 }
185 }
186 "postgres" | "postgresql" => {
187 #[cfg(not(feature = "sqlx-postgres"))]
188 {
189 panic!("postgres feature is off")
190 }
191 #[cfg(feature = "sqlx-postgres")]
192 {
193 use sea_schema::postgres::discovery::SchemaDiscovery;
194 use sqlx::Postgres;
195
196 println!("Connecting to Postgres ...");
197 let schema = database_schema.as_deref().unwrap_or("public");
198 let connection = sqlx_connect::<Postgres>(
199 max_connections,
200 acquire_timeout,
201 url.as_str(),
202 Some(schema),
203 )
204 .await?;
205 println!("Discovering schema ...");
206 let schema_discovery = SchemaDiscovery::new(connection, schema);
207 let schema = schema_discovery.discover().await?;
208 let table_stmts = schema
209 .tables
210 .into_iter()
211 .filter(|schema| filter_tables(&schema.info.name))
212 .filter(|schema| filter_hidden_tables(&schema.info.name))
213 .filter(|schema| filter_skip_tables(&schema.info.name))
214 .map(|schema| schema.write())
215 .collect();
216 (database_schema, table_stmts)
217 }
218 }
219 _ => unimplemented!("{} is not supported", url.scheme()),
220 };
221 println!("... discovered.");
222
223 let writer_context = EntityWriterContext::new(
224 if expanded_format {
225 EntityFormat::Expanded
226 } else if frontend_format {
227 EntityFormat::Frontend
228 } else if let Some(entity_format) = entity_format {
229 EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
230 } else {
231 EntityFormat::default()
232 },
233 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
234 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
235 with_copy_enums,
236 date_time_crate.into(),
237 schema_name,
238 lib,
239 serde_skip_deserializing_primary_key,
240 serde_skip_hidden_column,
241 model_extra_derives,
242 model_extra_attributes,
243 enum_extra_derives,
244 enum_extra_attributes,
245 column_extra_derives,
246 seaography,
247 impl_active_model_behavior,
248 );
249 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
250
251 let dir = Path::new(&output_dir);
252 fs::create_dir_all(dir)?;
253
254 let mut merge_fallback_files: Vec<String> = Vec::new();
255
256 for OutputFile { name, content } in output.files.iter() {
257 let file_path = dir.join(name);
258 println!("Writing {}", file_path.display());
259
260 if !matches!(
261 name.as_str(),
262 "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
263 ) && file_path.exists()
264 && preserve_user_modifications
265 {
266 let prev_content = fs::read_to_string(&file_path)?;
267 match merge_entity_files(&prev_content, content) {
268 Ok(merged) => {
269 fs::write(file_path, merged)?;
270 }
271 Err(MergeReport {
272 output,
273 warnings,
274 fallback_applied,
275 }) => {
276 for message in warnings {
277 eprintln!("{message}");
278 }
279 fs::write(file_path, output)?;
280 if fallback_applied {
281 merge_fallback_files.push(name.clone());
282 }
283 }
284 }
285 } else {
286 fs::write(file_path, content)?;
287 };
288 }
289
290 for OutputFile { name, .. } in output.files.iter() {
292 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
294 return Err(format!("Fail to format file `{name}`").into());
296 }
297 }
298
299 if merge_fallback_files.is_empty() {
300 println!("... Done.");
301 } else {
302 return Err(format!(
303 "Merge fallback applied for {} file(s): \n{}",
304 merge_fallback_files.len(),
305 merge_fallback_files.join("\n")
306 )
307 .into());
308 }
309 }
310 }
311
312 Ok(())
313}
314
315async fn sqlx_connect<DB>(
316 max_connections: u32,
317 acquire_timeout: u64,
318 url: &str,
319 schema: Option<&str>,
320) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
321where
322 DB: sqlx::Database,
323 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
324{
325 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
326 .max_connections(max_connections)
327 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
328 if let Some(schema) = schema {
331 let sql = format!("SET search_path = '{schema}'");
332 pool_options = pool_options.after_connect(move |conn, _| {
333 let sql = sql.clone();
334 Box::pin(async move {
335 sqlx::Executor::execute(conn, sql.as_str())
336 .await
337 .map(|_| ())
338 })
339 });
340 }
341 pool_options.connect(url).await.map_err(Into::into)
342}
343
344impl From<DateTimeCrate> for CodegenDateTimeCrate {
345 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
346 match date_time_crate {
347 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
348 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use clap::Parser;
356
357 use super::*;
358 use crate::{Cli, Commands};
359
360 #[test]
361 #[should_panic(
362 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
363 )]
364 fn test_generate_entity_no_protocol() {
365 let cli = Cli::parse_from([
366 "sea-orm-cli",
367 "generate",
368 "entity",
369 "--database-url",
370 "://root:root@localhost:3306/database",
371 ]);
372
373 match cli.command {
374 Commands::Generate { command } => {
375 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
376 }
377 _ => unreachable!(),
378 }
379 }
380
381 #[test]
382 #[should_panic(
383 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
384 )]
385 fn test_generate_entity_no_database_section() {
386 let cli = Cli::parse_from([
387 "sea-orm-cli",
388 "generate",
389 "entity",
390 "--database-url",
391 "postgresql://root:root@localhost:3306",
392 ]);
393
394 match cli.command {
395 Commands::Generate { command } => {
396 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
397 }
398 _ => unreachable!(),
399 }
400 }
401
402 #[test]
403 #[should_panic(
404 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
405 )]
406 fn test_generate_entity_no_database_path() {
407 let cli = Cli::parse_from([
408 "sea-orm-cli",
409 "generate",
410 "entity",
411 "--database-url",
412 "mysql://root:root@localhost:3306/",
413 ]);
414
415 match cli.command {
416 Commands::Generate { command } => {
417 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
418 }
419 _ => unreachable!(),
420 }
421 }
422
423 #[test]
424 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
425 fn test_generate_entity_no_host() {
426 let cli = Cli::parse_from([
427 "sea-orm-cli",
428 "generate",
429 "entity",
430 "--database-url",
431 "postgres://root:root@/database",
432 ]);
433
434 match cli.command {
435 Commands::Generate { command } => {
436 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
437 }
438 _ => unreachable!(),
439 }
440 }
441}