Skip to main content

scythe_codegen/backends/
typescript_pg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
16use crate::singularize;
17
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-pg.toml");
19
20pub struct TypescriptPgBackend {
21    manifest: BackendManifest,
22    row_type: TsRowType,
23}
24
25impl TypescriptPgBackend {
26    pub fn new(engine: &str) -> Result<Self, ScytheError> {
27        match engine {
28            "postgresql" | "postgres" | "pg" => {}
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!(
33                        "typescript-pg only supports PostgreSQL, got engine '{}'",
34                        engine
35                    ),
36                ));
37            }
38        }
39        let manifest_path = Path::new("backends/typescript-pg/manifest.toml");
40        let manifest = if manifest_path.exists() {
41            load_manifest(manifest_path)
42                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43        } else {
44            toml::from_str(DEFAULT_MANIFEST_TOML)
45                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46        };
47        Ok(Self {
48            manifest,
49            row_type: TsRowType::default(),
50        })
51    }
52}
53
54impl CodegenBackend for TypescriptPgBackend {
55    fn name(&self) -> &str {
56        "typescript-pg"
57    }
58
59    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60        &self.manifest
61    }
62
63    fn file_header(&self) -> String {
64        let mut header =
65            "/** Auto-generated by scythe. Do not edit. */\n\nimport type { PoolClient } from \"pg\";\n"
66                .to_string();
67        if self.row_type == TsRowType::Zod {
68            header.push_str("import { z } from \"zod\";\n");
69        }
70        header
71    }
72
73    fn generate_row_struct(
74        &self,
75        query_name: &str,
76        columns: &[ResolvedColumn],
77    ) -> Result<String, ScytheError> {
78        let struct_name = row_struct_name(query_name, &self.manifest.naming);
79        if self.row_type == TsRowType::Zod {
80            return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
81        }
82        let mut out = String::new();
83        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
84        let _ = writeln!(out, "export interface {} {{", struct_name);
85        for col in columns {
86            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
87        }
88        let _ = write!(out, "}}");
89        Ok(out)
90    }
91
92    fn generate_model_struct(
93        &self,
94        table_name: &str,
95        columns: &[ResolvedColumn],
96    ) -> Result<String, ScytheError> {
97        let singular = singularize(table_name);
98        let name = to_pascal_case(&singular);
99        self.generate_row_struct(&name, columns)
100    }
101
102    fn generate_query_fn(
103        &self,
104        analyzed: &AnalyzedQuery,
105        struct_name: &str,
106        _columns: &[ResolvedColumn],
107        params: &[ResolvedParam],
108    ) -> Result<String, ScytheError> {
109        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110        let mut out = String::new();
111
112        // Build parameter list
113        let param_list = params
114            .iter()
115            .map(|p| format!("{}: {}", p.field_name, p.full_type))
116            .collect::<Vec<_>>()
117            .join(", ");
118        let _sep = if param_list.is_empty() { "" } else { ", " };
119
120        // Clean SQL — pg uses $1, $2 positional params natively
121        let sql = super::clean_sql_with_optional(
122            &analyzed.sql,
123            &analyzed.optional_params,
124            &analyzed.params,
125        );
126
127        // Build array of param values
128        let _param_array: String = if params.is_empty() {
129            String::new()
130        } else {
131            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
132            format!(", [{}]", args.join(", "))
133        };
134
135        // Build inline params string for line-length checking
136        let inline_params = if params.is_empty() {
137            "client: PoolClient".to_string()
138        } else {
139            format!("client: PoolClient, {}", param_list)
140        };
141
142        // Helper: write a typed query call (with generic type annotation).
143        // Biome always breaks `client.query<T>(...)` to multi-line.
144        let write_typed_query = |out: &mut String,
145                                 prefix: &str,
146                                 type_name: &str,
147                                 sql: &str,
148                                 params: &[ResolvedParam]| {
149            let _ = writeln!(out, "{}client.query<{}>(", prefix, type_name);
150            let _ = writeln!(out, "\t\t`{}`,", sql);
151            if !params.is_empty() {
152                let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
153                let _ = writeln!(out, "\t\t[{}],", args.join(", "));
154            }
155            let _ = writeln!(out, "\t);");
156        };
157
158        // Helper: write an untyped query call. Inline if short, multi-line if long.
159        let write_untyped_query =
160            |out: &mut String, prefix: &str, sql: &str, params: &[ResolvedParam]| {
161                let param_str = if params.is_empty() {
162                    String::new()
163                } else {
164                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
165                    format!(", [{}]", args.join(", "))
166                };
167                let oneliner = format!("{}client.query(`{}`{});", prefix, sql, param_str);
168                // Use tab width of 4 for line length estimation
169                let estimated_len = oneliner.replace('\t', "    ").len();
170                if estimated_len <= 80 {
171                    let _ = writeln!(out, "{}", oneliner);
172                } else {
173                    let _ = writeln!(out, "{}client.query(", prefix);
174                    let _ = writeln!(out, "\t\t`{}`,", sql);
175                    if !params.is_empty() {
176                        let args: Vec<String> =
177                            params.iter().map(|p| p.field_name.clone()).collect();
178                        let _ = writeln!(out, "\t\t[{}],", args.join(", "));
179                    }
180                    let _ = writeln!(out, "\t);");
181                }
182            };
183
184        // Helper: write function signature, inline or multi-line based on length
185        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
186            let oneliner = format!(
187                "export async function {}({}): {} {{",
188                name, params_inline, ret
189            );
190            if oneliner.len() <= 80 {
191                let _ = writeln!(out, "{}", oneliner);
192            } else {
193                let mut parts = vec!["\tclient: PoolClient".to_string()];
194                for p in params {
195                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
196                }
197                let _ = writeln!(out, "export async function {}(", name);
198                for part in &parts {
199                    let _ = writeln!(out, "{},", part);
200                }
201                let _ = writeln!(out, "): {} {{", ret);
202            }
203        };
204
205        match &analyzed.command {
206            QueryCommand::One => {
207                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
208                let ret = format!("Promise<{} | null>", struct_name);
209                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
210                write_typed_query(
211                    &mut out,
212                    "\tconst { rows } = await ",
213                    struct_name,
214                    &sql,
215                    params,
216                );
217                let _ = writeln!(out, "\treturn rows[0] ?? null;");
218                let _ = write!(out, "}}");
219            }
220            QueryCommand::Many => {
221                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
222                let ret = format!("Promise<{}[]>", struct_name);
223                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
224                write_typed_query(
225                    &mut out,
226                    "\tconst { rows } = await ",
227                    struct_name,
228                    &sql,
229                    params,
230                );
231                let _ = writeln!(out, "\treturn rows;");
232                let _ = write!(out, "}}");
233            }
234            QueryCommand::Batch => {
235                let batch_fn_name = format!("{}Batch", func_name);
236                // Build params interface
237                if params.len() > 1 {
238                    let params_type_name = format!("{}BatchParams", struct_name);
239                    let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
240                    let _ = writeln!(out, "export interface {} {{", params_type_name);
241                    for p in params {
242                        let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
243                    }
244                    let _ = writeln!(out, "}}");
245                    let _ = writeln!(out);
246                    let _ = writeln!(
247                        out,
248                        "/** Execute {} for each item in the batch within a transaction. */",
249                        analyzed.name
250                    );
251                    let batch_params = format!("client: PoolClient, items: {}[]", params_type_name);
252                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
253                    let _ = writeln!(out, "\ttry {{");
254                    let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
255                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
256                    let _ = writeln!(out, "\t\t\tawait client.query(");
257                    let _ = writeln!(out, "\t\t\t\t`{}`,", sql);
258                    let args: Vec<String> = params
259                        .iter()
260                        .map(|p| format!("item.{}", p.field_name))
261                        .collect();
262                    let _ = writeln!(out, "\t\t\t\t[{}],", args.join(", "));
263                    let _ = writeln!(out, "\t\t\t);");
264                    let _ = writeln!(out, "\t\t}}");
265                    let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
266                    let _ = writeln!(out, "\t}} catch (error) {{");
267                    let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
268                    let _ = writeln!(out, "\t\tthrow error;");
269                    let _ = writeln!(out, "\t}}");
270                    let _ = write!(out, "}}");
271                } else if params.len() == 1 {
272                    let _ = writeln!(
273                        out,
274                        "/** Execute {} for each item in the batch within a transaction. */",
275                        analyzed.name
276                    );
277                    let batch_params =
278                        format!("client: PoolClient, items: {}[]", params[0].full_type);
279                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
280                    let _ = writeln!(out, "\ttry {{");
281                    let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
282                    let _ = writeln!(out, "\t\tfor (const item of items) {{");
283                    let _ = writeln!(out, "\t\t\tawait client.query(`{}`, [item]);", sql);
284                    let _ = writeln!(out, "\t\t}}");
285                    let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
286                    let _ = writeln!(out, "\t}} catch (error) {{");
287                    let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
288                    let _ = writeln!(out, "\t\tthrow error;");
289                    let _ = writeln!(out, "\t}}");
290                    let _ = write!(out, "}}");
291                } else {
292                    let _ = writeln!(
293                        out,
294                        "/** Execute {} for each item in the batch within a transaction. */",
295                        analyzed.name
296                    );
297                    write_fn_sig(
298                        &mut out,
299                        &batch_fn_name,
300                        "client: PoolClient, count: number",
301                        "Promise<void>",
302                    );
303                    let _ = writeln!(out, "\ttry {{");
304                    let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
305                    let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
306                    let _ = writeln!(out, "\t\t\tawait client.query(`{}`);", sql);
307                    let _ = writeln!(out, "\t\t}}");
308                    let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
309                    let _ = writeln!(out, "\t}} catch (error) {{");
310                    let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
311                    let _ = writeln!(out, "\t\tthrow error;");
312                    let _ = writeln!(out, "\t}}");
313                    let _ = write!(out, "}}");
314                }
315            }
316            QueryCommand::Exec => {
317                let _ = writeln!(out, "/** Execute a query returning no rows. */");
318                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
319                write_untyped_query(&mut out, "\tawait ", &sql, params);
320                let _ = write!(out, "}}");
321            }
322            QueryCommand::ExecResult | QueryCommand::ExecRows => {
323                let _ = writeln!(
324                    out,
325                    "/** Execute a query and return the number of affected rows. */"
326                );
327                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
328                write_untyped_query(&mut out, "\tconst result = await ", &sql, params);
329                let _ = writeln!(out, "\treturn result.rowCount ?? 0;");
330                let _ = write!(out, "}}");
331            }
332        }
333
334        Ok(out)
335    }
336
337    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
338        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
339        if self.row_type == TsRowType::Zod {
340            return Ok(generate_zod_enum(&type_name, &enum_info.values));
341        }
342        let mut out = String::new();
343        let _ = writeln!(out, "export enum {} {{", type_name);
344        for value in &enum_info.values {
345            let variant = enum_variant_name(value, &self.manifest.naming);
346            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
347        }
348        let _ = write!(out, "}}");
349        Ok(out)
350    }
351
352    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
353        let name = to_pascal_case(&composite.sql_name);
354        let mut out = String::new();
355        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
356        let _ = writeln!(out, "export interface {} {{", name);
357        if composite.fields.is_empty() {
358            // empty interface
359        } else {
360            for field in &composite.fields {
361                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
362                    .map(|t| t.into_owned())
363                    .map_err(|e| {
364                        ScytheError::new(
365                            ErrorCode::InternalError,
366                            format!("composite field type error: {}", e),
367                        )
368                    })?;
369                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
370            }
371        }
372        let _ = write!(out, "}}");
373        Ok(out)
374    }
375
376    fn apply_options(
377        &mut self,
378        options: &std::collections::HashMap<String, String>,
379    ) -> Result<(), ScytheError> {
380        if let Some(value) = options.get("row_type") {
381            self.row_type = TsRowType::from_option(value)?;
382        }
383        Ok(())
384    }
385}