1use sea_orm_codegen::{
2 DateTimeCrate as CodegenDateTimeCrate, EntityTransformer, EntityWriterContext, OutputFile,
3 WithSerde,
4};
5use std::{error::Error, fs, io::Write, path::Path, process::Command, str::FromStr};
6use tracing_subscriber::{prelude::*, EnvFilter};
7use url::Url;
8
9use crate::{DateTimeCrate, GenerateSubcommands};
10
11pub async fn run_generate_command(
12 command: GenerateSubcommands,
13 verbose: bool,
14) -> Result<(), Box<dyn Error>> {
15 match command {
16 GenerateSubcommands::Entity {
17 compact_format: _,
18 expanded_format,
19 include_hidden_tables,
20 tables,
21 ignore_tables,
22 max_connections,
23 output_dir,
24 database_schema,
25 database_url,
26 with_serde,
27 serde_skip_deserializing_primary_key,
28 serde_skip_hidden_column,
29 with_copy_enums,
30 date_time_crate,
31 lib,
32 model_extra_derives,
33 model_extra_attributes,
34 enum_extra_derives,
35 enum_extra_attributes,
36 seaography,
37 } => {
38 if verbose {
39 let _ = tracing_subscriber::fmt()
40 .with_max_level(tracing::Level::DEBUG)
41 .with_test_writer()
42 .try_init();
43 } else {
44 let filter_layer = EnvFilter::try_new("sea_orm_codegen=info").unwrap();
45 let fmt_layer = tracing_subscriber::fmt::layer()
46 .with_target(false)
47 .with_level(false)
48 .without_time();
49
50 let _ = tracing_subscriber::registry()
51 .with(filter_layer)
52 .with(fmt_layer)
53 .try_init();
54 }
55
56 let url = Url::parse(&database_url)?;
59
60 let is_sqlite = url.scheme() == "sqlite";
65
66 let filter_tables =
68 |table: &String| -> bool { tables.is_empty() || tables.contains(table) };
69
70 let filter_hidden_tables = |table: &str| -> bool {
71 if include_hidden_tables {
72 true
73 } else {
74 !table.starts_with('_')
75 }
76 };
77
78 let filter_skip_tables = |table: &String| -> bool { !ignore_tables.contains(table) };
79
80 let database_name = if !is_sqlite {
81 let database_name = url
87 .path_segments()
88 .unwrap_or_else(|| {
89 panic!(
90 "There is no database name as part of the url path: {}",
91 url.as_str()
92 )
93 })
94 .next()
95 .unwrap();
96
97 if database_name.is_empty() {
99 panic!(
100 "There is no database name as part of the url path: {}",
101 url.as_str()
102 );
103 }
104
105 database_name
106 } else {
107 Default::default()
108 };
109
110 let (schema_name, table_stmts) = match url.scheme() {
111 "mysql" => {
112 use sea_schema::mysql::discovery::SchemaDiscovery;
113 use sqlx::MySql;
114
115 println!("Connecting to MySQL ...");
116 let connection =
117 sqlx_connect::<MySql>(max_connections, url.as_str(), None).await?;
118 println!("Discovering schema ...");
119 let schema_discovery = SchemaDiscovery::new(connection, database_name);
120 let schema = schema_discovery.discover().await?;
121 let table_stmts = schema
122 .tables
123 .into_iter()
124 .filter(|schema| filter_tables(&schema.info.name))
125 .filter(|schema| filter_hidden_tables(&schema.info.name))
126 .filter(|schema| filter_skip_tables(&schema.info.name))
127 .map(|schema| schema.write())
128 .collect();
129 (None, table_stmts)
130 }
131 "sqlite" => {
132 use sea_schema::sqlite::discovery::SchemaDiscovery;
133 use sqlx::Sqlite;
134
135 println!("Connecting to SQLite ...");
136 let connection =
137 sqlx_connect::<Sqlite>(max_connections, url.as_str(), None).await?;
138 println!("Discovering schema ...");
139 let schema_discovery = SchemaDiscovery::new(connection);
140 let schema = schema_discovery
141 .discover()
142 .await?
143 .merge_indexes_into_table();
144 let table_stmts = schema
145 .tables
146 .into_iter()
147 .filter(|schema| filter_tables(&schema.name))
148 .filter(|schema| filter_hidden_tables(&schema.name))
149 .filter(|schema| filter_skip_tables(&schema.name))
150 .map(|schema| schema.write())
151 .collect();
152 (None, table_stmts)
153 }
154 "postgres" | "postgresql" => {
155 use sea_schema::postgres::discovery::SchemaDiscovery;
156 use sqlx::Postgres;
157
158 println!("Connecting to Postgres ...");
159 let schema = database_schema.as_deref().unwrap_or("public");
160 let connection =
161 sqlx_connect::<Postgres>(max_connections, url.as_str(), Some(schema))
162 .await?;
163 println!("Discovering schema ...");
164 let schema_discovery = SchemaDiscovery::new(connection, schema);
165 let schema = schema_discovery.discover().await?;
166 let table_stmts = schema
167 .tables
168 .into_iter()
169 .filter(|schema| filter_tables(&schema.info.name))
170 .filter(|schema| filter_hidden_tables(&schema.info.name))
171 .filter(|schema| filter_skip_tables(&schema.info.name))
172 .map(|schema| schema.write())
173 .collect();
174 (database_schema, table_stmts)
175 }
176 _ => unimplemented!("{} is not supported", url.scheme()),
177 };
178 println!("... discovered.");
179
180 let writer_context = EntityWriterContext::new(
181 expanded_format,
182 WithSerde::from_str(&with_serde).expect("Invalid serde derive option"),
183 with_copy_enums,
184 date_time_crate.into(),
185 schema_name,
186 lib,
187 serde_skip_deserializing_primary_key,
188 serde_skip_hidden_column,
189 model_extra_derives,
190 model_extra_attributes,
191 enum_extra_derives,
192 enum_extra_attributes,
193 seaography,
194 );
195 let output = EntityTransformer::transform(table_stmts)?.generate(&writer_context);
196
197 let dir = Path::new(&output_dir);
198 fs::create_dir_all(dir)?;
199
200 for OutputFile { name, content } in output.files.iter() {
201 let file_path = dir.join(name);
202 println!("Writing {}", file_path.display());
203 let mut file = fs::File::create(file_path)?;
204 file.write_all(content.as_bytes())?;
205 }
206
207 for OutputFile { name, .. } in output.files.iter() {
209 let exit_status = Command::new("rustfmt").arg(dir.join(name)).status()?; if !exit_status.success() {
211 return Err(format!("Fail to format file `{name}`").into());
213 }
214 }
215
216 println!("... Done.");
217 }
218 }
219
220 Ok(())
221}
222
223async fn sqlx_connect<DB>(
224 max_connections: u32,
225 url: &str,
226 schema: Option<&str>,
227) -> Result<sqlx::Pool<DB>, Box<dyn Error>>
228where
229 DB: sqlx::Database,
230 for<'a> &'a mut <DB as sqlx::Database>::Connection: sqlx::Executor<'a>,
231{
232 let mut pool_options = sqlx::pool::PoolOptions::<DB>::new().max_connections(max_connections);
233 if let Some(schema) = schema {
236 let sql = format!("SET search_path = '{schema}'");
237 pool_options = pool_options.after_connect(move |conn, _| {
238 let sql = sql.clone();
239 Box::pin(async move {
240 sqlx::Executor::execute(conn, sql.as_str())
241 .await
242 .map(|_| ())
243 })
244 });
245 }
246 pool_options.connect(url).await.map_err(Into::into)
247}
248
249impl From<DateTimeCrate> for CodegenDateTimeCrate {
250 fn from(date_time_crate: DateTimeCrate) -> CodegenDateTimeCrate {
251 match date_time_crate {
252 DateTimeCrate::Chrono => CodegenDateTimeCrate::Chrono,
253 DateTimeCrate::Time => CodegenDateTimeCrate::Time,
254 }
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use clap::Parser;
261
262 use super::*;
263 use crate::{Cli, Commands};
264
265 #[test]
266 #[should_panic(
267 expected = "called `Result::unwrap()` on an `Err` value: RelativeUrlWithoutBase"
268 )]
269 fn test_generate_entity_no_protocol() {
270 let cli = Cli::parse_from([
271 "sea-orm-cli",
272 "generate",
273 "entity",
274 "--database-url",
275 "://root:root@localhost:3306/database",
276 ]);
277
278 match cli.command {
279 Commands::Generate { command } => {
280 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
281 }
282 _ => unreachable!(),
283 }
284 }
285
286 #[test]
287 #[should_panic(
288 expected = "There is no database name as part of the url path: postgresql://root:root@localhost:3306"
289 )]
290 fn test_generate_entity_no_database_section() {
291 let cli = Cli::parse_from([
292 "sea-orm-cli",
293 "generate",
294 "entity",
295 "--database-url",
296 "postgresql://root:root@localhost:3306",
297 ]);
298
299 match cli.command {
300 Commands::Generate { command } => {
301 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
302 }
303 _ => unreachable!(),
304 }
305 }
306
307 #[test]
308 #[should_panic(
309 expected = "There is no database name as part of the url path: mysql://root:root@localhost:3306/"
310 )]
311 fn test_generate_entity_no_database_path() {
312 let cli = Cli::parse_from([
313 "sea-orm-cli",
314 "generate",
315 "entity",
316 "--database-url",
317 "mysql://root:root@localhost:3306/",
318 ]);
319
320 match cli.command {
321 Commands::Generate { command } => {
322 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
323 }
324 _ => unreachable!(),
325 }
326 }
327
328 #[test]
329 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: PoolTimedOut")]
330 fn test_generate_entity_no_password() {
331 let cli = Cli::parse_from([
332 "sea-orm-cli",
333 "generate",
334 "entity",
335 "--database-url",
336 "mysql://root:@localhost:3306/database",
337 ]);
338
339 match cli.command {
340 Commands::Generate { command } => {
341 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
342 }
343 _ => unreachable!(),
344 }
345 }
346
347 #[test]
348 #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: EmptyHost")]
349 fn test_generate_entity_no_host() {
350 let cli = Cli::parse_from([
351 "sea-orm-cli",
352 "generate",
353 "entity",
354 "--database-url",
355 "postgres://root:root@/database",
356 ]);
357
358 match cli.command {
359 Commands::Generate { command } => {
360 smol::block_on(run_generate_command(command, cli.verbose)).unwrap();
361 }
362 _ => unreachable!(),
363 }
364 }
365}