Skip to main content

scythe_codegen/backends/
typescript_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
16use crate::singularize;
17
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-postgres.toml");
19
20pub struct TypescriptPostgresBackend {
21    manifest: BackendManifest,
22    row_type: TsRowType,
23}
24
25impl TypescriptPostgresBackend {
26    pub fn new(engine: &str) -> Result<Self, ScytheError> {
27        match engine {
28            "postgresql" | "postgres" | "pg" => {}
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!(
33                        "typescript-postgres only supports PostgreSQL, got engine '{}'",
34                        engine
35                    ),
36                ));
37            }
38        }
39        let manifest_path = Path::new("backends/typescript-postgres/manifest.toml");
40        let manifest = if manifest_path.exists() {
41            load_manifest(manifest_path)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        } else {
44            toml::from_str(DEFAULT_MANIFEST_TOML)
45                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46        };
47        Ok(Self {
48            manifest,
49            row_type: TsRowType::default(),
50        })
51    }
52}
53
54impl CodegenBackend for TypescriptPostgresBackend {
55    fn name(&self) -> &str {
56        "typescript-postgres"
57    }
58
59    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60        &self.manifest
61    }
62
63    fn file_header(&self) -> String {
64        let mut header =
65            "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Sql } from \"postgres\";\n"
66                .to_string();
67        if self.row_type == TsRowType::Zod {
68            header.push_str("import { z } from \"zod\";\n");
69        }
70        header
71    }
72
73    fn generate_row_struct(
74        &self,
75        query_name: &str,
76        columns: &[ResolvedColumn],
77    ) -> Result<String, ScytheError> {
78        let struct_name = row_struct_name(query_name, &self.manifest.naming);
79        if self.row_type == TsRowType::Zod {
80            return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
81        }
82        let mut out = String::new();
83        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
84        let _ = writeln!(out, "export interface {} {{", struct_name);
85        for col in columns {
86            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
87        }
88        let _ = write!(out, "}}");
89        Ok(out)
90    }
91
92    fn generate_model_struct(
93        &self,
94        table_name: &str,
95        columns: &[ResolvedColumn],
96    ) -> Result<String, ScytheError> {
97        let singular = singularize(table_name);
98        let name = to_pascal_case(&singular);
99        self.generate_row_struct(&name, columns)
100    }
101
102    fn generate_query_fn(
103        &self,
104        analyzed: &AnalyzedQuery,
105        struct_name: &str,
106        _columns: &[ResolvedColumn],
107        params: &[ResolvedParam],
108    ) -> Result<String, ScytheError> {
109        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110        let mut out = String::new();
111
112        // Build parameter list
113        let param_list = params
114            .iter()
115            .map(|p| format!("{}: {}", p.field_name, p.full_type))
116            .collect::<Vec<_>>()
117            .join(", ");
118        let _sep = if param_list.is_empty() { "" } else { ", " };
119
120        // Clean SQL and rewrite $1, $2 to ${paramName} for postgres.js tagged template
121        let sql_clean = super::clean_sql_with_optional(
122            &analyzed.sql,
123            &analyzed.optional_params,
124            &analyzed.params,
125        );
126        let sql_template = rewrite_params_template(&sql_clean, analyzed, params);
127
128        // Build function params: inline if short, multi-line if long (biome compliance)
129        let inline_params = if params.is_empty() {
130            "sql: Sql".to_string()
131        } else {
132            format!("sql: Sql, {}", param_list)
133        };
134
135        // We'll decide inline vs multi-line per call site based on the full signature length
136
137        // Helper: write function signature, inline or multi-line based on length
138        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
139            let oneliner = format!(
140                "export async function {}({}): {} {{",
141                name, params_inline, ret
142            );
143            if oneliner.len() <= 80 {
144                let _ = writeln!(out, "{}", oneliner);
145            } else {
146                // Multi-line params
147                let mut parts = vec!["\tsql: Sql".to_string()];
148                for p in params {
149                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
150                }
151                let _ = writeln!(out, "export async function {}(", name);
152                for part in &parts {
153                    let _ = writeln!(out, "{},", part);
154                }
155                let _ = writeln!(out, "): {} {{", ret);
156            }
157        };
158
159        match &analyzed.command {
160            QueryCommand::One => {
161                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
162                let ret = format!("Promise<{} | null>", struct_name);
163                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
164                let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
165                let _ = writeln!(out, "    {}", sql_template);
166                let _ = writeln!(out, "  `;");
167                let _ = writeln!(out, "\treturn rows[0] ?? null;");
168                let _ = write!(out, "}}");
169            }
170            QueryCommand::Batch => {
171                let batch_fn_name = format!("{}Batch", func_name);
172                if params.len() > 1 {
173                    let params_type_name = format!("{}BatchParams", struct_name);
174                    let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
175                    let _ = writeln!(out, "export interface {} {{", params_type_name);
176                    for p in params {
177                        let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
178                    }
179                    let _ = writeln!(out, "}}");
180                    let _ = writeln!(out);
181                    let _ = writeln!(
182                        out,
183                        "/** Execute {} for each item in the batch within a transaction. */",
184                        analyzed.name
185                    );
186                    let batch_params = format!("sql: Sql, items: {}[]", params_type_name);
187                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
188                    let _ = writeln!(out, "\tawait sql.begin(async (tx) => {{");
189                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
190                    // Build template with item.fieldName
191                    let batch_sql = {
192                        let mut s = sql_clean.clone();
193                        let mut indexed: Vec<(i64, &str)> = analyzed
194                            .params
195                            .iter()
196                            .zip(params.iter())
197                            .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
198                            .collect();
199                        indexed.sort_by(|a, b| b.0.cmp(&a.0));
200                        for (pos, field_name) in indexed {
201                            let placeholder = format!("${}", pos);
202                            let replacement = format!("${{item.{}}}", field_name);
203                            s = s.replace(&placeholder, &replacement);
204                        }
205                        s
206                    };
207                    let _ = writeln!(out, "\t\t\tawait tx`");
208                    let _ = writeln!(out, "    {}", batch_sql);
209                    let _ = writeln!(out, "  `;");
210                    let _ = writeln!(out, "\t\t}}");
211                    let _ = writeln!(out, "\t}});");
212                    let _ = write!(out, "}}");
213                } else if params.len() == 1 {
214                    let _ = writeln!(
215                        out,
216                        "/** Execute {} for each item in the batch within a transaction. */",
217                        analyzed.name
218                    );
219                    let batch_params = format!("sql: Sql, items: {}[]", params[0].full_type);
220                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
221                    let _ = writeln!(out, "\tawait sql.begin(async (tx) => {{");
222                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
223                    let batch_sql =
224                        sql_template.replace(&format!("${{{}}}", params[0].field_name), "${item}");
225                    let _ = writeln!(out, "\t\t\tawait tx`");
226                    let _ = writeln!(out, "    {}", batch_sql);
227                    let _ = writeln!(out, "  `;");
228                    let _ = writeln!(out, "\t\t}}");
229                    let _ = writeln!(out, "\t}});");
230                    let _ = write!(out, "}}");
231                } else {
232                    let _ = writeln!(
233                        out,
234                        "/** Execute {} for each item in the batch within a transaction. */",
235                        analyzed.name
236                    );
237                    write_fn_sig(
238                        &mut out,
239                        &batch_fn_name,
240                        "sql: Sql, count: number",
241                        "Promise<void>",
242                    );
243                    let _ = writeln!(out, "\tawait sql.begin(async (tx) => {{");
244                    let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
245                    let _ = writeln!(out, "\t\t\tawait tx`");
246                    let _ = writeln!(out, "    {}", sql_template);
247                    let _ = writeln!(out, "  `;");
248                    let _ = writeln!(out, "\t\t}}");
249                    let _ = writeln!(out, "\t}});");
250                    let _ = write!(out, "}}");
251                }
252            }
253            QueryCommand::Many => {
254                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
255                let ret = format!("Promise<{}[]>", struct_name);
256                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
257                let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
258                let _ = writeln!(out, "    {}", sql_template);
259                let _ = writeln!(out, "  `;");
260                let _ = writeln!(out, "\treturn rows;");
261                let _ = write!(out, "}}");
262            }
263            QueryCommand::Exec => {
264                let _ = writeln!(out, "/** Execute a query returning no rows. */");
265                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
266                let _ = writeln!(out, "\tawait sql`");
267                let _ = writeln!(out, "    {}", sql_template);
268                let _ = writeln!(out, "  `;");
269                let _ = write!(out, "}}");
270            }
271            QueryCommand::ExecResult | QueryCommand::ExecRows => {
272                let _ = writeln!(
273                    out,
274                    "/** Execute a query and return the number of affected rows. */"
275                );
276                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
277                let _ = writeln!(out, "\tconst result = await sql`");
278                let _ = writeln!(out, "    {}", sql_template);
279                let _ = writeln!(out, "  `;");
280                let _ = writeln!(out, "\treturn result.count;");
281                let _ = write!(out, "}}");
282            }
283        }
284
285        Ok(out)
286    }
287
288    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
289        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
290        if self.row_type == TsRowType::Zod {
291            return Ok(generate_zod_enum(&type_name, &enum_info.values));
292        }
293        let mut out = String::new();
294        let _ = writeln!(out, "export enum {} {{", type_name);
295        for value in &enum_info.values {
296            let variant = enum_variant_name(value, &self.manifest.naming);
297            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
298        }
299        let _ = write!(out, "}}");
300        Ok(out)
301    }
302
303    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
304        let name = to_pascal_case(&composite.sql_name);
305        let mut out = String::new();
306        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
307        let _ = writeln!(out, "export interface {} {{", name);
308        if composite.fields.is_empty() {
309            // empty interface
310        } else {
311            for field in &composite.fields {
312                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
313                    .map(|t| t.into_owned())
314                    .map_err(|e| {
315                        ScytheError::new(
316                            ErrorCode::InternalError,
317                            format!("composite field type error: {}", e),
318                        )
319                    })?;
320                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
321            }
322        }
323        let _ = write!(out, "}}");
324        Ok(out)
325    }
326
327    fn apply_options(
328        &mut self,
329        options: &std::collections::HashMap<String, String>,
330    ) -> Result<(), ScytheError> {
331        if let Some(value) = options.get("row_type") {
332            self.row_type = TsRowType::from_option(value)?;
333        }
334        Ok(())
335    }
336}
337
338/// Rewrite `$1`, `$2`, ... positional params to `${paramName}` for postgres.js tagged templates.
339fn rewrite_params_template(
340    sql: &str,
341    analyzed: &AnalyzedQuery,
342    params: &[ResolvedParam],
343) -> String {
344    let mut result = sql.to_string();
345    // Replace in reverse order so positions don't shift
346    let mut indexed: Vec<(i64, &str)> = analyzed
347        .params
348        .iter()
349        .zip(params.iter())
350        .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
351        .collect();
352    indexed.sort_by(|a, b| b.0.cmp(&a.0));
353    for (pos, field_name) in indexed {
354        let placeholder = format!("${}", pos);
355        let replacement = format!("${{{}}}", field_name);
356        result = result.replace(&placeholder, &replacement);
357    }
358    result
359}