Skip to main content

scythe_codegen/backends/
typescript_pg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
18    include_str!("../../manifests/typescript-pg.toml");
19
20pub struct TypescriptPgBackend {
21    manifest: BackendManifest,
22}
23
24impl TypescriptPgBackend {
25    pub fn new() -> Result<Self, ScytheError> {
26        let manifest_path = Path::new("backends/typescript-pg/manifest.toml");
27        let manifest = if manifest_path.exists() {
28            load_manifest(manifest_path)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        } else {
31            toml::from_str(DEFAULT_MANIFEST_TOML)
32                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
33        };
34        Ok(Self { manifest })
35    }
36
37    pub fn manifest(&self) -> &BackendManifest {
38        &self.manifest
39    }
40}
41
42impl CodegenBackend for TypescriptPgBackend {
43    fn name(&self) -> &str {
44        "typescript-pg"
45    }
46
47    fn file_header(&self) -> String {
48        "/** Auto-generated by scythe. Do not edit. */\n\nimport type { PoolClient } from \"pg\";\n"
49            .to_string()
50    }
51
52    fn generate_row_struct(
53        &self,
54        query_name: &str,
55        columns: &[ResolvedColumn],
56    ) -> Result<String, ScytheError> {
57        let struct_name = row_struct_name(query_name, &self.manifest.naming);
58        let mut out = String::new();
59        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
60        let _ = writeln!(out, "export interface {} {{", struct_name);
61        for col in columns {
62            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
63        }
64        let _ = write!(out, "}}");
65        Ok(out)
66    }
67
68    fn generate_model_struct(
69        &self,
70        table_name: &str,
71        columns: &[ResolvedColumn],
72    ) -> Result<String, ScytheError> {
73        let singular = singularize(table_name);
74        let name = to_pascal_case(&singular);
75        self.generate_row_struct(&name, columns)
76    }
77
78    fn generate_query_fn(
79        &self,
80        analyzed: &AnalyzedQuery,
81        struct_name: &str,
82        _columns: &[ResolvedColumn],
83        params: &[ResolvedParam],
84    ) -> Result<String, ScytheError> {
85        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
86        let mut out = String::new();
87
88        // Build parameter list
89        let param_list = params
90            .iter()
91            .map(|p| format!("{}: {}", p.field_name, p.full_type))
92            .collect::<Vec<_>>()
93            .join(", ");
94        let _sep = if param_list.is_empty() { "" } else { ", " };
95
96        // Clean SQL — pg uses $1, $2 positional params natively
97        let sql = super::clean_sql(&analyzed.sql);
98
99        // Build array of param values
100        let _param_array: String = if params.is_empty() {
101            String::new()
102        } else {
103            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
104            format!(", [{}]", args.join(", "))
105        };
106
107        // Build inline params string for line-length checking
108        let inline_params = if params.is_empty() {
109            "client: PoolClient".to_string()
110        } else {
111            format!("client: PoolClient, {}", param_list)
112        };
113
114        // Helper: write a typed query call (with generic type annotation).
115        // Biome always breaks `client.query<T>(...)` to multi-line.
116        let write_typed_query = |out: &mut String,
117                                 prefix: &str,
118                                 type_name: &str,
119                                 sql: &str,
120                                 params: &[ResolvedParam]| {
121            let _ = writeln!(out, "{}client.query<{}>(", prefix, type_name);
122            let _ = writeln!(out, "\t\t\"{}\",", sql);
123            if !params.is_empty() {
124                let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
125                let _ = writeln!(out, "\t\t[{}],", args.join(", "));
126            }
127            let _ = writeln!(out, "\t);");
128        };
129
130        // Helper: write an untyped query call. Inline if short, multi-line if long.
131        let write_untyped_query =
132            |out: &mut String, prefix: &str, sql: &str, params: &[ResolvedParam]| {
133                let param_str = if params.is_empty() {
134                    String::new()
135                } else {
136                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
137                    format!(", [{}]", args.join(", "))
138                };
139                let oneliner = format!("{}client.query(\"{}\"{});", prefix, sql, param_str);
140                // Use tab width of 4 for line length estimation
141                let estimated_len = oneliner.replace('\t', "    ").len();
142                if estimated_len <= 80 {
143                    let _ = writeln!(out, "{}", oneliner);
144                } else {
145                    let _ = writeln!(out, "{}client.query(", prefix);
146                    let _ = writeln!(out, "\t\t\"{}\",", sql);
147                    if !params.is_empty() {
148                        let args: Vec<String> =
149                            params.iter().map(|p| p.field_name.clone()).collect();
150                        let _ = writeln!(out, "\t\t[{}],", args.join(", "));
151                    }
152                    let _ = writeln!(out, "\t);");
153                }
154            };
155
156        // Helper: write function signature, inline or multi-line based on length
157        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
158            let oneliner = format!(
159                "export async function {}({}): {} {{",
160                name, params_inline, ret
161            );
162            if oneliner.len() <= 80 {
163                let _ = writeln!(out, "{}", oneliner);
164            } else {
165                let mut parts = vec!["\tclient: PoolClient".to_string()];
166                for p in params {
167                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
168                }
169                let _ = writeln!(out, "export async function {}(", name);
170                for part in &parts {
171                    let _ = writeln!(out, "{},", part);
172                }
173                let _ = writeln!(out, "): {} {{", ret);
174            }
175        };
176
177        match &analyzed.command {
178            QueryCommand::One => {
179                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
180                let ret = format!("Promise<{} | null>", struct_name);
181                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
182                write_typed_query(
183                    &mut out,
184                    "\tconst { rows } = await ",
185                    struct_name,
186                    &sql,
187                    params,
188                );
189                let _ = writeln!(out, "\treturn rows[0] ?? null;");
190                let _ = write!(out, "}}");
191            }
192            QueryCommand::Many | QueryCommand::Batch => {
193                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
194                let ret = format!("Promise<{}[]>", struct_name);
195                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
196                write_typed_query(
197                    &mut out,
198                    "\tconst { rows } = await ",
199                    struct_name,
200                    &sql,
201                    params,
202                );
203                let _ = writeln!(out, "\treturn rows;");
204                let _ = write!(out, "}}");
205            }
206            QueryCommand::Exec => {
207                let _ = writeln!(out, "/** Execute a query returning no rows. */");
208                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
209                write_untyped_query(&mut out, "\tawait ", &sql, params);
210                let _ = write!(out, "}}");
211            }
212            QueryCommand::ExecResult | QueryCommand::ExecRows => {
213                let _ = writeln!(
214                    out,
215                    "/** Execute a query and return the number of affected rows. */"
216                );
217                write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
218                write_untyped_query(&mut out, "\tconst result = await ", &sql, params);
219                let _ = writeln!(out, "\treturn result.rowCount ?? 0;");
220                let _ = write!(out, "}}");
221            }
222        }
223
224        Ok(out)
225    }
226
227    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
228        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
229        let mut out = String::new();
230        let _ = writeln!(out, "export enum {} {{", type_name);
231        for value in &enum_info.values {
232            let variant = enum_variant_name(value, &self.manifest.naming);
233            let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
234        }
235        let _ = write!(out, "}}");
236        Ok(out)
237    }
238
239    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
240        let name = to_pascal_case(&composite.sql_name);
241        let mut out = String::new();
242        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
243        let _ = writeln!(out, "export interface {} {{", name);
244        if composite.fields.is_empty() {
245            // empty interface
246        } else {
247            for field in &composite.fields {
248                let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
249                    .map(|t| t.into_owned())
250                    .map_err(|e| {
251                        ScytheError::new(
252                            ErrorCode::InternalError,
253                            format!("composite field type error: {}", e),
254                        )
255                    })?;
256                let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
257            }
258        }
259        let _ = write!(out, "}}");
260        Ok(out)
261    }
262}