Skip to main content

scythe_codegen/backends/
typescript_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
18    include_str!("../../manifests/typescript-postgres.toml");
19
20pub struct TypescriptPostgresBackend {
21    manifest: BackendManifest,
22}
23
24impl TypescriptPostgresBackend {
25    pub fn new() -> Result<Self, ScytheError> {
26        let manifest_path = Path::new("backends/typescript-postgres/manifest.toml");
27        let manifest = if manifest_path.exists() {
28            load_manifest(manifest_path)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        } else {
31            toml::from_str(DEFAULT_MANIFEST_TOML)
32                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
33        };
34        Ok(Self { manifest })
35    }
36
37    pub fn manifest(&self) -> &BackendManifest {
38        &self.manifest
39    }
40}
41
42impl CodegenBackend for TypescriptPostgresBackend {
43    fn name(&self) -> &str {
44        "typescript-postgres"
45    }
46
47    fn file_header(&self) -> String {
48        "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Sql } from \"postgres\";\n"
49            .to_string()
50    }
51
52    fn generate_row_struct(
53        &self,
54        query_name: &str,
55        columns: &[ResolvedColumn],
56    ) -> Result<String, ScytheError> {
57        let struct_name = row_struct_name(query_name, &self.manifest.naming);
58        let mut out = String::new();
59        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
60        let _ = writeln!(out, "export interface {} {{", struct_name);
61        for col in columns {
62            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
63        }
64        let _ = write!(out, "}}");
65        Ok(out)
66    }
67
68    fn generate_model_struct(
69        &self,
70        table_name: &str,
71        columns: &[ResolvedColumn],
72    ) -> Result<String, ScytheError> {
73        let singular = singularize(table_name);
74        let name = to_pascal_case(&singular);
75        self.generate_row_struct(&name, columns)
76    }
77
78    fn generate_query_fn(
79        &self,
80        analyzed: &AnalyzedQuery,
81        struct_name: &str,
82        _columns: &[ResolvedColumn],
83        params: &[ResolvedParam],
84    ) -> Result<String, ScytheError> {
85        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
86        let mut out = String::new();
87
88        // Build parameter list
89        let param_list = params
90            .iter()
91            .map(|p| format!("{}: {}", p.field_name, p.full_type))
92            .collect::<Vec<_>>()
93            .join(", ");
94        let _sep = if param_list.is_empty() { "" } else { ", " };
95
96        // Clean SQL and rewrite $1, $2 to ${paramName} for postgres.js tagged template
97        let sql_clean = super::clean_sql(&analyzed.sql);
98        let sql_template = rewrite_params_template(&sql_clean, analyzed, params);
99
100        // Build function params: inline if short, multi-line if long (biome compliance)
101        let inline_params = if params.is_empty() {
102            "sql: Sql".to_string()
103        } else {
104            format!("sql: Sql, {}", param_list)
105        };
106
107        // We'll decide inline vs multi-line per call site based on the full signature length
108
109        // Helper: write function signature, inline or multi-line based on length
110        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
111            let oneliner = format!(
112                "export async function {}({}): {} {{",
113                name, params_inline, ret
114            );
115            if oneliner.len() <= 80 {
116                let _ = writeln!(out, "{}", oneliner);
117            } else {
118                // Multi-line params
119                let mut parts = vec!["\tsql: Sql".to_string()];
120                for p in params {
121                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
122                }
123                let _ = writeln!(out, "export async function {}(", name);
124                for part in &parts {
125                    let _ = writeln!(out, "{},", part);
126                }
127                let _ = writeln!(out, "): {} {{", ret);
128            }
129        };
130
131        match &analyzed.command {
132            QueryCommand::One => {
133                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
134                let ret = format!("Promise<{} | null>", struct_name);
135                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
136                let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
137                let _ = writeln!(out, "    {}", sql_template);
138                let _ = writeln!(out, "  `;");
139                let _ = writeln!(out, "\treturn rows[0] ?? null;");
140                let _ = write!(out, "}}");
141            }
142            QueryCommand::Many | QueryCommand::Batch => {
143                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
144                let ret = format!("Promise<{}[]>", struct_name);
145                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
146                let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
147                let _ = writeln!(out, "    {}", sql_template);
148                let _ = writeln!(out, "  `;");
149                let _ = writeln!(out, "\treturn rows;");
150                let _ = write!(out, "}}");
151            }
152            QueryCommand::Exec => {
153                let _ = writeln!(out, "/** Execute a query returning no rows. */");
154                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
155                let _ = writeln!(out, "\tawait sql`");
156                let _ = writeln!(out, "    {}", sql_template);
157                let _ = writeln!(out, "  `;");
158                let _ = write!(out, "}}");
159            }
160            QueryCommand::ExecResult | QueryCommand::ExecRows => {
161                let _ = writeln!(
162                    out,
163                    "/** Execute a query and return the number of affected rows. */"
164                );
165                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
166                let _ = writeln!(out, "\tconst result = await sql`");
167                let _ = writeln!(out, "    {}", sql_template);
168                let _ = writeln!(out, "  `;");
169                let _ = writeln!(out, "\treturn result.count;");
170                let _ = write!(out, "}}");
171            }
172        }
173
174        Ok(out)
175    }
176
177    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
178        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
179        let mut out = String::new();
180        let _ = writeln!(out, "export enum {} {{", type_name);
181        for value in &enum_info.values {
182            let variant = enum_variant_name(value, &self.manifest.naming);
183            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
184        }
185        let _ = write!(out, "}}");
186        Ok(out)
187    }
188
189    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
190        let name = to_pascal_case(&composite.sql_name);
191        let mut out = String::new();
192        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
193        let _ = writeln!(out, "export interface {} {{", name);
194        if composite.fields.is_empty() {
195            // empty interface
196        } else {
197            for field in &composite.fields {
198                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
199                    .map(|t| t.into_owned())
200                    .map_err(|e| {
201                        ScytheError::new(
202                            ErrorCode::InternalError,
203                            format!("composite field type error: {}", e),
204                        )
205                    })?;
206                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
207            }
208        }
209        let _ = write!(out, "}}");
210        Ok(out)
211    }
212}
213
214/// Rewrite `$1`, `$2`, ... positional params to `${paramName}` for postgres.js tagged templates.
215fn rewrite_params_template(
216    sql: &str,
217    analyzed: &AnalyzedQuery,
218    params: &[ResolvedParam],
219) -> String {
220    let mut result = sql.to_string();
221    // Replace in reverse order so positions don't shift
222    let mut indexed: Vec<(i64, &str)> = analyzed
223        .params
224        .iter()
225        .zip(params.iter())
226        .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
227        .collect();
228    indexed.sort_by(|a, b| b.0.cmp(&a.0));
229    for (pos, field_name) in indexed {
230        let placeholder = format!("${}", pos);
231        let replacement = format!("${{{}}}", field_name);
232        result = result.replace(&placeholder, &replacement);
233    }
234    result
235}