Skip to main content

scythe_codegen/backends/
typescript_pg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-pg.toml");
18
19pub struct TypescriptPgBackend {
20    manifest: BackendManifest,
21}
22
23impl TypescriptPgBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "postgresql" | "postgres" | "pg" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "typescript-pg only supports PostgreSQL, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/typescript-pg/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49impl CodegenBackend for TypescriptPgBackend {
50    fn name(&self) -> &str {
51        "typescript-pg"
52    }
53
54    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
55        &self.manifest
56    }
57
58    fn file_header(&self) -> String {
59        "/** Auto-generated by scythe. Do not edit. */\n\nimport type { PoolClient } from \"pg\";\n"
60            .to_string()
61    }
62
63    fn generate_row_struct(
64        &self,
65        query_name: &str,
66        columns: &[ResolvedColumn],
67    ) -> Result<String, ScytheError> {
68        let struct_name = row_struct_name(query_name, &self.manifest.naming);
69        let mut out = String::new();
70        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
71        let _ = writeln!(out, "export interface {} {{", struct_name);
72        for col in columns {
73            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
74        }
75        let _ = write!(out, "}}");
76        Ok(out)
77    }
78
79    fn generate_model_struct(
80        &self,
81        table_name: &str,
82        columns: &[ResolvedColumn],
83    ) -> Result<String, ScytheError> {
84        let singular = singularize(table_name);
85        let name = to_pascal_case(&singular);
86        self.generate_row_struct(&name, columns)
87    }
88
89    fn generate_query_fn(
90        &self,
91        analyzed: &AnalyzedQuery,
92        struct_name: &str,
93        _columns: &[ResolvedColumn],
94        params: &[ResolvedParam],
95    ) -> Result<String, ScytheError> {
96        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
97        let mut out = String::new();
98
99        // Build parameter list
100        let param_list = params
101            .iter()
102            .map(|p| format!("{}: {}", p.field_name, p.full_type))
103            .collect::<Vec<_>>()
104            .join(", ");
105        let _sep = if param_list.is_empty() { "" } else { ", " };
106
107        // Clean SQL — pg uses $1, $2 positional params natively
108        let sql = super::clean_sql(&analyzed.sql);
109
110        // Build array of param values
111        let _param_array: String = if params.is_empty() {
112            String::new()
113        } else {
114            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
115            format!(", [{}]", args.join(", "))
116        };
117
118        // Build inline params string for line-length checking
119        let inline_params = if params.is_empty() {
120            "client: PoolClient".to_string()
121        } else {
122            format!("client: PoolClient, {}", param_list)
123        };
124
125        // Helper: write a typed query call (with generic type annotation).
126        // Biome always breaks `client.query<T>(...)` to multi-line.
127        let write_typed_query = |out: &mut String,
128                                 prefix: &str,
129                                 type_name: &str,
130                                 sql: &str,
131                                 params: &[ResolvedParam]| {
132            let _ = writeln!(out, "{}client.query<{}>(", prefix, type_name);
133            let _ = writeln!(out, "\t\t`{}`,", sql);
134            if !params.is_empty() {
135                let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
136                let _ = writeln!(out, "\t\t[{}],", args.join(", "));
137            }
138            let _ = writeln!(out, "\t);");
139        };
140
141        // Helper: write an untyped query call. Inline if short, multi-line if long.
142        let write_untyped_query =
143            |out: &mut String, prefix: &str, sql: &str, params: &[ResolvedParam]| {
144                let param_str = if params.is_empty() {
145                    String::new()
146                } else {
147                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
148                    format!(", [{}]", args.join(", "))
149                };
150                let oneliner = format!("{}client.query(`{}`{});", prefix, sql, param_str);
151                // Use tab width of 4 for line length estimation
152                let estimated_len = oneliner.replace('\t', "    ").len();
153                if estimated_len <= 80 {
154                    let _ = writeln!(out, "{}", oneliner);
155                } else {
156                    let _ = writeln!(out, "{}client.query(", prefix);
157                    let _ = writeln!(out, "\t\t`{}`,", sql);
158                    if !params.is_empty() {
159                        let args: Vec<String> =
160                            params.iter().map(|p| p.field_name.clone()).collect();
161                        let _ = writeln!(out, "\t\t[{}],", args.join(", "));
162                    }
163                    let _ = writeln!(out, "\t);");
164                }
165            };
166
167        // Helper: write function signature, inline or multi-line based on length
168        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
169            let oneliner = format!(
170                "export async function {}({}): {} {{",
171                name, params_inline, ret
172            );
173            if oneliner.len() <= 80 {
174                let _ = writeln!(out, "{}", oneliner);
175            } else {
176                let mut parts = vec!["\tclient: PoolClient".to_string()];
177                for p in params {
178                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
179                }
180                let _ = writeln!(out, "export async function {}(", name);
181                for part in &parts {
182                    let _ = writeln!(out, "{},", part);
183                }
184                let _ = writeln!(out, "): {} {{", ret);
185            }
186        };
187
188        match &analyzed.command {
189            QueryCommand::One => {
190                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
191                let ret = format!("Promise<{} | null>", struct_name);
192                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
193                write_typed_query(
194                    &mut out,
195                    "\tconst { rows } = await ",
196                    struct_name,
197                    &sql,
198                    params,
199                );
200                let _ = writeln!(out, "\treturn rows[0] ?? null;");
201                let _ = write!(out, "}}");
202            }
203            QueryCommand::Many | QueryCommand::Batch => {
204                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
205                let ret = format!("Promise<{}[]>", struct_name);
206                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
207                write_typed_query(
208                    &mut out,
209                    "\tconst { rows } = await ",
210                    struct_name,
211                    &sql,
212                    params,
213                );
214                let _ = writeln!(out, "\treturn rows;");
215                let _ = write!(out, "}}");
216            }
217            QueryCommand::Exec => {
218                let _ = writeln!(out, "/** Execute a query returning no rows. */");
219                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
220                write_untyped_query(&mut out, "\tawait ", &sql, params);
221                let _ = write!(out, "}}");
222            }
223            QueryCommand::ExecResult | QueryCommand::ExecRows => {
224                let _ = writeln!(
225                    out,
226                    "/** Execute a query and return the number of affected rows. */"
227                );
228                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
229                write_untyped_query(&mut out, "\tconst result = await ", &sql, params);
230                let _ = writeln!(out, "\treturn result.rowCount ?? 0;");
231                let _ = write!(out, "}}");
232            }
233        }
234
235        Ok(out)
236    }
237
238    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
239        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
240        let mut out = String::new();
241        let _ = writeln!(out, "export enum {} {{", type_name);
242        for value in &enum_info.values {
243            let variant = enum_variant_name(value, &self.manifest.naming);
244            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
245        }
246        let _ = write!(out, "}}");
247        Ok(out)
248    }
249
250    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
251        let name = to_pascal_case(&composite.sql_name);
252        let mut out = String::new();
253        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
254        let _ = writeln!(out, "export interface {} {{", name);
255        if composite.fields.is_empty() {
256            // empty interface
257        } else {
258            for field in &composite.fields {
259                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
260                    .map(|t| t.into_owned())
261                    .map_err(|e| {
262                        ScytheError::new(
263                            ErrorCode::InternalError,
264                            format!("composite field type error: {}", e),
265                        )
266                    })?;
267                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
268            }
269        }
270        let _ = write!(out, "}}");
271        Ok(out)
272    }
273}