Skip to main content

scythe_codegen/backends/
typescript_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-postgres.toml");
18
19pub struct TypescriptPostgresBackend {
20    manifest: BackendManifest,
21}
22
23impl TypescriptPostgresBackend {
24    pub fn new() -> Result<Self, ScytheError> {
25        let manifest_path = Path::new("backends/typescript-postgres/manifest.toml");
26        let manifest = if manifest_path.exists() {
27            load_manifest(manifest_path)
28                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29        } else {
30            toml::from_str(DEFAULT_MANIFEST_TOML)
31                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32        };
33        Ok(Self { manifest })
34    }
35
36    pub fn manifest(&self) -> &BackendManifest {
37        &self.manifest
38    }
39}
40
41impl CodegenBackend for TypescriptPostgresBackend {
42    fn name(&self) -> &str {
43        "typescript-postgres"
44    }
45
46    fn file_header(&self) -> String {
47        "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Sql } from \"postgres\";\n"
48            .to_string()
49    }
50
51    fn generate_row_struct(
52        &self,
53        query_name: &str,
54        columns: &[ResolvedColumn],
55    ) -> Result<String, ScytheError> {
56        let struct_name = row_struct_name(query_name, &self.manifest.naming);
57        let mut out = String::new();
58        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
59        let _ = writeln!(out, "export interface {} {{", struct_name);
60        for col in columns {
61            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
62        }
63        let _ = write!(out, "}}");
64        Ok(out)
65    }
66
67    fn generate_model_struct(
68        &self,
69        table_name: &str,
70        columns: &[ResolvedColumn],
71    ) -> Result<String, ScytheError> {
72        let singular = singularize(table_name);
73        let name = to_pascal_case(&singular);
74        self.generate_row_struct(&name, columns)
75    }
76
77    fn generate_query_fn(
78        &self,
79        analyzed: &AnalyzedQuery,
80        struct_name: &str,
81        _columns: &[ResolvedColumn],
82        params: &[ResolvedParam],
83    ) -> Result<String, ScytheError> {
84        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
85        let mut out = String::new();
86
87        // Build parameter list
88        let param_list = params
89            .iter()
90            .map(|p| format!("{}: {}", p.field_name, p.full_type))
91            .collect::<Vec<_>>()
92            .join(", ");
93        let _sep = if param_list.is_empty() { "" } else { ", " };
94
95        // Clean SQL and rewrite $1, $2 to ${paramName} for postgres.js tagged template
96        let sql_clean = super::clean_sql(&analyzed.sql);
97        let sql_template = rewrite_params_template(&sql_clean, analyzed, params);
98
99        // Build function params: inline if short, multi-line if long (biome compliance)
100        let inline_params = if params.is_empty() {
101            "sql: Sql".to_string()
102        } else {
103            format!("sql: Sql, {}", param_list)
104        };
105
106        // We'll decide inline vs multi-line per call site based on the full signature length
107
108        // Helper: write function signature, inline or multi-line based on length
109        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
110            let oneliner = format!(
111                "export async function {}({}): {} {{",
112                name, params_inline, ret
113            );
114            if oneliner.len() <= 80 {
115                let _ = writeln!(out, "{}", oneliner);
116            } else {
117                // Multi-line params
118                let mut parts = vec!["\tsql: Sql".to_string()];
119                for p in params {
120                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
121                }
122                let _ = writeln!(out, "export async function {}(", name);
123                for part in &parts {
124                    let _ = writeln!(out, "{},", part);
125                }
126                let _ = writeln!(out, "): {} {{", ret);
127            }
128        };
129
130        match &analyzed.command {
131            QueryCommand::One => {
132                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
133                let ret = format!("Promise<{} | null>", struct_name);
134                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
135                let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
136                let _ = writeln!(out, "    {}", sql_template);
137                let _ = writeln!(out, "  `;");
138                let _ = writeln!(out, "\treturn rows[0] ?? null;");
139                let _ = write!(out, "}}");
140            }
141            QueryCommand::Many | QueryCommand::Batch => {
142                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
143                let ret = format!("Promise<{}[]>", struct_name);
144                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
145                let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
146                let _ = writeln!(out, "    {}", sql_template);
147                let _ = writeln!(out, "  `;");
148                let _ = writeln!(out, "\treturn rows;");
149                let _ = write!(out, "}}");
150            }
151            QueryCommand::Exec => {
152                let _ = writeln!(out, "/** Execute a query returning no rows. */");
153                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
154                let _ = writeln!(out, "\tawait sql`");
155                let _ = writeln!(out, "    {}", sql_template);
156                let _ = writeln!(out, "  `;");
157                let _ = write!(out, "}}");
158            }
159            QueryCommand::ExecResult | QueryCommand::ExecRows => {
160                let _ = writeln!(
161                    out,
162                    "/** Execute a query and return the number of affected rows. */"
163                );
164                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
165                let _ = writeln!(out, "\tconst result = await sql`");
166                let _ = writeln!(out, "    {}", sql_template);
167                let _ = writeln!(out, "  `;");
168                let _ = writeln!(out, "\treturn result.count;");
169                let _ = write!(out, "}}");
170            }
171        }
172
173        Ok(out)
174    }
175
176    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
177        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
178        let mut out = String::new();
179        let _ = writeln!(out, "export enum {} {{", type_name);
180        for value in &enum_info.values {
181            let variant = enum_variant_name(value, &self.manifest.naming);
182            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
183        }
184        let _ = write!(out, "}}");
185        Ok(out)
186    }
187
188    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
189        let name = to_pascal_case(&composite.sql_name);
190        let mut out = String::new();
191        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
192        let _ = writeln!(out, "export interface {} {{", name);
193        if composite.fields.is_empty() {
194            // empty interface
195        } else {
196            for field in &composite.fields {
197                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
198                    .map(|t| t.into_owned())
199                    .map_err(|e| {
200                        ScytheError::new(
201                            ErrorCode::InternalError,
202                            format!("composite field type error: {}", e),
203                        )
204                    })?;
205                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
206            }
207        }
208        let _ = write!(out, "}}");
209        Ok(out)
210    }
211}
212
213/// Rewrite `$1`, `$2`, ... positional params to `${paramName}` for postgres.js tagged templates.
214fn rewrite_params_template(
215    sql: &str,
216    analyzed: &AnalyzedQuery,
217    params: &[ResolvedParam],
218) -> String {
219    let mut result = sql.to_string();
220    // Replace in reverse order so positions don't shift
221    let mut indexed: Vec<(i64, &str)> = analyzed
222        .params
223        .iter()
224        .zip(params.iter())
225        .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
226        .collect();
227    indexed.sort_by(|a, b| b.0.cmp(&a.0));
228    for (pos, field_name) in indexed {
229        let placeholder = format!("${}", pos);
230        let replacement = format!("${{{}}}", field_name);
231        result = result.replace(&placeholder, &replacement);
232    }
233    result
234}