1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{fn_name, row_struct_name, to_camel_case, to_pascal_case};
6use scythe_backend::types::resolve_type;
7
8use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
9use scythe_core::errors::{ErrorCode, ScytheError};
10use scythe_core::parser::QueryCommand;
11
12use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
13use crate::backends::typescript_common::{TsRowType, generate_zod_row_struct};
14use crate::singularize;
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-duckdb.toml");
17
18pub struct TypescriptDuckdbBackend {
19 manifest: BackendManifest,
20 row_type: TsRowType,
21}
22
23impl TypescriptDuckdbBackend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 match engine {
26 "duckdb" => {}
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "typescript-duckdb only supports DuckDB, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 }
37 let manifest_path = Path::new("backends/typescript-duckdb/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(DEFAULT_MANIFEST_TOML)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self {
46 manifest,
47 row_type: TsRowType::default(),
48 })
49 }
50}
51
52impl CodegenBackend for TypescriptDuckdbBackend {
53 fn name(&self) -> &str {
54 "typescript-duckdb"
55 }
56
57 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
58 &self.manifest
59 }
60
61 fn supported_engines(&self) -> &[&str] {
62 &["duckdb"]
63 }
64
65 fn file_header(&self) -> String {
66 let mut header =
67 "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Connection } from \"@duckdb/node-api\";\n"
68 .to_string();
69 if self.row_type == TsRowType::Zod {
70 header.push_str("import { z } from \"zod\";\n");
71 }
72 header
73 }
74
75 fn generate_row_struct(
76 &self,
77 query_name: &str,
78 columns: &[ResolvedColumn],
79 ) -> Result<String, ScytheError> {
80 let struct_name = row_struct_name(query_name, &self.manifest.naming);
81 if self.row_type == TsRowType::Zod {
82 return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
83 }
84 let mut out = String::new();
85 let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
86 let _ = writeln!(out, "export interface {} {{", struct_name);
87 for col in columns {
88 let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
89 }
90 let _ = write!(out, "}}");
91 Ok(out)
92 }
93
94 fn generate_model_struct(
95 &self,
96 table_name: &str,
97 columns: &[ResolvedColumn],
98 ) -> Result<String, ScytheError> {
99 let singular = singularize(table_name);
100 let name = to_pascal_case(&singular);
101 self.generate_row_struct(&name, columns)
102 }
103
104 fn generate_query_fn(
105 &self,
106 analyzed: &AnalyzedQuery,
107 struct_name: &str,
108 _columns: &[ResolvedColumn],
109 params: &[ResolvedParam],
110 ) -> Result<String, ScytheError> {
111 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
112 let mut out = String::new();
113
114 let param_list = params
115 .iter()
116 .map(|p| format!("{}: {}", p.field_name, p.full_type))
117 .collect::<Vec<_>>()
118 .join(", ");
119
120 let sql = super::clean_sql_with_optional(
121 &analyzed.sql,
122 &analyzed.optional_params,
123 &analyzed.params,
124 );
125
126 let inline_params = if params.is_empty() {
127 "conn: Connection".to_string()
128 } else {
129 format!("conn: Connection, {}", param_list)
130 };
131
132 let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
133 let oneliner = format!(
134 "export async function {}({}): Promise<{}> {{",
135 name, params_inline, ret
136 );
137 if oneliner.len() <= 80 {
138 let _ = writeln!(out, "{}", oneliner);
139 } else {
140 let mut parts = vec!["\tconn: Connection".to_string()];
141 for p in params {
142 parts.push(format!("\t{}: {}", p.field_name, p.full_type));
143 }
144 let _ = writeln!(out, "export async function {}(", name);
145 for part in &parts {
146 let _ = writeln!(out, "{},", part);
147 }
148 let _ = writeln!(out, "): Promise<{}> {{", ret);
149 }
150 };
151
152 let param_args = if params.is_empty() {
153 String::new()
154 } else {
155 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
156 args.join(", ")
157 };
158
159 match &analyzed.command {
160 QueryCommand::One => {
161 let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
162 let ret = format!("{} | null", struct_name);
163 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
164 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
165 if params.is_empty() {
166 let _ = writeln!(out, "\tconst result = await stmt.run();");
167 } else {
168 let _ = writeln!(out, "\tconst result = await stmt.run({});", param_args);
169 }
170 let _ = writeln!(out, "\tconst rows = await result.getRows();");
171 let _ = writeln!(
175 out,
176 "\tconst row = rows.length > 0 ? rows[0] as unknown as {} : null;",
177 struct_name
178 );
179 let _ = writeln!(out, "\treturn row;");
180 let _ = write!(out, "}}");
181 }
182 QueryCommand::Batch => {
183 let batch_fn_name = format!("{}Batch", func_name);
184 if params.len() > 1 {
185 let params_type_name = format!("{}BatchParams", struct_name);
186 let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
187 let _ = writeln!(out, "export interface {} {{", params_type_name);
188 for p in params {
189 let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
190 }
191 let _ = writeln!(out, "}}");
192 let _ = writeln!(out);
193 let _ = writeln!(
194 out,
195 "/** Execute {} for each item in the batch. */",
196 analyzed.name
197 );
198 let batch_params = format!("conn: Connection, items: {}[]", params_type_name);
199 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "void");
200 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
201 let _ = writeln!(out, "\tfor (const item of items) {{");
202 let args: Vec<String> = params
203 .iter()
204 .map(|p| format!("item.{}", p.field_name))
205 .collect();
206 let _ = writeln!(out, "\t\tawait stmt.run({});", args.join(", "));
207 let _ = writeln!(out, "\t}}");
208 let _ = write!(out, "}}");
209 } else if params.len() == 1 {
210 let _ = writeln!(
211 out,
212 "/** Execute {} for each item in the batch. */",
213 analyzed.name
214 );
215 let batch_params =
216 format!("conn: Connection, items: {}[]", params[0].full_type);
217 write_fn_sig(&mut out, &batch_fn_name, &batch_params, "void");
218 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
219 let _ = writeln!(out, "\tfor (const item of items) {{");
220 let _ = writeln!(out, "\t\tawait stmt.run(item);");
221 let _ = writeln!(out, "\t}}");
222 let _ = write!(out, "}}");
223 } else {
224 let _ = writeln!(
225 out,
226 "/** Execute {} for each item in the batch. */",
227 analyzed.name
228 );
229 write_fn_sig(
230 &mut out,
231 &batch_fn_name,
232 "conn: Connection, count: number",
233 "void",
234 );
235 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
236 let _ = writeln!(out, "\tfor (let i = 0; i < count; i++) {{");
237 let _ = writeln!(out, "\t\tawait stmt.run();");
238 let _ = writeln!(out, "\t}}");
239 let _ = write!(out, "}}");
240 }
241 }
242 QueryCommand::Many => {
243 let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
244 let ret = format!("{}[]", struct_name);
245 write_fn_sig(&mut out, &func_name, &inline_params, &ret);
246 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
247 if params.is_empty() {
248 let _ = writeln!(out, "\tconst result = await stmt.run();");
249 } else {
250 let _ = writeln!(out, "\tconst result = await stmt.run({});", param_args);
251 }
252 let _ = writeln!(
253 out,
254 "\treturn await result.getRows() as unknown as {}[];",
255 struct_name
256 );
257 let _ = write!(out, "}}");
258 }
259 QueryCommand::Exec => {
260 let _ = writeln!(out, "/** Execute a query returning no rows. */");
261 write_fn_sig(&mut out, &func_name, &inline_params, "void");
262 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
263 if params.is_empty() {
264 let _ = writeln!(out, "\tawait stmt.run();");
265 } else {
266 let _ = writeln!(out, "\tawait stmt.run({});", param_args);
267 }
268 let _ = write!(out, "}}");
269 }
270 QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
271 QueryCommand::ExecResult | QueryCommand::ExecRows => {
272 let _ = writeln!(
273 out,
274 "/** Execute a query and return the number of affected rows. */"
275 );
276 write_fn_sig(&mut out, &func_name, &inline_params, "number");
277 let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
278 if params.is_empty() {
279 let _ = writeln!(out, "\tconst result = await stmt.run();");
280 } else {
281 let _ = writeln!(out, "\tconst result = await stmt.run({});", param_args);
282 }
283 let _ = writeln!(out, "\treturn result.rowsChanged;");
284 let _ = write!(out, "}}");
285 }
286 }
287
288 Ok(out)
289 }
290
291 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
292 let type_name = to_pascal_case(&enum_info.sql_name);
293 if self.row_type == TsRowType::Zod {
294 return Ok(super::typescript_common::generate_zod_enum(
295 &type_name,
296 &enum_info.values,
297 ));
298 }
299 let mut out = String::new();
300 let variants: Vec<String> = enum_info
301 .values
302 .iter()
303 .map(|v| format!("\"{}\"", v))
304 .collect();
305 let _ = write!(out, "export type {} = {};", type_name, variants.join(" | "));
306 Ok(out)
307 }
308
309 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
310 let name = to_pascal_case(&composite.sql_name);
311 let mut out = String::new();
312 let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
313 let _ = writeln!(out, "export interface {} {{", name);
314 for field in &composite.fields {
315 let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
316 .map(|t| t.into_owned())
317 .map_err(|e| {
318 ScytheError::new(
319 ErrorCode::InternalError,
320 format!("composite field type error: {}", e),
321 )
322 })?;
323 let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
324 }
325 let _ = write!(out, "}}");
326 Ok(out)
327 }
328
329 fn apply_options(
330 &mut self,
331 options: &std::collections::HashMap<String, String>,
332 ) -> Result<(), ScytheError> {
333 if let Some(value) = options.get("row_type") {
334 self.row_type = TsRowType::from_option(value)?;
335 }
336 Ok(())
337 }
338}