1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-postgres.toml");
18
19pub struct TypescriptPostgresBackend {
20 manifest: BackendManifest,
21}
22
23impl TypescriptPostgresBackend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 match engine {
26 "postgresql" | "postgres" | "pg" => {}
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "typescript-postgres only supports PostgreSQL, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 }
37 let manifest_path = Path::new("backends/typescript-postgres/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(DEFAULT_MANIFEST_TOML)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49impl CodegenBackend for TypescriptPostgresBackend {
50 fn name(&self) -> &str {
51 "typescript-postgres"
52 }
53
54 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
55 &self.manifest
56 }
57
58 fn file_header(&self) -> String {
59 "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Sql } from \"postgres\";\n"
60 .to_string()
61 }
62
63 fn generate_row_struct(
64 &self,
65 query_name: &str,
66 columns: &[ResolvedColumn],
67 ) -> Result<String, ScytheError> {
68 let struct_name = row_struct_name(query_name, &self.manifest.naming);
69 let mut out = String::new();
70 let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
71 let _ = writeln!(out, "export interface {} {{", struct_name);
72 for col in columns {
73 let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
74 }
75 let _ = write!(out, "}}");
76 Ok(out)
77 }
78
79 fn generate_model_struct(
80 &self,
81 table_name: &str,
82 columns: &[ResolvedColumn],
83 ) -> Result<String, ScytheError> {
84 let singular = singularize(table_name);
85 let name = to_pascal_case(&singular);
86 self.generate_row_struct(&name, columns)
87 }
88
89 fn generate_query_fn(
90 &self,
91 analyzed: &AnalyzedQuery,
92 struct_name: &str,
93 _columns: &[ResolvedColumn],
94 params: &[ResolvedParam],
95 ) -> Result<String, ScytheError> {
96 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
97 let mut out = String::new();
98
99 let param_list = params
101 .iter()
102 .map(|p| format!("{}: {}", p.field_name, p.full_type))
103 .collect::<Vec<_>>()
104 .join(", ");
105 let _sep = if param_list.is_empty() { "" } else { ", " };
106
107 let sql_clean = super::clean_sql(&analyzed.sql);
109 let sql_template = rewrite_params_template(&sql_clean, analyzed, params);
110
111 let inline_params = if params.is_empty() {
113 "sql: Sql".to_string()
114 } else {
115 format!("sql: Sql, {}", param_list)
116 };
117
118 let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
122 let oneliner = format!(
123 "export async function {}({}): {} {{",
124 name, params_inline, ret
125 );
126 if oneliner.len() <= 80 {
127 let _ = writeln!(out, "{}", oneliner);
128 } else {
129 let mut parts = vec!["\tsql: Sql".to_string()];
131 for p in params {
132 parts.push(format!("\t{}: {}", p.field_name, p.full_type));
133 }
134 let _ = writeln!(out, "export async function {}(", name);
135 for part in &parts {
136 let _ = writeln!(out, "{},", part);
137 }
138 let _ = writeln!(out, "): {} {{", ret);
139 }
140 };
141
142 match &analyzed.command {
143 QueryCommand::One => {
144 let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
145 let ret = format!("Promise<{} | null>", struct_name);
146 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
147 let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
148 let _ = writeln!(out, " {}", sql_template);
149 let _ = writeln!(out, " `;");
150 let _ = writeln!(out, "\treturn rows[0] ?? null;");
151 let _ = write!(out, "}}");
152 }
153 QueryCommand::Many | QueryCommand::Batch => {
154 let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
155 let ret = format!("Promise<{}[]>", struct_name);
156 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
157 let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
158 let _ = writeln!(out, " {}", sql_template);
159 let _ = writeln!(out, " `;");
160 let _ = writeln!(out, "\treturn rows;");
161 let _ = write!(out, "}}");
162 }
163 QueryCommand::Exec => {
164 let _ = writeln!(out, "/** Execute a query returning no rows. */");
165 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
166 let _ = writeln!(out, "\tawait sql`");
167 let _ = writeln!(out, " {}", sql_template);
168 let _ = writeln!(out, " `;");
169 let _ = write!(out, "}}");
170 }
171 QueryCommand::ExecResult | QueryCommand::ExecRows => {
172 let _ = writeln!(
173 out,
174 "/** Execute a query and return the number of affected rows. */"
175 );
176 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
177 let _ = writeln!(out, "\tconst result = await sql`");
178 let _ = writeln!(out, " {}", sql_template);
179 let _ = writeln!(out, " `;");
180 let _ = writeln!(out, "\treturn result.count;");
181 let _ = write!(out, "}}");
182 }
183 }
184
185 Ok(out)
186 }
187
188 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
189 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
190 let mut out = String::new();
191 let _ = writeln!(out, "export enum {} {{", type_name);
192 for value in &enum_info.values {
193 let variant = enum_variant_name(value, &self.manifest.naming);
194 let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
195 }
196 let _ = write!(out, "}}");
197 Ok(out)
198 }
199
200 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
201 let name = to_pascal_case(&composite.sql_name);
202 let mut out = String::new();
203 let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
204 let _ = writeln!(out, "export interface {} {{", name);
205 if composite.fields.is_empty() {
206 } else {
208 for field in &composite.fields {
209 let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
210 .map(|t| t.into_owned())
211 .map_err(|e| {
212 ScytheError::new(
213 ErrorCode::InternalError,
214 format!("composite field type error: {}", e),
215 )
216 })?;
217 let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
218 }
219 }
220 let _ = write!(out, "}}");
221 Ok(out)
222 }
223}
224
225fn rewrite_params_template(
227 sql: &str,
228 analyzed: &AnalyzedQuery,
229 params: &[ResolvedParam],
230) -> String {
231 let mut result = sql.to_string();
232 let mut indexed: Vec<(i64, &str)> = analyzed
234 .params
235 .iter()
236 .zip(params.iter())
237 .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
238 .collect();
239 indexed.sort_by(|a, b| b.0.cmp(&a.0));
240 for (pos, field_name) in indexed {
241 let placeholder = format!("${}", pos);
242 let replacement = format!("${{{}}}", field_name);
243 result = result.replace(&placeholder, &replacement);
244 }
245 result
246}