1use crate::{BannerVersion, BigIntegerType, DateTimeCrate, GenerateSubcommands};
2use core::time;
3use sea_orm_codegen::{
4 BannerVersion as CodegenBannerVersion, BigIntegerType as CodegenBigIntegerType,
5 DateTimeCrate as CodegenDateTimeCrate, EntityFormat, EntityTransformer, EntityWriterContext,
6 MergeReport, OutputFile, WithPrelude, WithSerde, merge_entity_files,
7};
8use std::{error::Error, fs, path::Path, process::Command, str::FromStr};
9use tracing_subscriber::{EnvFilter, prelude::*};
10use url::Url;
11
12pub async fn run_generate_command(
13 command: GenerateSubcommands,
14 verbose: bool,
15) -> Result<(), Box<dyn Error>> {
16 match command {
17 GenerateSubcommands::Entity {
18 entity_format,
19 compact_format: _,
20 expanded_format,
21 frontend_format,
22 include_hidden_tables,
23 tables,
24 ignore_tables,
25 max_connections,
26 acquire_timeout,
27 output_dir,
28 database_schema,
29 database_url,
30 with_prelude,
31 with_serde,
32 serde_skip_deserializing_primary_key,
33 serde_skip_hidden_column,
34 with_copy_enums,
35 date_time_crate,
36 big_integer_type,
37 lib,
38 model_extra_derives,
39 model_extra_attributes,
40 enum_extra_derives,
41 enum_extra_attributes,
42 column_extra_derives,
43 seaography,
44 impl_active_model_behavior,
45 preserve_user_modifications,
46 banner_version,
47 } => {
48 if verbose {
49 let _ = tracing_subscriber::fmt()
50 .with_max_level(tracing::Level::DEBUG)
51 .with_test_writer()
52 .try_init();
53 } else {
54 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
55 let fmt_layer = tracing_subscriber::fmt::layer()
56 .with_target(false)
57 .with_level(false)
58 .without_time();
59
60 let _ = tracing_subscriber::registry()
61 .with(filter_layer)
62 .with(fmt_layer)
63 .try_init();
64 }
65
66 let url = Url::parse(&database_url)?;
69
70 let is_sqlite = url.scheme() == "sqlite";
75
76 let filter_tables =
78 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
79
80 let filter_hidden_tables = |table: &str| -> bool {
81 if include_hidden_tables {
82 true
83 } else {
84 !table.starts_with('_')
85 }
86 };
87
88 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
89
90 let _database_name = if !is_sqlite {
91 let database_name = url
97 .path_segments()
98 .unwrap_or_else(|| {
99 panic!(
100 "There is no database name as part of the url path: {}",
101 url.as_str()
102 )
103 })
104 .next()
105 .unwrap();
106
107 if database_name.is_empty() {
109 panic!(
110 "There is no database name as part of the url path: {}",
111 url.as_str()
112 );
113 }
114
115 database_name
116 } else {
117 Default::default()
118 };
119
120 let (schema_name, table_stmts) = match url.scheme() {
121 "mysql" => {
122 #[cfg(not(feature = "sqlx-mysql"))]
123 {
124 panic!("mysql feature is off")
125 }
126 #[cfg(feature = "sqlx-mysql")]
127 {
128 use sea_schema::mysql::discovery::SchemaDiscovery;
129 use sqlx::MySql;
130
131 println!("Connecting to MySQL ...");
132 let connection = sqlx_connect::<MySql>(
133 max_connections,
134 acquire_timeout,
135 url.as_str(),
136 None,
137 )
138 .await?;
139 println!("Discovering schema ...");
140 let schema_discovery = SchemaDiscovery::new(connection, _database_name);
141 let schema = schema_discovery.discover().await?;
142 let table_stmts = schema
143 .tables
144 .into_iter()
145 .filter(|schema| filter_tables(&schema.info.name))
146 .filter(|schema| filter_hidden_tables(&schema.info.name))
147 .filter(|schema| filter_skip_tables(&schema.info.name))
148 .map(|schema| schema.write())
149 .collect();
150 (None, table_stmts)
151 }
152 }
153 "sqlite" => {
154 #[cfg(not(feature = "sqlx-sqlite"))]
155 {
156 panic!("sqlite feature is off")
157 }
158 #[cfg(feature = "sqlx-sqlite")]
159 {
160 use sea_schema::sqlite::discovery::SchemaDiscovery;
161 use sqlx::Sqlite;
162
163 println!("Connecting to SQLite ...");
164 let connection = sqlx_connect::<Sqlite>(
165 max_connections,
166 acquire_timeout,
167 url.as_str(),
168 None,
169 )
170 .await?;
171 println!("Discovering schema ...");
172 let schema_discovery = SchemaDiscovery::new(connection);
173 let schema = schema_discovery
174 .discover()
175 .await?
176 .merge_indexes_into_table();
177 let table_stmts = schema
178 .tables
179 .into_iter()
180 .filter(|schema| filter_tables(&schema.name))
181 .filter(|schema| filter_hidden_tables(&schema.name))
182 .filter(|schema| filter_skip_tables(&schema.name))
183 .map(|schema| schema.write())
184 .collect();
185 (None, table_stmts)
186 }
187 }
188 "postgres" | "postgresql" => {
189 #[cfg(not(feature = "sqlx-postgres"))]
190 {
191 panic!("postgres feature is off")
192 }
193 #[cfg(feature = "sqlx-postgres")]
194 {
195 use sea_schema::postgres::discovery::SchemaDiscovery;
196 use sqlx::Postgres;
197
198 println!("Connecting to Postgres ...");
199 let schema = database_schema.as_deref().unwrap_or("public");
200 let connection = sqlx_connect::<Postgres>(
201 max_connections,
202 acquire_timeout,
203 url.as_str(),
204 Some(schema),
205 )
206 .await?;
207 println!("Discovering schema ...");
208 let schema_discovery = SchemaDiscovery::new(connection, schema);
209 let schema = schema_discovery.discover().await?;
210 let table_stmts = schema
211 .tables
212 .into_iter()
213 .filter(|schema| filter_tables(&schema.info.name))
214 .filter(|schema| filter_hidden_tables(&schema.info.name))
215 .filter(|schema| filter_skip_tables(&schema.info.name))
216 .map(|schema| schema.write())
217 .collect();
218 (database_schema, table_stmts)
219 }
220 }
221 _ => unimplemented!("{} is not supported", url.scheme()),
222 };
223 println!("... discovered.");
224
225 let writer_context = EntityWriterContext::new(
226 if expanded_format {
227 EntityFormat::Expanded
228 } else if frontend_format {
229 EntityFormat::Frontend
230 } else if let Some(entity_format) = entity_format {
231 EntityFormat::from_str(&entity_format).expect("Invalid entity-format option")
232 } else {
233 EntityFormat::default()
234 },
235 WithPrelude::from_str(&with_prelude).expect("Invalid prelude option"),
236 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
237 with_copy_enums,
238 date_time_crate.into(),
239 big_integer_type.into(),
240 schema_name,
241 lib,
242 serde_skip_deserializing_primary_key,
243 serde_skip_hidden_column,
244 model_extra_derives,
245 model_extra_attributes,
246 enum_extra_derives,
247 enum_extra_attributes,
248 column_extra_derives,
249 seaography,
250 impl_active_model_behavior,
251 banner_version.into(),
252 );
253 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
254
255 let dir = Path::new(&output_dir);
256 fs::create_dir_all(dir)?;
257
258 let mut merge_fallback_files: Vec<String> = Vec::new();
259
260 for OutputFile { name, content } in output.files.iter() {
261 let file_path = dir.join(name);
262 println!("Writing {}", file_path.display());
263
264 if !matches!(
265 name.as_str(),
266 "mod.rs" | "lib.rs" | "prelude.rs" | "sea_orm_active_enums.rs"
267 ) && file_path.exists()
268 && preserve_user_modifications
269 {
270 let prev_content = fs::read_to_string(&file_path)?;
271 match merge_entity_files(&prev_content, content) {
272 Ok(merged) => {
273 fs::write(file_path, merged)?;
274 }
275 Err(MergeReport {
276 output,
277 warnings,
278 fallback_applied,
279 }) => {
280 for message in warnings {
281 eprintln!("{message}");
282 }
283 fs::write(file_path, output)?;
284 if fallback_applied {
285 merge_fallback_files.push(name.clone());
286 }
287 }
288 }
289 } else {
290 fs::write(file_path, content)?;
291 };
292 }
293
294 for OutputFile { name, .. } in output.files.iter() {
296 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
298 return Err(format!("Fail to format file `{name}`").into());
300 }
301 }
302
303 if merge_fallback_files.is_empty() {
304 println!("... Done.");
305 } else {
306 return Err(format!(
307 "Merge fallback applied for {} file(s): \n{}",
308 merge_fallback_files.len(),
309 merge_fallback_files.join("\n")
310 )
311 .into());
312 }
313 }
314 }
315
316 Ok(())
317}
318
319async fn sqlx_connect<DB>(
320 max_connections: u32,
321 acquire_timeout: u64,
322 url: &str,
323 schema: Option<&str>,
324) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
325where
326 DB: sqlx::Database,
327 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
328{
329 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new()
330 .max_connections(max_connections)
331 .acquire_timeout(time::Duration::from_secs(acquire_timeout));
332 if let Some(schema) = schema {
335 let sql = format!("SET search_path = '{schema}'");
336 pool_options = pool_options.after_connect(move |conn, _| {
337 let sql = sql.clone();
338 Box::pin(async move {
339 sqlx::Executor::execute(conn, sql.as_str())
340 .await
341 .map(|_| ())
342 })
343 });
344 }
345 pool_options.connect(url).await.map_err(Into::into)
346}
347
348impl From<DateTimeCrate> for CodegenDateTimeCrate {
349 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
350 match date_time_crate {
351 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
352 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
353 }
354 }
355}
356
357impl From<BigIntegerType> for CodegenBigIntegerType {
358 fn from(date_time_crate: BigIntegerType) -> CodegenBigIntegerType {
359 match date_time_crate {
360 BigIntegerType::I64 => CodegenBigIntegerType::I64,
361 BigIntegerType::I32 => CodegenBigIntegerType::I32,
362 }
363 }
364}
365
366impl From<BannerVersion> for CodegenBannerVersion {
367 fn from(banner_version: BannerVersion) -> CodegenBannerVersion {
368 match banner_version {
369 BannerVersion::Off => CodegenBannerVersion::Off,
370 BannerVersion::Major => CodegenBannerVersion::Major,
371 BannerVersion::Minor => CodegenBannerVersion::Minor,
372 BannerVersion::Patch => CodegenBannerVersion::Patch,
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use clap::Parser;
380
381 use super::*;
382 use crate::{Cli, Commands};
383
384 #[test]
385 #[should_panic(
386 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
387 )]
388 fn test_generate_entity_no_protocol() {
389 let cli = Cli::parse_from([
390 "sea-orm-cli",
391 "generate",
392 "entity",
393 "--database-url",
394 "://root:root@localhost:3306/database",
395 ]);
396
397 match cli.command {
398 Commands::Generate { command } => {
399 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
400 }
401 _ => unreachable!(),
402 }
403 }
404
405 #[test]
406 #[should_panic(
407 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
408 )]
409 fn test_generate_entity_no_database_section() {
410 let cli = Cli::parse_from([
411 "sea-orm-cli",
412 "generate",
413 "entity",
414 "--database-url",
415 "postgresql://root:root@localhost:3306",
416 ]);
417
418 match cli.command {
419 Commands::Generate { command } => {
420 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
421 }
422 _ => unreachable!(),
423 }
424 }
425
426 #[test]
427 #[should_panic(
428 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
429 )]
430 fn test_generate_entity_no_database_path() {
431 let cli = Cli::parse_from([
432 "sea-orm-cli",
433 "generate",
434 "entity",
435 "--database-url",
436 "mysql://root:root@localhost:3306/",
437 ]);
438
439 match cli.command {
440 Commands::Generate { command } => {
441 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
442 }
443 _ => unreachable!(),
444 }
445 }
446
447 #[test]
448 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
449 fn test_generate_entity_no_host() {
450 let cli = Cli::parse_from([
451 "sea-orm-cli",
452 "generate",
453 "entity",
454 "--database-url",
455 "postgres://root:root@/database",
456 ]);
457
458 match cli.command {
459 Commands::Generate { command } => {
460 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
461 }
462 _ => unreachable!(),
463 }
464 }
465}