Skip to main content

scythe_codegen/
lib.rs

1pub mod backend_trait;
2pub mod backends;
3pub mod resolve;
4pub mod validation;
5
6pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
7pub use backends::get_backend;
8
9use scythe_backend::manifest::BackendManifest;
10use scythe_backend::naming::{row_struct_name, to_pascal_case};
11
12use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
13use scythe_core::catalog::Catalog;
14use scythe_core::errors::ScytheError;
15use scythe_core::parser::QueryCommand;
16
17// ---------------------------------------------------------------------------
18// Output types
19// ---------------------------------------------------------------------------
20
21#[derive(Debug, Default)]
22pub struct GeneratedCode {
23    pub query_fn: Option<String>,
24    pub row_struct: Option<String>,
25    pub model_struct: Option<String>,
26    pub enum_def: Option<String>,
27}
28
29// ---------------------------------------------------------------------------
30// Utility (shared across backends)
31// ---------------------------------------------------------------------------
32
33/// Simple singularization: remove trailing 's'.
34pub(crate) fn singularize(name: &str) -> String {
35    if let Some(stem) = name.strip_suffix("ies") {
36        format!("{stem}y")
37    } else if name.ends_with("sses")
38        || name.ends_with("shes")
39        || name.ends_with("ches")
40        || name.ends_with("xes")
41        || name.ends_with("zes")
42        || name.ends_with("ses")
43    {
44        name[..name.len() - 2].to_string()
45    } else if name.ends_with('s') && !name.ends_with("ss") {
46        name[..name.len() - 1].to_string()
47    } else {
48        name.to_string()
49    }
50}
51
52// ---------------------------------------------------------------------------
53// Manifest helpers
54// ---------------------------------------------------------------------------
55
56/// Get the manifest for a backend. Defaults to PostgreSQL engine.
57pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
58    let backend = get_backend(backend_name, "postgresql")?;
59    Ok(backend.manifest().clone())
60}
61
62/// Determine the struct name for a query (model struct or row struct).
63fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
64    if let Some(ref table_name) = analyzed.source_table {
65        let singular = singularize(table_name);
66        to_pascal_case(&singular).into_owned()
67    } else {
68        row_struct_name(&analyzed.name, &manifest.naming)
69    }
70}
71
72// ---------------------------------------------------------------------------
73// Public API
74// ---------------------------------------------------------------------------
75
76/// Generate code using a specific backend.
77pub fn generate_with_backend(
78    analyzed: &AnalyzedQuery,
79    backend: &dyn CodegenBackend,
80) -> Result<GeneratedCode, ScytheError> {
81    let manifest = backend.manifest();
82    let columns = resolve::resolve_columns(&analyzed.columns, manifest)?;
83    let params = resolve::resolve_params(&analyzed.params, manifest)?;
84
85    let mut result = GeneratedCode::default();
86
87    // Generate enum definitions for any enum-typed columns
88    // Use the backend-specific enum generation for proper derives
89    let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
90    if !enum_def.is_empty() {
91        result.enum_def = Some(enum_def);
92    }
93
94    // Generate row/model struct for :one and :many commands
95    let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
96    if needs_row_struct && !analyzed.columns.is_empty() {
97        if let Some(ref table_name) = analyzed.source_table {
98            result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
99        } else {
100            result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
101        }
102    }
103
104    // Generate composite type definitions
105    if !analyzed.composites.is_empty() {
106        let mut comp_defs = String::new();
107        for (i, comp) in analyzed.composites.iter().enumerate() {
108            if i > 0 {
109                comp_defs.push_str("\n\n");
110            }
111            comp_defs.push_str(&backend.generate_composite_def(comp)?);
112        }
113        if !comp_defs.is_empty() {
114            if let Some(ref mut existing) = result.model_struct {
115                existing.push_str("\n\n");
116                existing.push_str(&comp_defs);
117            } else {
118                result.model_struct = Some(comp_defs);
119            }
120        }
121    }
122
123    // Generate query function
124    let struct_name = determine_struct_name(analyzed, manifest);
125    result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, &params)?);
126
127    Ok(result)
128}
129
130/// Generate enum definitions via the backend trait.
131fn generate_enum_defs_via_backend(
132    analyzed: &AnalyzedQuery,
133    backend: &dyn CodegenBackend,
134) -> Result<String, ScytheError> {
135    use ahash::AHashSet;
136    use std::fmt::Write;
137
138    let mut out = String::new();
139    let mut seen_enums: AHashSet<String> = AHashSet::new();
140
141    let enum_sources: Vec<&str> = analyzed
142        .columns
143        .iter()
144        .filter_map(|col| col.neutral_type.strip_prefix("enum::"))
145        .chain(
146            analyzed
147                .params
148                .iter()
149                .filter_map(|p| p.neutral_type.strip_prefix("enum::")),
150        )
151        .collect();
152
153    for sql_name in enum_sources {
154        if !seen_enums.insert(sql_name.to_string()) {
155            continue;
156        }
157
158        if !out.is_empty() {
159            let _ = writeln!(out);
160        }
161
162        if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
163            out.push_str(&backend.generate_enum_def(enum_info)?);
164        } else {
165            // Generate a stub enum with no variants (for enum types referenced but
166            // not fully defined in the query's EnumInfo list).
167            let stub_info = EnumInfo {
168                sql_name: sql_name.to_string(),
169                values: vec![],
170            };
171            out.push_str(&backend.generate_enum_def(&stub_info)?);
172        }
173    }
174
175    Ok(out)
176}
177
178/// Backward-compatible: generate code using the default sqlx backend.
179pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
180    let backend = get_backend("rust-sqlx", "postgresql")?;
181    generate_with_backend(analyzed, &*backend)
182}
183
184/// Stub for catalog-level codegen. Returns default for now.
185pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
186    Ok(GeneratedCode::default())
187}
188
189/// Generate a single enum definition using a specific backend.
190pub fn generate_single_enum_def_with_backend(
191    enum_info: &EnumInfo,
192    backend: &dyn CodegenBackend,
193) -> Result<String, ScytheError> {
194    backend.generate_enum_def(enum_info)
195}
196
197/// Backward-compatible: generate a single enum definition (sqlx backend).
198/// Uses the manifest directly for backward compatibility with existing callers.
199pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
200    // Reproduce the old behavior exactly using the sqlx backend's logic
201    use scythe_backend::naming::{enum_type_name, enum_variant_name};
202    use std::fmt::Write;
203
204    let mut out = String::with_capacity(256);
205    let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
206
207    let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
208    let _ = writeln!(
209        out,
210        "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
211        enum_info.sql_name
212    );
213    let _ = writeln!(out, "pub enum {type_name} {{");
214
215    for value in &enum_info.values {
216        let variant = enum_variant_name(value, &manifest.naming);
217        let _ = writeln!(out, "    {variant},");
218    }
219
220    let _ = write!(out, "}}");
221    out
222}
223
224/// Backward-compatible: load the default sqlx manifest.
225pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
226    let b = backends::sqlx::SqlxBackend::new("postgresql")?;
227    Ok(b.manifest().clone())
228}
229
230// ---------------------------------------------------------------------------
231// Tests
232// ---------------------------------------------------------------------------
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
238    use scythe_core::parser::QueryCommand;
239
240    fn make_query(
241        name: &str,
242        command: QueryCommand,
243        sql: &str,
244        columns: Vec<AnalyzedColumn>,
245        params: Vec<AnalyzedParam>,
246    ) -> AnalyzedQuery {
247        AnalyzedQuery {
248            name: name.to_string(),
249            command,
250            sql: sql.to_string(),
251            columns,
252            params,
253            deprecated: None,
254            source_table: None,
255            composites: Vec::new(),
256            enums: Vec::new(),
257        }
258    }
259
260    #[test]
261    fn test_generate_select_many() {
262        let query = make_query(
263            "ListUsers",
264            QueryCommand::Many,
265            "SELECT id, name, email FROM users",
266            vec![
267                AnalyzedColumn {
268                    name: "id".to_string(),
269                    neutral_type: "int32".to_string(),
270                    nullable: false,
271                },
272                AnalyzedColumn {
273                    name: "name".to_string(),
274                    neutral_type: "string".to_string(),
275                    nullable: false,
276                },
277                AnalyzedColumn {
278                    name: "email".to_string(),
279                    neutral_type: "string".to_string(),
280                    nullable: true,
281                },
282            ],
283            vec![],
284        );
285
286        let result = generate(&query).unwrap();
287
288        let row_struct = result.row_struct.unwrap();
289        assert!(row_struct.contains("pub struct ListUsersRow"));
290        assert!(row_struct.contains("pub id: i32"));
291        assert!(row_struct.contains("pub name: String"));
292        assert!(row_struct.contains("pub email: Option<String>"));
293
294        let query_fn = result.query_fn.unwrap();
295        assert!(query_fn.contains("pub async fn list_users("));
296        assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
297        assert!(query_fn.contains(".fetch_all(pool)"));
298    }
299
300    #[test]
301    fn test_generate_select_one_with_param() {
302        let query = make_query(
303            "GetUser",
304            QueryCommand::One,
305            "SELECT id, name FROM users WHERE id = $1",
306            vec![
307                AnalyzedColumn {
308                    name: "id".to_string(),
309                    neutral_type: "int32".to_string(),
310                    nullable: false,
311                },
312                AnalyzedColumn {
313                    name: "name".to_string(),
314                    neutral_type: "string".to_string(),
315                    nullable: false,
316                },
317            ],
318            vec![AnalyzedParam {
319                name: "id".to_string(),
320                neutral_type: "int32".to_string(),
321                nullable: false,
322                position: 1,
323            }],
324        );
325
326        let result = generate(&query).unwrap();
327
328        let query_fn = result.query_fn.unwrap();
329        assert!(query_fn.contains("pub async fn get_user("));
330        assert!(query_fn.contains("id: i32"));
331        assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
332        assert!(query_fn.contains(".fetch_one(pool)"));
333    }
334
335    #[test]
336    fn test_generate_exec() {
337        let query = make_query(
338            "DeleteUser",
339            QueryCommand::Exec,
340            "DELETE FROM users WHERE id = $1",
341            vec![],
342            vec![AnalyzedParam {
343                name: "id".to_string(),
344                neutral_type: "int32".to_string(),
345                nullable: false,
346                position: 1,
347            }],
348        );
349
350        let result = generate(&query).unwrap();
351
352        assert!(result.row_struct.is_none());
353
354        let query_fn = result.query_fn.unwrap();
355        assert!(query_fn.contains("pub async fn delete_user("));
356        assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
357        assert!(query_fn.contains(".execute(pool)"));
358    }
359
360    #[test]
361    fn test_generate_with_enum_column() {
362        let query = make_query(
363            "GetUserStatus",
364            QueryCommand::One,
365            "SELECT id, status FROM users WHERE id = $1",
366            vec![
367                AnalyzedColumn {
368                    name: "id".to_string(),
369                    neutral_type: "int32".to_string(),
370                    nullable: false,
371                },
372                AnalyzedColumn {
373                    name: "status".to_string(),
374                    neutral_type: "enum::user_status".to_string(),
375                    nullable: false,
376                },
377            ],
378            vec![AnalyzedParam {
379                name: "id".to_string(),
380                neutral_type: "int32".to_string(),
381                nullable: false,
382                position: 1,
383            }],
384        );
385
386        let result = generate(&query).unwrap();
387
388        assert!(result.enum_def.is_some());
389        let enum_def = result.enum_def.unwrap();
390        assert!(enum_def.contains("pub enum UserStatus"));
391        assert!(enum_def.contains("type_name = \"user_status\""));
392
393        let row_struct = result.row_struct.unwrap();
394        assert!(row_struct.contains("pub status: UserStatus"));
395    }
396
397    #[test]
398    fn test_generate_from_catalog_returns_default() {
399        let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
400        let result = generate_from_catalog(&catalog).unwrap();
401        assert!(result.query_fn.is_none());
402        assert!(result.row_struct.is_none());
403    }
404
405    #[test]
406    fn test_singularize_basic() {
407        assert_eq!(singularize("users"), "user");
408        assert_eq!(singularize("orders"), "order");
409        assert_eq!(singularize("posts"), "post");
410    }
411
412    #[test]
413    fn test_singularize_ies() {
414        assert_eq!(singularize("categories"), "category");
415        assert_eq!(singularize("entries"), "entry");
416    }
417
418    #[test]
419    fn test_singularize_sses() {
420        assert_eq!(singularize("addresses"), "address");
421        assert_eq!(singularize("classes"), "class");
422    }
423
424    #[test]
425    fn test_singularize_no_change() {
426        assert_eq!(singularize("status"), "statu");
427        assert_eq!(singularize("boss"), "boss");
428        assert_eq!(singularize("address"), "address");
429    }
430
431    #[test]
432    fn test_singularize_shes_ches_xes() {
433        assert_eq!(singularize("batches"), "batch");
434        assert_eq!(singularize("boxes"), "box");
435        assert_eq!(singularize("wishes"), "wish");
436    }
437
438    #[test]
439    fn test_tokio_postgres_backend_basic() {
440        let backend = get_backend("tokio-postgres", "postgresql").unwrap();
441
442        let query = make_query(
443            "ListUsers",
444            QueryCommand::Many,
445            "SELECT id, name FROM users",
446            vec![
447                AnalyzedColumn {
448                    name: "id".to_string(),
449                    neutral_type: "int32".to_string(),
450                    nullable: false,
451                },
452                AnalyzedColumn {
453                    name: "name".to_string(),
454                    neutral_type: "string".to_string(),
455                    nullable: false,
456                },
457            ],
458            vec![],
459        );
460
461        let result = generate_with_backend(&query, &*backend).unwrap();
462
463        let row_struct = result.row_struct.unwrap();
464        assert!(row_struct.contains("pub struct ListUsersRow"));
465        assert!(row_struct.contains("pub id: i32"));
466        assert!(row_struct.contains("pub name: String"));
467        assert!(row_struct.contains("from_row"));
468        assert!(row_struct.contains("tokio_postgres::Row"));
469        // Should NOT contain sqlx
470        assert!(!row_struct.contains("sqlx"));
471
472        let query_fn = result.query_fn.unwrap();
473        assert!(query_fn.contains("pub async fn list_users("));
474        assert!(query_fn.contains("tokio_postgres::Client"));
475        assert!(query_fn.contains("tokio_postgres::Error"));
476        assert!(!query_fn.contains("sqlx"));
477    }
478
479    #[test]
480    fn test_tokio_postgres_enum() {
481        let backend = get_backend("tokio-postgres", "postgresql").unwrap();
482
483        let enum_info = scythe_core::analyzer::EnumInfo {
484            sql_name: "user_status".to_string(),
485            values: vec!["active".to_string(), "inactive".to_string()],
486        };
487
488        let def = backend.generate_enum_def(&enum_info).unwrap();
489        assert!(def.contains("pub enum UserStatus"));
490        assert!(def.contains("Active"));
491        assert!(def.contains("Inactive"));
492        assert!(def.contains("impl std::fmt::Display"));
493        assert!(def.contains("impl std::str::FromStr"));
494        // Should NOT contain sqlx
495        assert!(!def.contains("sqlx"));
496    }
497}