1pub mod backend_trait;
2pub mod backends;
3pub mod resolve;
4pub mod validation;
5
6pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
7pub use backends::get_backend;
8
9use scythe_backend::manifest::BackendManifest;
10use scythe_backend::naming::{row_struct_name, to_pascal_case};
11
12use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
13use scythe_core::catalog::Catalog;
14use scythe_core::errors::ScytheError;
15use scythe_core::parser::QueryCommand;
16
17#[derive(Debug, Default)]
22pub struct GeneratedCode {
23 pub query_fn: Option<String>,
24 pub row_struct: Option<String>,
25 pub model_struct: Option<String>,
26 pub enum_def: Option<String>,
27}
28
29pub(crate) fn singularize(name: &str) -> String {
35 if let Some(stem) = name.strip_suffix("ies") {
36 format!("{stem}y")
37 } else if name.ends_with("sses")
38 || name.ends_with("shes")
39 || name.ends_with("ches")
40 || name.ends_with("xes")
41 || name.ends_with("zes")
42 || name.ends_with("ses")
43 {
44 name[..name.len() - 2].to_string()
45 } else if name.ends_with('s') && !name.ends_with("ss") {
46 name[..name.len() - 1].to_string()
47 } else {
48 name.to_string()
49 }
50}
51
52pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
58 let backend = get_backend(backend_name, "postgresql")?;
59 Ok(backend.manifest().clone())
60}
61
62fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
64 if let Some(ref table_name) = analyzed.source_table {
65 let singular = singularize(table_name);
66 to_pascal_case(&singular).into_owned()
67 } else {
68 row_struct_name(&analyzed.name, &manifest.naming)
69 }
70}
71
72pub fn generate_with_backend(
78 analyzed: &AnalyzedQuery,
79 backend: &dyn CodegenBackend,
80) -> Result<GeneratedCode, ScytheError> {
81 let manifest = backend.manifest();
82 let columns = resolve::resolve_columns(&analyzed.columns, manifest)?;
83 let params = resolve::resolve_params(&analyzed.params, manifest)?;
84
85 let mut result = GeneratedCode::default();
86
87 let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
90 if !enum_def.is_empty() {
91 result.enum_def = Some(enum_def);
92 }
93
94 let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
96 if needs_row_struct && !analyzed.columns.is_empty() {
97 if let Some(ref table_name) = analyzed.source_table {
98 result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
99 } else {
100 result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
101 }
102 }
103
104 if !analyzed.composites.is_empty() {
106 let mut comp_defs = String::new();
107 for (i, comp) in analyzed.composites.iter().enumerate() {
108 if i > 0 {
109 comp_defs.push_str("\n\n");
110 }
111 comp_defs.push_str(&backend.generate_composite_def(comp)?);
112 }
113 if !comp_defs.is_empty() {
114 if let Some(ref mut existing) = result.model_struct {
115 existing.push_str("\n\n");
116 existing.push_str(&comp_defs);
117 } else {
118 result.model_struct = Some(comp_defs);
119 }
120 }
121 }
122
123 let struct_name = determine_struct_name(analyzed, manifest);
125 result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, ¶ms)?);
126
127 Ok(result)
128}
129
130fn generate_enum_defs_via_backend(
132 analyzed: &AnalyzedQuery,
133 backend: &dyn CodegenBackend,
134) -> Result<String, ScytheError> {
135 use ahash::AHashSet;
136 use std::fmt::Write;
137
138 let mut out = String::new();
139 let mut seen_enums: AHashSet<String> = AHashSet::new();
140
141 let enum_sources: Vec<&str> = analyzed
142 .columns
143 .iter()
144 .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
145 .chain(
146 analyzed
147 .params
148 .iter()
149 .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
150 )
151 .collect();
152
153 for sql_name in enum_sources {
154 if !seen_enums.insert(sql_name.to_string()) {
155 continue;
156 }
157
158 if !out.is_empty() {
159 let _ = writeln!(out);
160 }
161
162 if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
163 out.push_str(&backend.generate_enum_def(enum_info)?);
164 } else {
165 let stub_info = EnumInfo {
168 sql_name: sql_name.to_string(),
169 values: vec![],
170 };
171 out.push_str(&backend.generate_enum_def(&stub_info)?);
172 }
173 }
174
175 Ok(out)
176}
177
178pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
180 let backend = get_backend("rust-sqlx", "postgresql")?;
181 generate_with_backend(analyzed, &*backend)
182}
183
184pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
186 Ok(GeneratedCode::default())
187}
188
189pub fn generate_single_enum_def_with_backend(
191 enum_info: &EnumInfo,
192 backend: &dyn CodegenBackend,
193) -> Result<String, ScytheError> {
194 backend.generate_enum_def(enum_info)
195}
196
197pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
200 use scythe_backend::naming::{enum_type_name, enum_variant_name};
202 use std::fmt::Write;
203
204 let mut out = String::with_capacity(256);
205 let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
206
207 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
208 let _ = writeln!(
209 out,
210 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
211 enum_info.sql_name
212 );
213 let _ = writeln!(out, "pub enum {type_name} {{");
214
215 for value in &enum_info.values {
216 let variant = enum_variant_name(value, &manifest.naming);
217 let _ = writeln!(out, " {variant},");
218 }
219
220 let _ = write!(out, "}}");
221 out
222}
223
224pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
226 let b = backends::sqlx::SqlxBackend::new("postgresql")?;
227 Ok(b.manifest().clone())
228}
229
230#[cfg(test)]
235mod tests {
236 use super::*;
237 use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
238 use scythe_core::parser::QueryCommand;
239
240 fn make_query(
241 name: &str,
242 command: QueryCommand,
243 sql: &str,
244 columns: Vec<AnalyzedColumn>,
245 params: Vec<AnalyzedParam>,
246 ) -> AnalyzedQuery {
247 AnalyzedQuery {
248 name: name.to_string(),
249 command,
250 sql: sql.to_string(),
251 columns,
252 params,
253 deprecated: None,
254 source_table: None,
255 composites: Vec::new(),
256 enums: Vec::new(),
257 }
258 }
259
260 #[test]
261 fn test_generate_select_many() {
262 let query = make_query(
263 "ListUsers",
264 QueryCommand::Many,
265 "SELECT id, name, email FROM users",
266 vec![
267 AnalyzedColumn {
268 name: "id".to_string(),
269 neutral_type: "int32".to_string(),
270 nullable: false,
271 },
272 AnalyzedColumn {
273 name: "name".to_string(),
274 neutral_type: "string".to_string(),
275 nullable: false,
276 },
277 AnalyzedColumn {
278 name: "email".to_string(),
279 neutral_type: "string".to_string(),
280 nullable: true,
281 },
282 ],
283 vec![],
284 );
285
286 let result = generate(&query).unwrap();
287
288 let row_struct = result.row_struct.unwrap();
289 assert!(row_struct.contains("pub struct ListUsersRow"));
290 assert!(row_struct.contains("pub id: i32"));
291 assert!(row_struct.contains("pub name: String"));
292 assert!(row_struct.contains("pub email: Option<String>"));
293
294 let query_fn = result.query_fn.unwrap();
295 assert!(query_fn.contains("pub async fn list_users("));
296 assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
297 assert!(query_fn.contains(".fetch_all(pool)"));
298 }
299
300 #[test]
301 fn test_generate_select_one_with_param() {
302 let query = make_query(
303 "GetUser",
304 QueryCommand::One,
305 "SELECT id, name FROM users WHERE id = $1",
306 vec![
307 AnalyzedColumn {
308 name: "id".to_string(),
309 neutral_type: "int32".to_string(),
310 nullable: false,
311 },
312 AnalyzedColumn {
313 name: "name".to_string(),
314 neutral_type: "string".to_string(),
315 nullable: false,
316 },
317 ],
318 vec![AnalyzedParam {
319 name: "id".to_string(),
320 neutral_type: "int32".to_string(),
321 nullable: false,
322 position: 1,
323 }],
324 );
325
326 let result = generate(&query).unwrap();
327
328 let query_fn = result.query_fn.unwrap();
329 assert!(query_fn.contains("pub async fn get_user("));
330 assert!(query_fn.contains("id: i32"));
331 assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
332 assert!(query_fn.contains(".fetch_one(pool)"));
333 }
334
335 #[test]
336 fn test_generate_exec() {
337 let query = make_query(
338 "DeleteUser",
339 QueryCommand::Exec,
340 "DELETE FROM users WHERE id = $1",
341 vec![],
342 vec![AnalyzedParam {
343 name: "id".to_string(),
344 neutral_type: "int32".to_string(),
345 nullable: false,
346 position: 1,
347 }],
348 );
349
350 let result = generate(&query).unwrap();
351
352 assert!(result.row_struct.is_none());
353
354 let query_fn = result.query_fn.unwrap();
355 assert!(query_fn.contains("pub async fn delete_user("));
356 assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
357 assert!(query_fn.contains(".execute(pool)"));
358 }
359
360 #[test]
361 fn test_generate_with_enum_column() {
362 let query = make_query(
363 "GetUserStatus",
364 QueryCommand::One,
365 "SELECT id, status FROM users WHERE id = $1",
366 vec![
367 AnalyzedColumn {
368 name: "id".to_string(),
369 neutral_type: "int32".to_string(),
370 nullable: false,
371 },
372 AnalyzedColumn {
373 name: "status".to_string(),
374 neutral_type: "enum::user_status".to_string(),
375 nullable: false,
376 },
377 ],
378 vec![AnalyzedParam {
379 name: "id".to_string(),
380 neutral_type: "int32".to_string(),
381 nullable: false,
382 position: 1,
383 }],
384 );
385
386 let result = generate(&query).unwrap();
387
388 assert!(result.enum_def.is_some());
389 let enum_def = result.enum_def.unwrap();
390 assert!(enum_def.contains("pub enum UserStatus"));
391 assert!(enum_def.contains("type_name = \"user_status\""));
392
393 let row_struct = result.row_struct.unwrap();
394 assert!(row_struct.contains("pub status: UserStatus"));
395 }
396
397 #[test]
398 fn test_generate_from_catalog_returns_default() {
399 let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
400 let result = generate_from_catalog(&catalog).unwrap();
401 assert!(result.query_fn.is_none());
402 assert!(result.row_struct.is_none());
403 }
404
405 #[test]
406 fn test_singularize_basic() {
407 assert_eq!(singularize("users"), "user");
408 assert_eq!(singularize("orders"), "order");
409 assert_eq!(singularize("posts"), "post");
410 }
411
412 #[test]
413 fn test_singularize_ies() {
414 assert_eq!(singularize("categories"), "category");
415 assert_eq!(singularize("entries"), "entry");
416 }
417
418 #[test]
419 fn test_singularize_sses() {
420 assert_eq!(singularize("addresses"), "address");
421 assert_eq!(singularize("classes"), "class");
422 }
423
424 #[test]
425 fn test_singularize_no_change() {
426 assert_eq!(singularize("status"), "statu");
427 assert_eq!(singularize("boss"), "boss");
428 assert_eq!(singularize("address"), "address");
429 }
430
431 #[test]
432 fn test_singularize_shes_ches_xes() {
433 assert_eq!(singularize("batches"), "batch");
434 assert_eq!(singularize("boxes"), "box");
435 assert_eq!(singularize("wishes"), "wish");
436 }
437
438 #[test]
439 fn test_tokio_postgres_backend_basic() {
440 let backend = get_backend("tokio-postgres", "postgresql").unwrap();
441
442 let query = make_query(
443 "ListUsers",
444 QueryCommand::Many,
445 "SELECT id, name FROM users",
446 vec![
447 AnalyzedColumn {
448 name: "id".to_string(),
449 neutral_type: "int32".to_string(),
450 nullable: false,
451 },
452 AnalyzedColumn {
453 name: "name".to_string(),
454 neutral_type: "string".to_string(),
455 nullable: false,
456 },
457 ],
458 vec![],
459 );
460
461 let result = generate_with_backend(&query, &*backend).unwrap();
462
463 let row_struct = result.row_struct.unwrap();
464 assert!(row_struct.contains("pub struct ListUsersRow"));
465 assert!(row_struct.contains("pub id: i32"));
466 assert!(row_struct.contains("pub name: String"));
467 assert!(row_struct.contains("from_row"));
468 assert!(row_struct.contains("tokio_postgres::Row"));
469 assert!(!row_struct.contains("sqlx"));
471
472 let query_fn = result.query_fn.unwrap();
473 assert!(query_fn.contains("pub async fn list_users("));
474 assert!(query_fn.contains("tokio_postgres::Client"));
475 assert!(query_fn.contains("tokio_postgres::Error"));
476 assert!(!query_fn.contains("sqlx"));
477 }
478
479 #[test]
480 fn test_tokio_postgres_enum() {
481 let backend = get_backend("tokio-postgres", "postgresql").unwrap();
482
483 let enum_info = scythe_core::analyzer::EnumInfo {
484 sql_name: "user_status".to_string(),
485 values: vec!["active".to_string(), "inactive".to_string()],
486 };
487
488 let def = backend.generate_enum_def(&enum_info).unwrap();
489 assert!(def.contains("pub enum UserStatus"));
490 assert!(def.contains("Active"));
491 assert!(def.contains("Inactive"));
492 assert!(def.contains("impl std::fmt::Display"));
493 assert!(def.contains("impl std::str::FromStr"));
494 assert!(!def.contains("sqlx"));
496 }
497}