1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
16use crate::singularize;
17
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-postgres.toml");
19
20pub struct TypescriptPostgresBackend {
21 manifest: BackendManifest,
22 row_type: TsRowType,
23}
24
25impl TypescriptPostgresBackend {
26 pub fn new(engine: &str) -> Result<Self, ScytheError> {
27 match engine {
28 "postgresql" | "postgres" | "pg" => {}
29 _ => {
30 return Err(ScytheError::new(
31 ErrorCode::InternalError,
32 format!(
33 "typescript-postgres only supports PostgreSQL, got engine '{}'",
34 engine
35 ),
36 ));
37 }
38 }
39 let manifest_path = Path::new("backends/typescript-postgres/manifest.toml");
40 let manifest = if manifest_path.exists() {
41 load_manifest(manifest_path)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 } else {
44 toml::from_str(DEFAULT_MANIFEST_TOML)
45 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46 };
47 Ok(Self {
48 manifest,
49 row_type: TsRowType::default(),
50 })
51 }
52}
53
54impl CodegenBackend for TypescriptPostgresBackend {
55 fn name(&self) -> &str {
56 "typescript-postgres"
57 }
58
59 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60 &self.manifest
61 }
62
63 fn file_header(&self) -> String {
64 let mut header =
65 "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Sql } from \"postgres\";\n"
66 .to_string();
67 if self.row_type == TsRowType::Zod {
68 header.push_str("import { z } from \"zod\";\n");
69 }
70 header
71 }
72
73 fn generate_row_struct(
74 &self,
75 query_name: &str,
76 columns: &[ResolvedColumn],
77 ) -> Result<String, ScytheError> {
78 let struct_name = row_struct_name(query_name, &self.manifest.naming);
79 if self.row_type == TsRowType::Zod {
80 return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
81 }
82 let mut out = String::new();
83 let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
84 let _ = writeln!(out, "export interface {} {{", struct_name);
85 for col in columns {
86 let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
87 }
88 let _ = write!(out, "}}");
89 Ok(out)
90 }
91
92 fn generate_model_struct(
93 &self,
94 table_name: &str,
95 columns: &[ResolvedColumn],
96 ) -> Result<String, ScytheError> {
97 let singular = singularize(table_name);
98 let name = to_pascal_case(&singular);
99 self.generate_row_struct(&name, columns)
100 }
101
102 fn generate_query_fn(
103 &self,
104 analyzed: &AnalyzedQuery,
105 struct_name: &str,
106 _columns: &[ResolvedColumn],
107 params: &[ResolvedParam],
108 ) -> Result<String, ScytheError> {
109 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110 let mut out = String::new();
111
112 let param_list = params
114 .iter()
115 .map(|p| format!("{}: {}", p.field_name, p.full_type))
116 .collect::<Vec<_>>()
117 .join(", ");
118 let _sep = if param_list.is_empty() { "" } else { ", " };
119
120 let sql_clean = super::clean_sql_with_optional(
122 &analyzed.sql,
123 &analyzed.optional_params,
124 &analyzed.params,
125 );
126 let sql_template = rewrite_params_template(&sql_clean, analyzed, params);
127
128 let inline_params = if params.is_empty() {
130 "sql: Sql".to_string()
131 } else {
132 format!("sql: Sql, {}", param_list)
133 };
134
135 let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
139 let oneliner = format!(
140 "export async function {}({}): {} {{",
141 name, params_inline, ret
142 );
143 if oneliner.len() <= 80 {
144 let _ = writeln!(out, "{}", oneliner);
145 } else {
146 let mut parts = vec!["\tsql: Sql".to_string()];
148 for p in params {
149 parts.push(format!("\t{}: {}", p.field_name, p.full_type));
150 }
151 let _ = writeln!(out, "export async function {}(", name);
152 for part in &parts {
153 let _ = writeln!(out, "{},", part);
154 }
155 let _ = writeln!(out, "): {} {{", ret);
156 }
157 };
158
159 match &analyzed.command {
160 QueryCommand::One => {
161 let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
162 let ret = format!("Promise<{} | null>", struct_name);
163 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
164 let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
165 let _ = writeln!(out, " {}", sql_template);
166 let _ = writeln!(out, " `;");
167 let _ = writeln!(out, "\treturn rows[0] ?? null;");
168 let _ = write!(out, "}}");
169 }
170 QueryCommand::Batch => {
171 let batch_fn_name = format!("{}Batch", func_name);
172 if params.len() > 1 {
173 let params_type_name = format!("{}BatchParams", struct_name);
174 let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
175 let _ = writeln!(out, "export interface {} {{", params_type_name);
176 for p in params {
177 let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
178 }
179 let _ = writeln!(out, "}}");
180 let _ = writeln!(out);
181 let _ = writeln!(
182 out,
183 "/** Execute {} for each item in the batch within a transaction. */",
184 analyzed.name
185 );
186 let batch_params = format!("sql: Sql, items: {}[]", params_type_name);
187 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
188 let _ = writeln!(out, "\tawait sql.begin(async (tx) => {{");
189 let _ = writeln!(out, "\t\tfor (const item of items) {{");
190 let batch_sql = {
192 let mut s = sql_clean.clone();
193 let mut indexed: Vec<(i64, &str)> = analyzed
194 .params
195 .iter()
196 .zip(params.iter())
197 .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
198 .collect();
199 indexed.sort_by(|a, b| b.0.cmp(&a.0));
200 for (pos, field_name) in indexed {
201 let placeholder = format!("${}", pos);
202 let replacement = format!("${{item.{}}}", field_name);
203 s = s.replace(&placeholder, &replacement);
204 }
205 s
206 };
207 let _ = writeln!(out, "\t\t\tawait tx`");
208 let _ = writeln!(out, " {}", batch_sql);
209 let _ = writeln!(out, " `;");
210 let _ = writeln!(out, "\t\t}}");
211 let _ = writeln!(out, "\t}});");
212 let _ = write!(out, "}}");
213 } else if params.len() == 1 {
214 let _ = writeln!(
215 out,
216 "/** Execute {} for each item in the batch within a transaction. */",
217 analyzed.name
218 );
219 let batch_params = format!("sql: Sql, items: {}[]", params[0].full_type);
220 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
221 let _ = writeln!(out, "\tawait sql.begin(async (tx) => {{");
222 let _ = writeln!(out, "\t\tfor (const item of items) {{");
223 let batch_sql =
224 sql_template.replace(&format!("${{{}}}", params[0].field_name), "${item}");
225 let _ = writeln!(out, "\t\t\tawait tx`");
226 let _ = writeln!(out, " {}", batch_sql);
227 let _ = writeln!(out, " `;");
228 let _ = writeln!(out, "\t\t}}");
229 let _ = writeln!(out, "\t}});");
230 let _ = write!(out, "}}");
231 } else {
232 let _ = writeln!(
233 out,
234 "/** Execute {} for each item in the batch within a transaction. */",
235 analyzed.name
236 );
237 write_fn_sig(
238 &mut out,
239 &batch_fn_name,
240 "sql: Sql, count: number",
241 "Promise<void>",
242 );
243 let _ = writeln!(out, "\tawait sql.begin(async (tx) => {{");
244 let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
245 let _ = writeln!(out, "\t\t\tawait tx`");
246 let _ = writeln!(out, " {}", sql_template);
247 let _ = writeln!(out, " `;");
248 let _ = writeln!(out, "\t\t}}");
249 let _ = writeln!(out, "\t}});");
250 let _ = write!(out, "}}");
251 }
252 }
253 QueryCommand::Many => {
254 let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
255 let ret = format!("Promise<{}[]>", struct_name);
256 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
257 let _ = writeln!(out, "\tconst rows = await sql<{}[]>`", struct_name);
258 let _ = writeln!(out, " {}", sql_template);
259 let _ = writeln!(out, " `;");
260 let _ = writeln!(out, "\treturn rows;");
261 let _ = write!(out, "}}");
262 }
263 QueryCommand::Exec => {
264 let _ = writeln!(out, "/** Execute a query returning no rows. */");
265 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
266 let _ = writeln!(out, "\tawait sql`");
267 let _ = writeln!(out, " {}", sql_template);
268 let _ = writeln!(out, " `;");
269 let _ = write!(out, "}}");
270 }
271 QueryCommand::ExecResult | QueryCommand::ExecRows => {
272 let _ = writeln!(
273 out,
274 "/** Execute a query and return the number of affected rows. */"
275 );
276 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
277 let _ = writeln!(out, "\tconst result = await sql`");
278 let _ = writeln!(out, " {}", sql_template);
279 let _ = writeln!(out, " `;");
280 let _ = writeln!(out, "\treturn result.count;");
281 let _ = write!(out, "}}");
282 }
283 }
284
285 Ok(out)
286 }
287
288 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
289 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
290 if self.row_type == TsRowType::Zod {
291 return Ok(generate_zod_enum(&type_name, &enum_info.values));
292 }
293 let mut out = String::new();
294 let _ = writeln!(out, "export enum {} {{", type_name);
295 for value in &enum_info.values {
296 let variant = enum_variant_name(value, &self.manifest.naming);
297 let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
298 }
299 let _ = write!(out, "}}");
300 Ok(out)
301 }
302
303 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
304 let name = to_pascal_case(&composite.sql_name);
305 let mut out = String::new();
306 let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
307 let _ = writeln!(out, "export interface {} {{", name);
308 if composite.fields.is_empty() {
309 } else {
311 for field in &composite.fields {
312 let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
313 .map(|t| t.into_owned())
314 .map_err(|e| {
315 ScytheError::new(
316 ErrorCode::InternalError,
317 format!("composite field type error: {}", e),
318 )
319 })?;
320 let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
321 }
322 }
323 let _ = write!(out, "}}");
324 Ok(out)
325 }
326
327 fn apply_options(
328 &mut self,
329 options: &std::collections::HashMap<String, String>,
330 ) -> Result<(), ScytheError> {
331 if let Some(value) = options.get("row_type") {
332 self.row_type = TsRowType::from_option(value)?;
333 }
334 Ok(())
335 }
336}
337
338fn rewrite_params_template(
340 sql: &str,
341 analyzed: &AnalyzedQuery,
342 params: &[ResolvedParam],
343) -> String {
344 let mut result = sql.to_string();
345 let mut indexed: Vec<(i64, &str)> = analyzed
347 .params
348 .iter()
349 .zip(params.iter())
350 .map(|(ap, rp)| (ap.position, rp.field_name.as_str()))
351 .collect();
352 indexed.sort_by(|a, b| b.0.cmp(&a.0));
353 for (pos, field_name) in indexed {
354 let placeholder = format!("${}", pos);
355 let replacement = format!("${{{}}}", field_name);
356 result = result.replace(&placeholder, &replacement);
357 }
358 result
359}