1pub mod backend_trait;
2pub mod backends;
3pub mod resolve;
4pub mod validation;
5
6pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
7pub use backends::get_backend;
8
9use scythe_backend::manifest::BackendManifest;
10use scythe_backend::naming::{row_struct_name, to_pascal_case};
11
12use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
13use scythe_core::catalog::Catalog;
14use scythe_core::errors::ScytheError;
15use scythe_core::parser::QueryCommand;
16
17#[derive(Debug, Default)]
22pub struct GeneratedCode {
23 pub query_fn: Option<String>,
24 pub row_struct: Option<String>,
25 pub model_struct: Option<String>,
26 pub enum_def: Option<String>,
27}
28
29pub(crate) fn singularize(name: &str) -> String {
35 if let Some(stem) = name.strip_suffix("ies") {
36 format!("{stem}y")
37 } else if name.ends_with("sses")
38 || name.ends_with("shes")
39 || name.ends_with("ches")
40 || name.ends_with("xes")
41 || name.ends_with("zes")
42 || name.ends_with("ses")
43 {
44 name[..name.len() - 2].to_string()
45 } else if name.ends_with('s') && !name.ends_with("ss") {
46 name[..name.len() - 1].to_string()
47 } else {
48 name.to_string()
49 }
50}
51
52fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
57 match backend_name {
58 "rust-sqlx" | "sqlx" => {
59 let b = backends::sqlx::SqlxBackend::new()?;
60 Ok(b.manifest().clone())
61 }
62 "rust-tokio-postgres" | "tokio-postgres" => {
63 let b = backends::tokio_postgres::TokioPostgresBackend::new()?;
64 Ok(b.manifest().clone())
65 }
66 "go-pgx" => {
67 let b = backends::go_pgx::GoPgxBackend::new()?;
68 Ok(b.manifest().clone())
69 }
70 "java-jdbc" => {
71 let b = backends::java_jdbc::JavaJdbcBackend::new()?;
72 Ok(b.manifest().clone())
73 }
74 "kotlin-jdbc" => {
75 let b = backends::kotlin_jdbc::KotlinJdbcBackend::new()?;
76 Ok(b.manifest().clone())
77 }
78 "python-psycopg3" => {
79 let b = backends::python_psycopg3::PythonPsycopg3Backend::new()?;
80 Ok(b.manifest().clone())
81 }
82 "python-asyncpg" => {
83 let b = backends::python_asyncpg::PythonAsyncpgBackend::new()?;
84 Ok(b.manifest().clone())
85 }
86 "typescript-postgres" => {
87 let b = backends::typescript_postgres::TypescriptPostgresBackend::new()?;
88 Ok(b.manifest().clone())
89 }
90 "typescript-pg" => {
91 let b = backends::typescript_pg::TypescriptPgBackend::new()?;
92 Ok(b.manifest().clone())
93 }
94 "csharp-npgsql" => {
95 let b = backends::csharp_npgsql::CsharpNpgsqlBackend::new()?;
96 Ok(b.manifest().clone())
97 }
98 "elixir-postgrex" => {
99 let b = backends::elixir_postgrex::ElixirPostgrexBackend::new()?;
100 Ok(b.manifest().clone())
101 }
102 "ruby-pg" => {
103 let b = backends::ruby_pg::RubyPgBackend::new()?;
104 Ok(b.manifest().clone())
105 }
106 "php-pdo" => {
107 let b = backends::php_pdo::PhpPdoBackend::new()?;
108 Ok(b.manifest().clone())
109 }
110 _ => {
111 use scythe_core::errors::ErrorCode;
112 Err(ScytheError::new(
113 ErrorCode::InternalError,
114 format!("unknown backend: {}", backend_name),
115 ))
116 }
117 }
118}
119
120fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
122 if let Some(ref table_name) = analyzed.source_table {
123 let singular = singularize(table_name);
124 to_pascal_case(&singular).into_owned()
125 } else {
126 row_struct_name(&analyzed.name, &manifest.naming)
127 }
128}
129
130pub fn generate_with_backend(
136 analyzed: &AnalyzedQuery,
137 backend: &dyn CodegenBackend,
138) -> Result<GeneratedCode, ScytheError> {
139 let manifest = get_manifest_for_backend(backend.name())?;
140 let columns = resolve::resolve_columns(&analyzed.columns, &manifest)?;
141 let params = resolve::resolve_params(&analyzed.params, &manifest)?;
142
143 let mut result = GeneratedCode::default();
144
145 let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
148 if !enum_def.is_empty() {
149 result.enum_def = Some(enum_def);
150 }
151
152 let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
154 if needs_row_struct && !analyzed.columns.is_empty() {
155 if let Some(ref table_name) = analyzed.source_table {
156 result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
157 } else {
158 result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
159 }
160 }
161
162 if !analyzed.composites.is_empty() {
164 let mut comp_defs = String::new();
165 for (i, comp) in analyzed.composites.iter().enumerate() {
166 if i > 0 {
167 comp_defs.push_str("\n\n");
168 }
169 comp_defs.push_str(&backend.generate_composite_def(comp)?);
170 }
171 if !comp_defs.is_empty() {
172 if let Some(ref mut existing) = result.model_struct {
173 existing.push_str("\n\n");
174 existing.push_str(&comp_defs);
175 } else {
176 result.model_struct = Some(comp_defs);
177 }
178 }
179 }
180
181 let struct_name = determine_struct_name(analyzed, &manifest);
183 result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, ¶ms)?);
184
185 Ok(result)
186}
187
188fn generate_enum_defs_via_backend(
190 analyzed: &AnalyzedQuery,
191 backend: &dyn CodegenBackend,
192) -> Result<String, ScytheError> {
193 use ahash::AHashSet;
194 use std::fmt::Write;
195
196 let mut out = String::new();
197 let mut seen_enums: AHashSet<String> = AHashSet::new();
198
199 let enum_sources: Vec<&str> = analyzed
200 .columns
201 .iter()
202 .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
203 .chain(
204 analyzed
205 .params
206 .iter()
207 .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
208 )
209 .collect();
210
211 for sql_name in enum_sources {
212 if !seen_enums.insert(sql_name.to_string()) {
213 continue;
214 }
215
216 if !out.is_empty() {
217 let _ = writeln!(out);
218 }
219
220 if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
221 out.push_str(&backend.generate_enum_def(enum_info)?);
222 } else {
223 let stub_info = EnumInfo {
226 sql_name: sql_name.to_string(),
227 values: vec![],
228 };
229 out.push_str(&backend.generate_enum_def(&stub_info)?);
230 }
231 }
232
233 Ok(out)
234}
235
236pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
238 let backend = get_backend("rust-sqlx")?;
239 generate_with_backend(analyzed, &*backend)
240}
241
242pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
244 Ok(GeneratedCode::default())
245}
246
247pub fn generate_single_enum_def_with_backend(
249 enum_info: &EnumInfo,
250 backend: &dyn CodegenBackend,
251) -> Result<String, ScytheError> {
252 backend.generate_enum_def(enum_info)
253}
254
255pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
258 use scythe_backend::naming::{enum_type_name, enum_variant_name};
260 use std::fmt::Write;
261
262 let mut out = String::with_capacity(256);
263 let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
264
265 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
266 let _ = writeln!(
267 out,
268 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
269 enum_info.sql_name
270 );
271 let _ = writeln!(out, "pub enum {type_name} {{");
272
273 for value in &enum_info.values {
274 let variant = enum_variant_name(value, &manifest.naming);
275 let _ = writeln!(out, " {variant},");
276 }
277
278 let _ = write!(out, "}}");
279 out
280}
281
282pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
284 let b = backends::sqlx::SqlxBackend::new()?;
285 Ok(b.manifest().clone())
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295 use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
296 use scythe_core::parser::QueryCommand;
297
298 fn make_query(
299 name: &str,
300 command: QueryCommand,
301 sql: &str,
302 columns: Vec<AnalyzedColumn>,
303 params: Vec<AnalyzedParam>,
304 ) -> AnalyzedQuery {
305 AnalyzedQuery {
306 name: name.to_string(),
307 command,
308 sql: sql.to_string(),
309 columns,
310 params,
311 deprecated: None,
312 source_table: None,
313 composites: Vec::new(),
314 enums: Vec::new(),
315 }
316 }
317
318 #[test]
319 fn test_generate_select_many() {
320 let query = make_query(
321 "ListUsers",
322 QueryCommand::Many,
323 "SELECT id, name, email FROM users",
324 vec![
325 AnalyzedColumn {
326 name: "id".to_string(),
327 neutral_type: "int32".to_string(),
328 nullable: false,
329 },
330 AnalyzedColumn {
331 name: "name".to_string(),
332 neutral_type: "string".to_string(),
333 nullable: false,
334 },
335 AnalyzedColumn {
336 name: "email".to_string(),
337 neutral_type: "string".to_string(),
338 nullable: true,
339 },
340 ],
341 vec![],
342 );
343
344 let result = generate(&query).unwrap();
345
346 let row_struct = result.row_struct.unwrap();
347 assert!(row_struct.contains("pub struct ListUsersRow"));
348 assert!(row_struct.contains("pub id: i32"));
349 assert!(row_struct.contains("pub name: String"));
350 assert!(row_struct.contains("pub email: Option<String>"));
351
352 let query_fn = result.query_fn.unwrap();
353 assert!(query_fn.contains("pub async fn list_users("));
354 assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
355 assert!(query_fn.contains(".fetch_all(pool)"));
356 }
357
358 #[test]
359 fn test_generate_select_one_with_param() {
360 let query = make_query(
361 "GetUser",
362 QueryCommand::One,
363 "SELECT id, name FROM users WHERE id = $1",
364 vec![
365 AnalyzedColumn {
366 name: "id".to_string(),
367 neutral_type: "int32".to_string(),
368 nullable: false,
369 },
370 AnalyzedColumn {
371 name: "name".to_string(),
372 neutral_type: "string".to_string(),
373 nullable: false,
374 },
375 ],
376 vec![AnalyzedParam {
377 name: "id".to_string(),
378 neutral_type: "int32".to_string(),
379 nullable: false,
380 position: 1,
381 }],
382 );
383
384 let result = generate(&query).unwrap();
385
386 let query_fn = result.query_fn.unwrap();
387 assert!(query_fn.contains("pub async fn get_user("));
388 assert!(query_fn.contains("id: i32"));
389 assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
390 assert!(query_fn.contains(".fetch_one(pool)"));
391 }
392
393 #[test]
394 fn test_generate_exec() {
395 let query = make_query(
396 "DeleteUser",
397 QueryCommand::Exec,
398 "DELETE FROM users WHERE id = $1",
399 vec![],
400 vec![AnalyzedParam {
401 name: "id".to_string(),
402 neutral_type: "int32".to_string(),
403 nullable: false,
404 position: 1,
405 }],
406 );
407
408 let result = generate(&query).unwrap();
409
410 assert!(result.row_struct.is_none());
411
412 let query_fn = result.query_fn.unwrap();
413 assert!(query_fn.contains("pub async fn delete_user("));
414 assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
415 assert!(query_fn.contains(".execute(pool)"));
416 }
417
418 #[test]
419 fn test_generate_with_enum_column() {
420 let query = make_query(
421 "GetUserStatus",
422 QueryCommand::One,
423 "SELECT id, status FROM users WHERE id = $1",
424 vec![
425 AnalyzedColumn {
426 name: "id".to_string(),
427 neutral_type: "int32".to_string(),
428 nullable: false,
429 },
430 AnalyzedColumn {
431 name: "status".to_string(),
432 neutral_type: "enum::user_status".to_string(),
433 nullable: false,
434 },
435 ],
436 vec![AnalyzedParam {
437 name: "id".to_string(),
438 neutral_type: "int32".to_string(),
439 nullable: false,
440 position: 1,
441 }],
442 );
443
444 let result = generate(&query).unwrap();
445
446 assert!(result.enum_def.is_some());
447 let enum_def = result.enum_def.unwrap();
448 assert!(enum_def.contains("pub enum UserStatus"));
449 assert!(enum_def.contains("type_name = \"user_status\""));
450
451 let row_struct = result.row_struct.unwrap();
452 assert!(row_struct.contains("pub status: UserStatus"));
453 }
454
455 #[test]
456 fn test_generate_from_catalog_returns_default() {
457 let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
458 let result = generate_from_catalog(&catalog).unwrap();
459 assert!(result.query_fn.is_none());
460 assert!(result.row_struct.is_none());
461 }
462
463 #[test]
464 fn test_singularize_basic() {
465 assert_eq!(singularize("users"), "user");
466 assert_eq!(singularize("orders"), "order");
467 assert_eq!(singularize("posts"), "post");
468 }
469
470 #[test]
471 fn test_singularize_ies() {
472 assert_eq!(singularize("categories"), "category");
473 assert_eq!(singularize("entries"), "entry");
474 }
475
476 #[test]
477 fn test_singularize_sses() {
478 assert_eq!(singularize("addresses"), "address");
479 assert_eq!(singularize("classes"), "class");
480 }
481
482 #[test]
483 fn test_singularize_no_change() {
484 assert_eq!(singularize("status"), "statu");
485 assert_eq!(singularize("boss"), "boss");
486 assert_eq!(singularize("address"), "address");
487 }
488
489 #[test]
490 fn test_singularize_shes_ches_xes() {
491 assert_eq!(singularize("batches"), "batch");
492 assert_eq!(singularize("boxes"), "box");
493 assert_eq!(singularize("wishes"), "wish");
494 }
495
496 #[test]
497 fn test_tokio_postgres_backend_basic() {
498 let backend = get_backend("tokio-postgres").unwrap();
499
500 let query = make_query(
501 "ListUsers",
502 QueryCommand::Many,
503 "SELECT id, name FROM users",
504 vec![
505 AnalyzedColumn {
506 name: "id".to_string(),
507 neutral_type: "int32".to_string(),
508 nullable: false,
509 },
510 AnalyzedColumn {
511 name: "name".to_string(),
512 neutral_type: "string".to_string(),
513 nullable: false,
514 },
515 ],
516 vec![],
517 );
518
519 let result = generate_with_backend(&query, &*backend).unwrap();
520
521 let row_struct = result.row_struct.unwrap();
522 assert!(row_struct.contains("pub struct ListUsersRow"));
523 assert!(row_struct.contains("pub id: i32"));
524 assert!(row_struct.contains("pub name: String"));
525 assert!(row_struct.contains("from_row"));
526 assert!(row_struct.contains("tokio_postgres::Row"));
527 assert!(!row_struct.contains("sqlx"));
529
530 let query_fn = result.query_fn.unwrap();
531 assert!(query_fn.contains("pub async fn list_users("));
532 assert!(query_fn.contains("tokio_postgres::Client"));
533 assert!(query_fn.contains("tokio_postgres::Error"));
534 assert!(!query_fn.contains("sqlx"));
535 }
536
537 #[test]
538 fn test_tokio_postgres_enum() {
539 let backend = get_backend("tokio-postgres").unwrap();
540
541 let enum_info = scythe_core::analyzer::EnumInfo {
542 sql_name: "user_status".to_string(),
543 values: vec!["active".to_string(), "inactive".to_string()],
544 };
545
546 let def = backend.generate_enum_def(&enum_info).unwrap();
547 assert!(def.contains("pub enum UserStatus"));
548 assert!(def.contains("Active"));
549 assert!(def.contains("Inactive"));
550 assert!(def.contains("impl std::fmt::Display"));
551 assert!(def.contains("impl std::str::FromStr"));
552 assert!(!def.contains("sqlx"));
554 }
555}