1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_camel_case, to_pascal_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::backends::typescript_common::{TsRowType, generate_zod_enum, generate_zod_row_struct};
16use crate::singularize;
17
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-pg.toml");
19
20pub struct TypescriptPgBackend {
21 manifest: BackendManifest,
22 row_type: TsRowType,
23}
24
25impl TypescriptPgBackend {
26 pub fn new(engine: &str) -> Result<Self, ScytheError> {
27 match engine {
28 "postgresql" | "postgres" | "pg" => {}
29 _ => {
30 return Err(ScytheError::new(
31 ErrorCode::InternalError,
32 format!(
33 "typescript-pg only supports PostgreSQL, got engine '{}'",
34 engine
35 ),
36 ));
37 }
38 }
39 let manifest_path = Path::new("backends/typescript-pg/manifest.toml");
40 let manifest = if manifest_path.exists() {
41 load_manifest(manifest_path)
42 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
43 } else {
44 toml::from_str(DEFAULT_MANIFEST_TOML)
45 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
46 };
47 Ok(Self {
48 manifest,
49 row_type: TsRowType::default(),
50 })
51 }
52}
53
54impl CodegenBackend for TypescriptPgBackend {
55 fn name(&self) -> &str {
56 "typescript-pg"
57 }
58
59 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
60 &self.manifest
61 }
62
63 fn file_header(&self) -> String {
64 let mut header =
65 "/** Auto-generated by scythe. Do not edit. */\n\nimport type { PoolClient } from \"pg\";\n"
66 .to_string();
67 if self.row_type == TsRowType::Zod {
68 header.push_str("import { z } from \"zod\";\n");
69 }
70 header
71 }
72
73 fn generate_row_struct(
74 &self,
75 query_name: &str,
76 columns: &[ResolvedColumn],
77 ) -> Result<String, ScytheError> {
78 let struct_name = row_struct_name(query_name, &self.manifest.naming);
79 if self.row_type == TsRowType::Zod {
80 return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
81 }
82 let mut out = String::new();
83 let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
84 let _ = writeln!(out, "export interface {} {{", struct_name);
85 for col in columns {
86 let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
87 }
88 let _ = write!(out, "}}");
89 Ok(out)
90 }
91
92 fn generate_model_struct(
93 &self,
94 table_name: &str,
95 columns: &[ResolvedColumn],
96 ) -> Result<String, ScytheError> {
97 let singular = singularize(table_name);
98 let name = to_pascal_case(&singular);
99 self.generate_row_struct(&name, columns)
100 }
101
102 fn generate_query_fn(
103 &self,
104 analyzed: &AnalyzedQuery,
105 struct_name: &str,
106 _columns: &[ResolvedColumn],
107 params: &[ResolvedParam],
108 ) -> Result<String, ScytheError> {
109 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110 let mut out = String::new();
111
112 let param_list = params
114 .iter()
115 .map(|p| format!("{}: {}", p.field_name, p.full_type))
116 .collect::<Vec<_>>()
117 .join(", ");
118 let _sep = if param_list.is_empty() { "" } else { ", " };
119
120 let sql = super::clean_sql_with_optional(
122 &analyzed.sql,
123 &analyzed.optional_params,
124 &analyzed.params,
125 );
126
127 let _param_array: String = if params.is_empty() {
129 String::new()
130 } else {
131 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
132 format!(", [{}]", args.join(", "))
133 };
134
135 let inline_params = if params.is_empty() {
137 "client: PoolClient".to_string()
138 } else {
139 format!("client: PoolClient, {}", param_list)
140 };
141
142 let write_typed_query = |out: &mut String,
145 prefix: &str,
146 type_name: &str,
147 sql: &str,
148 params: &[ResolvedParam]| {
149 let _ = writeln!(out, "{}client.query<{}>(", prefix, type_name);
150 let _ = writeln!(out, "\t\t`{}`,", sql);
151 if !params.is_empty() {
152 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
153 let _ = writeln!(out, "\t\t[{}],", args.join(", "));
154 }
155 let _ = writeln!(out, "\t);");
156 };
157
158 let write_untyped_query =
160 |out: &mut String, prefix: &str, sql: &str, params: &[ResolvedParam]| {
161 let param_str = if params.is_empty() {
162 String::new()
163 } else {
164 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
165 format!(", [{}]", args.join(", "))
166 };
167 let oneliner = format!("{}client.query(`{}`{});", prefix, sql, param_str);
168 let estimated_len = oneliner.replace('\t', " ").len();
170 if estimated_len <= 80 {
171 let _ = writeln!(out, "{}", oneliner);
172 } else {
173 let _ = writeln!(out, "{}client.query(", prefix);
174 let _ = writeln!(out, "\t\t`{}`,", sql);
175 if !params.is_empty() {
176 let args: Vec<String> =
177 params.iter().map(|p| p.field_name.clone()).collect();
178 let _ = writeln!(out, "\t\t[{}],", args.join(", "));
179 }
180 let _ = writeln!(out, "\t);");
181 }
182 };
183
184 let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
186 let oneliner = format!(
187 "export async function {}({}): {} {{",
188 name, params_inline, ret
189 );
190 if oneliner.len() <= 80 {
191 let _ = writeln!(out, "{}", oneliner);
192 } else {
193 let mut parts = vec!["\tclient: PoolClient".to_string()];
194 for p in params {
195 parts.push(format!("\t{}: {}", p.field_name, p.full_type));
196 }
197 let _ = writeln!(out, "export async function {}(", name);
198 for part in &parts {
199 let _ = writeln!(out, "{},", part);
200 }
201 let _ = writeln!(out, "): {} {{", ret);
202 }
203 };
204
205 match &analyzed.command {
206 QueryCommand::One => {
207 let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
208 let ret = format!("Promise<{} | null>", struct_name);
209 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
210 write_typed_query(
211 &mut out,
212 "\tconst { rows } = await ",
213 struct_name,
214 &sql,
215 params,
216 );
217 let _ = writeln!(out, "\treturn rows[0] ?? null;");
218 let _ = write!(out, "}}");
219 }
220 QueryCommand::Many => {
221 let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
222 let ret = format!("Promise<{}[]>", struct_name);
223 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
224 write_typed_query(
225 &mut out,
226 "\tconst { rows } = await ",
227 struct_name,
228 &sql,
229 params,
230 );
231 let _ = writeln!(out, "\treturn rows;");
232 let _ = write!(out, "}}");
233 }
234 QueryCommand::Batch => {
235 let batch_fn_name = format!("{}Batch", func_name);
236 if params.len() > 1 {
238 let params_type_name = format!("{}BatchParams", struct_name);
239 let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
240 let _ = writeln!(out, "export interface {} {{", params_type_name);
241 for p in params {
242 let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
243 }
244 let _ = writeln!(out, "}}");
245 let _ = writeln!(out);
246 let _ = writeln!(
247 out,
248 "/** Execute {} for each item in the batch within a transaction. */",
249 analyzed.name
250 );
251 let batch_params = format!("client: PoolClient, items: {}[]", params_type_name);
252 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
253 let _ = writeln!(out, "\ttry {{");
254 let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
255 let _ = writeln!(out, "\t\tfor (const item of items) {{");
256 let _ = writeln!(out, "\t\t\tawait client.query(");
257 let _ = writeln!(out, "\t\t\t\t`{}`,", sql);
258 let args: Vec<String> = params
259 .iter()
260 .map(|p| format!("item.{}", p.field_name))
261 .collect();
262 let _ = writeln!(out, "\t\t\t\t[{}],", args.join(", "));
263 let _ = writeln!(out, "\t\t\t);");
264 let _ = writeln!(out, "\t\t}}");
265 let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
266 let _ = writeln!(out, "\t}} catch (error) {{");
267 let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
268 let _ = writeln!(out, "\t\tthrow error;");
269 let _ = writeln!(out, "\t}}");
270 let _ = write!(out, "}}");
271 } else if params.len() == 1 {
272 let _ = writeln!(
273 out,
274 "/** Execute {} for each item in the batch within a transaction. */",
275 analyzed.name
276 );
277 let batch_params =
278 format!("client: PoolClient, items: {}[]", params[0].full_type);
279 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "Promise<void>");
280 let _ = writeln!(out, "\ttry {{");
281 let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
282 let _ = writeln!(out, "\t\tfor (const item of items) {{");
283 let _ = writeln!(out, "\t\t\tawait client.query(`{}`, [item]);", sql);
284 let _ = writeln!(out, "\t\t}}");
285 let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
286 let _ = writeln!(out, "\t}} catch (error) {{");
287 let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
288 let _ = writeln!(out, "\t\tthrow error;");
289 let _ = writeln!(out, "\t}}");
290 let _ = write!(out, "}}");
291 } else {
292 let _ = writeln!(
293 out,
294 "/** Execute {} for each item in the batch within a transaction. */",
295 analyzed.name
296 );
297 write_fn_sig(
298 &mut out,
299 &batch_fn_name,
300 "client: PoolClient, count: number",
301 "Promise<void>",
302 );
303 let _ = writeln!(out, "\ttry {{");
304 let _ = writeln!(out, "\t\tawait client.query(\"BEGIN\");");
305 let _ = writeln!(out, "\t\tfor (let i = 0; i < count; i++) {{");
306 let _ = writeln!(out, "\t\t\tawait client.query(`{}`);", sql);
307 let _ = writeln!(out, "\t\t}}");
308 let _ = writeln!(out, "\t\tawait client.query(\"COMMIT\");");
309 let _ = writeln!(out, "\t}} catch (error) {{");
310 let _ = writeln!(out, "\t\tawait client.query(\"ROLLBACK\");");
311 let _ = writeln!(out, "\t\tthrow error;");
312 let _ = writeln!(out, "\t}}");
313 let _ = write!(out, "}}");
314 }
315 }
316 QueryCommand::Exec => {
317 let _ = writeln!(out, "/** Execute a query returning no rows. */");
318 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<void>");
319 write_untyped_query(&mut out, "\tawait ", &sql, params);
320 let _ = write!(out, "}}");
321 }
322 QueryCommand::ExecResult | QueryCommand::ExecRows => {
323 let _ = writeln!(
324 out,
325 "/** Execute a query and return the number of affected rows. */"
326 );
327 write_fn_sig(&mut out, &func_name, &inline_params, "Promise<number>");
328 write_untyped_query(&mut out, "\tconst result = await ", &sql, params);
329 let _ = writeln!(out, "\treturn result.rowCount ?? 0;");
330 let _ = write!(out, "}}");
331 }
332 }
333
334 Ok(out)
335 }
336
337 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
338 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
339 if self.row_type == TsRowType::Zod {
340 return Ok(generate_zod_enum(&type_name, &enum_info.values));
341 }
342 let mut out = String::new();
343 let _ = writeln!(out, "export enum {} {{", type_name);
344 for value in &enum_info.values {
345 let variant = enum_variant_name(value, &self.manifest.naming);
346 let _ = writeln!(out, "\t{} = \"{}\",", variant, value);
347 }
348 let _ = write!(out, "}}");
349 Ok(out)
350 }
351
352 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
353 let name = to_pascal_case(&composite.sql_name);
354 let mut out = String::new();
355 let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
356 let _ = writeln!(out, "export interface {} {{", name);
357 if composite.fields.is_empty() {
358 } else {
360 for field in &composite.fields {
361 let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
362 .map(|t| t.into_owned())
363 .map_err(|e| {
364 ScytheError::new(
365 ErrorCode::InternalError,
366 format!("composite field type error: {}", e),
367 )
368 })?;
369 let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
370 }
371 }
372 let _ = write!(out, "}}");
373 Ok(out)
374 }
375
376 fn apply_options(
377 &mut self,
378 options: &std::collections::HashMap<String, String>,
379 ) -> Result<(), ScytheError> {
380 if let Some(value) = options.get("row_type") {
381 self.row_type = TsRowType::from_option(value)?;
382 }
383 Ok(())
384 }
385}