Skip to main content

scythe_codegen/backends/
typescript_pg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-pg.toml");
18
19pub struct TypescriptPgBackend {
20    manifest: BackendManifest,
21}
22
23impl TypescriptPgBackend {
24    pub fn new() -> Result<Self, ScytheError> {
25        let manifest_path = Path::new("backends/typescript-pg/manifest.toml");
26        let manifest = if manifest_path.exists() {
27            load_manifest(manifest_path)
28                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29        } else {
30            toml::from_str(DEFAULT_MANIFEST_TOML)
31                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32        };
33        Ok(Self { manifest })
34    }
35
36    pub fn manifest(&self) -> &BackendManifest {
37        &self.manifest
38    }
39}
40
41impl CodegenBackend for TypescriptPgBackend {
42    fn name(&self) -> &str {
43        "typescript-pg"
44    }
45
46    fn file_header(&self) -> String {
47        "/** Auto-generated by scythe. Do not edit. */\n\nimport type { PoolClient } from \"pg\";\n"
48            .to_string()
49    }
50
51    fn generate_row_struct(
52        &self,
53        query_name: &str,
54        columns: &[ResolvedColumn],
55    ) -> Result<String, ScytheError> {
56        let struct_name = row_struct_name(query_name, &self.manifest.naming);
57        let mut out = String::new();
58        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
59        let _ = writeln!(out, "export interface {} {{", struct_name);
60        for col in columns {
61            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
62        }
63        let _ = write!(out, "}}");
64        Ok(out)
65    }
66
67    fn generate_model_struct(
68        &self,
69        table_name: &str,
70        columns: &[ResolvedColumn],
71    ) -> Result<String, ScytheError> {
72        let singular = singularize(table_name);
73        let name = to_pascal_case(&singular);
74        self.generate_row_struct(&name, columns)
75    }
76
77    fn generate_query_fn(
78        &self,
79        analyzed: &AnalyzedQuery,
80        struct_name: &str,
81        _columns: &[ResolvedColumn],
82        params: &[ResolvedParam],
83    ) -> Result<String, ScytheError> {
84        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
85        let mut out = String::new();
86
87        // Build parameter list
88        let param_list = params
89            .iter()
90            .map(|p| format!("{}: {}", p.field_name, p.full_type))
91            .collect::<Vec<_>>()
92            .join(", ");
93        let _sep = if param_list.is_empty() { "" } else { ", " };
94
95        // Clean SQL — pg uses $1, $2 positional params natively
96        let sql = super::clean_sql(&analyzed.sql);
97
98        // Build array of param values
99        let _param_array: String = if params.is_empty() {
100            String::new()
101        } else {
102            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
103            format!(", [{}]", args.join(", "))
104        };
105
106        // Build inline params string for line-length checking
107        let inline_params = if params.is_empty() {
108            "client: PoolClient".to_string()
109        } else {
110            format!("client: PoolClient, {}", param_list)
111        };
112
113        // Helper: write a typed query call (with generic type annotation).
114        // Biome always breaks `client.query<T>(...)` to multi-line.
115        let write_typed_query = |out: &mut String,
116                                 prefix: &str,
117                                 type_name: &str,
118                                 sql: &str,
119                                 params: &[ResolvedParam]| {
120            let _ = writeln!(out, "{}client.query<{}>(", prefix, type_name);
121            let _ = writeln!(out, "\t\t\"{}\",", sql);
122            if !params.is_empty() {
123                let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
124                let _ = writeln!(out, "\t\t[{}],", args.join(", "));
125            }
126            let _ = writeln!(out, "\t);");
127        };
128
129        // Helper: write an untyped query call. Inline if short, multi-line if long.
130        let write_untyped_query =
131            |out: &mut String, prefix: &str, sql: &str, params: &[ResolvedParam]| {
132                let param_str = if params.is_empty() {
133                    String::new()
134                } else {
135                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
136                    format!(", [{}]", args.join(", "))
137                };
138                let oneliner = format!("{}client.query(\"{}\"{});", prefix, sql, param_str);
139                // Use tab width of 4 for line length estimation
140                let estimated_len = oneliner.replace('\t', "    ").len();
141                if estimated_len <= 80 {
142                    let _ = writeln!(out, "{}", oneliner);
143                } else {
144                    let _ = writeln!(out, "{}client.query(", prefix);
145                    let _ = writeln!(out, "\t\t\"{}\",", sql);
146                    if !params.is_empty() {
147                        let args: Vec<String> =
148                            params.iter().map(|p| p.field_name.clone()).collect();
149                        let _ = writeln!(out, "\t\t[{}],", args.join(", "));
150                    }
151                    let _ = writeln!(out, "\t);");
152                }
153            };
154
155        // Helper: write function signature, inline or multi-line based on length
156        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
157            let oneliner = format!(
158                "export async function {}({}): {} {{",
159                name, params_inline, ret
160            );
161            if oneliner.len() <= 80 {
162                let _ = writeln!(out, "{}", oneliner);
163            } else {
164                let mut parts = vec!["\tclient: PoolClient".to_string()];
165                for p in params {
166                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
167                }
168                let _ = writeln!(out, "export async function {}(", name);
169                for part in &parts {
170                    let _ = writeln!(out, "{},", part);
171                }
172                let _ = writeln!(out, "): {} {{", ret);
173            }
174        };
175
176        match &analyzed.command {
177            QueryCommand::One => {
178                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
179                let ret = format!("Promise<{} | null>", struct_name);
180                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
181                write_typed_query(
182                    &mut out,
183                    "\tconst { rows } = await ",
184                    struct_name,
185                    &sql,
186                    params,
187                );
188                let _ = writeln!(out, "\treturn rows[0] ?? null;");
189                let _ = write!(out, "}}");
190            }
191            QueryCommand::Many | QueryCommand::Batch => {
192                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
193                let ret = format!("Promise<{}[]>", struct_name);
194                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
195                write_typed_query(
196                    &mut out,
197                    "\tconst { rows } = await ",
198                    struct_name,
199                    &sql,
200                    params,
201                );
202                let _ = writeln!(out, "\treturn rows;");
203                let _ = write!(out, "}}");
204            }
205            QueryCommand::Exec => {
206                let _ = writeln!(out, "/** Execute a query returning no rows. */");
207                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
208                write_untyped_query(&mut out, "\tawait ", &sql, params);
209                let _ = write!(out, "}}");
210            }
211            QueryCommand::ExecResult | QueryCommand::ExecRows => {
212                let _ = writeln!(
213                    out,
214                    "/** Execute a query and return the number of affected rows. */"
215                );
216                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
217                write_untyped_query(&mut out, "\tconst result = await ", &sql, params);
218                let _ = writeln!(out, "\treturn result.rowCount ?? 0;");
219                let _ = write!(out, "}}");
220            }
221        }
222
223        Ok(out)
224    }
225
226    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
227        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
228        let mut out = String::new();
229        let _ = writeln!(out, "export enum {} {{", type_name);
230        for value in &enum_info.values {
231            let variant = enum_variant_name(value, &self.manifest.naming);
232            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
233        }
234        let _ = write!(out, "}}");
235        Ok(out)
236    }
237
238    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
239        let name = to_pascal_case(&composite.sql_name);
240        let mut out = String::new();
241        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
242        let _ = writeln!(out, "export interface {} {{", name);
243        if composite.fields.is_empty() {
244            // empty interface
245        } else {
246            for field in &composite.fields {
247                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
248                    .map(|t| t.into_owned())
249                    .map_err(|e| {
250                        ScytheError::new(
251                            ErrorCode::InternalError,
252                            format!("composite field type error: {}", e),
253                        )
254                    })?;
255                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
256            }
257        }
258        let _ = write!(out, "}}");
259        Ok(out)
260    }
261}