Skip to main content

scythe_codegen/
lib.rs

1pub mod backend_trait;
2pub mod backends;
3pub mod resolve;
4pub mod validation;
5
6pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
7pub use backends::get_backend;
8
9use scythe_backend::manifest::BackendManifest;
10use scythe_backend::naming::{row_struct_name, to_pascal_case};
11
12use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
13use scythe_core::catalog::Catalog;
14use scythe_core::errors::ScytheError;
15use scythe_core::parser::QueryCommand;
16
17// ---------------------------------------------------------------------------
18// Output types
19// ---------------------------------------------------------------------------
20
21#[derive(Debug, Default)]
22pub struct GeneratedCode {
23    pub query_fn: Option<String>,
24    pub row_struct: Option<String>,
25    pub model_struct: Option<String>,
26    pub enum_def: Option<String>,
27}
28
29// ---------------------------------------------------------------------------
30// Utility (shared across backends)
31// ---------------------------------------------------------------------------
32
33/// Simple singularization: remove trailing 's'.
34pub(crate) fn singularize(name: &str) -> String {
35    if let Some(stem) = name.strip_suffix("ies") {
36        format!("{stem}y")
37    } else if name.ends_with("sses")
38        || name.ends_with("shes")
39        || name.ends_with("ches")
40        || name.ends_with("xes")
41        || name.ends_with("zes")
42        || name.ends_with("ses")
43    {
44        name[..name.len() - 2].to_string()
45    } else if name.ends_with('s') && !name.ends_with("ss") {
46        name[..name.len() - 1].to_string()
47    } else {
48        name.to_string()
49    }
50}
51
52// ---------------------------------------------------------------------------
53// Manifest helpers
54// ---------------------------------------------------------------------------
55
56fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
57    match backend_name {
58        "rust-sqlx" | "sqlx" => {
59            let b = backends::sqlx::SqlxBackend::new()?;
60            Ok(b.manifest().clone())
61        }
62        "rust-tokio-postgres" | "tokio-postgres" => {
63            let b = backends::tokio_postgres::TokioPostgresBackend::new()?;
64            Ok(b.manifest().clone())
65        }
66        "go-pgx" => {
67            let b = backends::go_pgx::GoPgxBackend::new()?;
68            Ok(b.manifest().clone())
69        }
70        "java-jdbc" => {
71            let b = backends::java_jdbc::JavaJdbcBackend::new()?;
72            Ok(b.manifest().clone())
73        }
74        "kotlin-jdbc" => {
75            let b = backends::kotlin_jdbc::KotlinJdbcBackend::new()?;
76            Ok(b.manifest().clone())
77        }
78        "python-psycopg3" => {
79            let b = backends::python_psycopg3::PythonPsycopg3Backend::new()?;
80            Ok(b.manifest().clone())
81        }
82        "python-asyncpg" => {
83            let b = backends::python_asyncpg::PythonAsyncpgBackend::new()?;
84            Ok(b.manifest().clone())
85        }
86        "typescript-postgres" => {
87            let b = backends::typescript_postgres::TypescriptPostgresBackend::new()?;
88            Ok(b.manifest().clone())
89        }
90        "typescript-pg" => {
91            let b = backends::typescript_pg::TypescriptPgBackend::new()?;
92            Ok(b.manifest().clone())
93        }
94        "csharp-npgsql" => {
95            let b = backends::csharp_npgsql::CsharpNpgsqlBackend::new()?;
96            Ok(b.manifest().clone())
97        }
98        "elixir-postgrex" => {
99            let b = backends::elixir_postgrex::ElixirPostgrexBackend::new()?;
100            Ok(b.manifest().clone())
101        }
102        "ruby-pg" => {
103            let b = backends::ruby_pg::RubyPgBackend::new()?;
104            Ok(b.manifest().clone())
105        }
106        "php-pdo" => {
107            let b = backends::php_pdo::PhpPdoBackend::new()?;
108            Ok(b.manifest().clone())
109        }
110        _ => {
111            use scythe_core::errors::ErrorCode;
112            Err(ScytheError::new(
113                ErrorCode::InternalError,
114                format!("unknown backend: {}", backend_name),
115            ))
116        }
117    }
118}
119
120/// Determine the struct name for a query (model struct or row struct).
121fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
122    if let Some(ref table_name) = analyzed.source_table {
123        let singular = singularize(table_name);
124        to_pascal_case(&singular).into_owned()
125    } else {
126        row_struct_name(&analyzed.name, &manifest.naming)
127    }
128}
129
130// ---------------------------------------------------------------------------
131// Public API
132// ---------------------------------------------------------------------------
133
134/// Generate code using a specific backend.
135pub fn generate_with_backend(
136    analyzed: &AnalyzedQuery,
137    backend: &dyn CodegenBackend,
138) -> Result<GeneratedCode, ScytheError> {
139    let manifest = get_manifest_for_backend(backend.name())?;
140    let columns = resolve::resolve_columns(&analyzed.columns, &manifest)?;
141    let params = resolve::resolve_params(&analyzed.params, &manifest)?;
142
143    let mut result = GeneratedCode::default();
144
145    // Generate enum definitions for any enum-typed columns
146    // Use the backend-specific enum generation for proper derives
147    let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
148    if !enum_def.is_empty() {
149        result.enum_def = Some(enum_def);
150    }
151
152    // Generate row/model struct for :one and :many commands
153    let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
154    if needs_row_struct && !analyzed.columns.is_empty() {
155        if let Some(ref table_name) = analyzed.source_table {
156            result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
157        } else {
158            result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
159        }
160    }
161
162    // Generate composite type definitions
163    if !analyzed.composites.is_empty() {
164        let mut comp_defs = String::new();
165        for (i, comp) in analyzed.composites.iter().enumerate() {
166            if i > 0 {
167                comp_defs.push_str("\n\n");
168            }
169            comp_defs.push_str(&backend.generate_composite_def(comp)?);
170        }
171        if !comp_defs.is_empty() {
172            if let Some(ref mut existing) = result.model_struct {
173                existing.push_str("\n\n");
174                existing.push_str(&comp_defs);
175            } else {
176                result.model_struct = Some(comp_defs);
177            }
178        }
179    }
180
181    // Generate query function
182    let struct_name = determine_struct_name(analyzed, &manifest);
183    result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, &params)?);
184
185    Ok(result)
186}
187
188/// Generate enum definitions via the backend trait.
189fn generate_enum_defs_via_backend(
190    analyzed: &AnalyzedQuery,
191    backend: &dyn CodegenBackend,
192) -> Result<String, ScytheError> {
193    use ahash::AHashSet;
194    use std::fmt::Write;
195
196    let mut out = String::new();
197    let mut seen_enums: AHashSet<String> = AHashSet::new();
198
199    let enum_sources: Vec<&str> = analyzed
200        .columns
201        .iter()
202        .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
203        .chain(
204            analyzed
205                .params
206                .iter()
207                .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
208        )
209        .collect();
210
211    for sql_name in enum_sources {
212        if !seen_enums.insert(sql_name.to_string()) {
213            continue;
214        }
215
216        if !out.is_empty() {
217            let _ = writeln!(out);
218        }
219
220        if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
221            out.push_str(&backend.generate_enum_def(enum_info)?);
222        } else {
223            // Generate a stub enum with no variants (for enum types referenced but
224            // not fully defined in the query's EnumInfo list).
225            let stub_info = EnumInfo {
226                sql_name: sql_name.to_string(),
227                values: vec![],
228            };
229            out.push_str(&backend.generate_enum_def(&stub_info)?);
230        }
231    }
232
233    Ok(out)
234}
235
236/// Backward-compatible: generate code using the default sqlx backend.
237pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
238    let backend = get_backend("rust-sqlx")?;
239    generate_with_backend(analyzed, &*backend)
240}
241
242/// Stub for catalog-level codegen. Returns default for now.
243pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
244    Ok(GeneratedCode::default())
245}
246
247/// Generate a single enum definition using a specific backend.
248pub fn generate_single_enum_def_with_backend(
249    enum_info: &EnumInfo,
250    backend: &dyn CodegenBackend,
251) -> Result<String, ScytheError> {
252    backend.generate_enum_def(enum_info)
253}
254
255/// Backward-compatible: generate a single enum definition (sqlx backend).
256/// Uses the manifest directly for backward compatibility with existing callers.
257pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
258    // Reproduce the old behavior exactly using the sqlx backend's logic
259    use scythe_backend::naming::{enum_type_name, enum_variant_name};
260    use std::fmt::Write;
261
262    let mut out = String::with_capacity(256);
263    let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
264
265    let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
266    let _ = writeln!(
267        out,
268        "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
269        enum_info.sql_name
270    );
271    let _ = writeln!(out, "pub enum {type_name} {{");
272
273    for value in &enum_info.values {
274        let variant = enum_variant_name(value, &manifest.naming);
275        let _ = writeln!(out, "    {variant},");
276    }
277
278    let _ = write!(out, "}}");
279    out
280}
281
282/// Backward-compatible: load the default sqlx manifest.
283pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
284    let b = backends::sqlx::SqlxBackend::new()?;
285    Ok(b.manifest().clone())
286}
287
288// ---------------------------------------------------------------------------
289// Tests
290// ---------------------------------------------------------------------------
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
296    use scythe_core::parser::QueryCommand;
297
298    fn make_query(
299        name: &str,
300        command: QueryCommand,
301        sql: &str,
302        columns: Vec<AnalyzedColumn>,
303        params: Vec<AnalyzedParam>,
304    ) -> AnalyzedQuery {
305        AnalyzedQuery {
306            name: name.to_string(),
307            command,
308            sql: sql.to_string(),
309            columns,
310            params,
311            deprecated: None,
312            source_table: None,
313            composites: Vec::new(),
314            enums: Vec::new(),
315        }
316    }
317
318    #[test]
319    fn test_generate_select_many() {
320        let query = make_query(
321            "ListUsers",
322            QueryCommand::Many,
323            "SELECT id, name, email FROM users",
324            vec![
325                AnalyzedColumn {
326                    name: "id".to_string(),
327                    neutral_type: "int32".to_string(),
328                    nullable: false,
329                },
330                AnalyzedColumn {
331                    name: "name".to_string(),
332                    neutral_type: "string".to_string(),
333                    nullable: false,
334                },
335                AnalyzedColumn {
336                    name: "email".to_string(),
337                    neutral_type: "string".to_string(),
338                    nullable: true,
339                },
340            ],
341            vec![],
342        );
343
344        let result = generate(&query).unwrap();
345
346        let row_struct = result.row_struct.unwrap();
347        assert!(row_struct.contains("pub struct ListUsersRow"));
348        assert!(row_struct.contains("pub id: i32"));
349        assert!(row_struct.contains("pub name: String"));
350        assert!(row_struct.contains("pub email: Option<String>"));
351
352        let query_fn = result.query_fn.unwrap();
353        assert!(query_fn.contains("pub async fn list_users("));
354        assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
355        assert!(query_fn.contains(".fetch_all(pool)"));
356    }
357
358    #[test]
359    fn test_generate_select_one_with_param() {
360        let query = make_query(
361            "GetUser",
362            QueryCommand::One,
363            "SELECT id, name FROM users WHERE id = $1",
364            vec![
365                AnalyzedColumn {
366                    name: "id".to_string(),
367                    neutral_type: "int32".to_string(),
368                    nullable: false,
369                },
370                AnalyzedColumn {
371                    name: "name".to_string(),
372                    neutral_type: "string".to_string(),
373                    nullable: false,
374                },
375            ],
376            vec![AnalyzedParam {
377                name: "id".to_string(),
378                neutral_type: "int32".to_string(),
379                nullable: false,
380                position: 1,
381            }],
382        );
383
384        let result = generate(&query).unwrap();
385
386        let query_fn = result.query_fn.unwrap();
387        assert!(query_fn.contains("pub async fn get_user("));
388        assert!(query_fn.contains("id: i32"));
389        assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
390        assert!(query_fn.contains(".fetch_one(pool)"));
391    }
392
393    #[test]
394    fn test_generate_exec() {
395        let query = make_query(
396            "DeleteUser",
397            QueryCommand::Exec,
398            "DELETE FROM users WHERE id = $1",
399            vec![],
400            vec![AnalyzedParam {
401                name: "id".to_string(),
402                neutral_type: "int32".to_string(),
403                nullable: false,
404                position: 1,
405            }],
406        );
407
408        let result = generate(&query).unwrap();
409
410        assert!(result.row_struct.is_none());
411
412        let query_fn = result.query_fn.unwrap();
413        assert!(query_fn.contains("pub async fn delete_user("));
414        assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
415        assert!(query_fn.contains(".execute(pool)"));
416    }
417
418    #[test]
419    fn test_generate_with_enum_column() {
420        let query = make_query(
421            "GetUserStatus",
422            QueryCommand::One,
423            "SELECT id, status FROM users WHERE id = $1",
424            vec![
425                AnalyzedColumn {
426                    name: "id".to_string(),
427                    neutral_type: "int32".to_string(),
428                    nullable: false,
429                },
430                AnalyzedColumn {
431                    name: "status".to_string(),
432                    neutral_type: "enum::user_status".to_string(),
433                    nullable: false,
434                },
435            ],
436            vec![AnalyzedParam {
437                name: "id".to_string(),
438                neutral_type: "int32".to_string(),
439                nullable: false,
440                position: 1,
441            }],
442        );
443
444        let result = generate(&query).unwrap();
445
446        assert!(result.enum_def.is_some());
447        let enum_def = result.enum_def.unwrap();
448        assert!(enum_def.contains("pub enum UserStatus"));
449        assert!(enum_def.contains("type_name = \"user_status\""));
450
451        let row_struct = result.row_struct.unwrap();
452        assert!(row_struct.contains("pub status: UserStatus"));
453    }
454
455    #[test]
456    fn test_generate_from_catalog_returns_default() {
457        let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
458        let result = generate_from_catalog(&catalog).unwrap();
459        assert!(result.query_fn.is_none());
460        assert!(result.row_struct.is_none());
461    }
462
463    #[test]
464    fn test_singularize_basic() {
465        assert_eq!(singularize("users"), "user");
466        assert_eq!(singularize("orders"), "order");
467        assert_eq!(singularize("posts"), "post");
468    }
469
470    #[test]
471    fn test_singularize_ies() {
472        assert_eq!(singularize("categories"), "category");
473        assert_eq!(singularize("entries"), "entry");
474    }
475
476    #[test]
477    fn test_singularize_sses() {
478        assert_eq!(singularize("addresses"), "address");
479        assert_eq!(singularize("classes"), "class");
480    }
481
482    #[test]
483    fn test_singularize_no_change() {
484        assert_eq!(singularize("status"), "statu");
485        assert_eq!(singularize("boss"), "boss");
486        assert_eq!(singularize("address"), "address");
487    }
488
489    #[test]
490    fn test_singularize_shes_ches_xes() {
491        assert_eq!(singularize("batches"), "batch");
492        assert_eq!(singularize("boxes"), "box");
493        assert_eq!(singularize("wishes"), "wish");
494    }
495
496    #[test]
497    fn test_tokio_postgres_backend_basic() {
498        let backend = get_backend("tokio-postgres").unwrap();
499
500        let query = make_query(
501            "ListUsers",
502            QueryCommand::Many,
503            "SELECT id, name FROM users",
504            vec![
505                AnalyzedColumn {
506                    name: "id".to_string(),
507                    neutral_type: "int32".to_string(),
508                    nullable: false,
509                },
510                AnalyzedColumn {
511                    name: "name".to_string(),
512                    neutral_type: "string".to_string(),
513                    nullable: false,
514                },
515            ],
516            vec![],
517        );
518
519        let result = generate_with_backend(&query, &*backend).unwrap();
520
521        let row_struct = result.row_struct.unwrap();
522        assert!(row_struct.contains("pub struct ListUsersRow"));
523        assert!(row_struct.contains("pub id: i32"));
524        assert!(row_struct.contains("pub name: String"));
525        assert!(row_struct.contains("from_row"));
526        assert!(row_struct.contains("tokio_postgres::Row"));
527        // Should NOT contain sqlx
528        assert!(!row_struct.contains("sqlx"));
529
530        let query_fn = result.query_fn.unwrap();
531        assert!(query_fn.contains("pub async fn list_users("));
532        assert!(query_fn.contains("tokio_postgres::Client"));
533        assert!(query_fn.contains("tokio_postgres::Error"));
534        assert!(!query_fn.contains("sqlx"));
535    }
536
537    #[test]
538    fn test_tokio_postgres_enum() {
539        let backend = get_backend("tokio-postgres").unwrap();
540
541        let enum_info = scythe_core::analyzer::EnumInfo {
542            sql_name: "user_status".to_string(),
543            values: vec!["active".to_string(), "inactive".to_string()],
544        };
545
546        let def = backend.generate_enum_def(&enum_info).unwrap();
547        assert!(def.contains("pub enum UserStatus"));
548        assert!(def.contains("Active"));
549        assert!(def.contains("Inactive"));
550        assert!(def.contains("impl std::fmt::Display"));
551        assert!(def.contains("impl std::str::FromStr"));
552        // Should NOT contain sqlx
553        assert!(!def.contains("sqlx"));
554    }
555}