Skip to main content

scythe_codegen/backends/
typescript_duckdb.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{fn_name, row_struct_name, to_camel_case, to_pascal_case};
6use scythe_backend::types::resolve_type;
7
8use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
9use scythe_core::errors::{ErrorCode, ScytheError};
10use scythe_core::parser::QueryCommand;
11
12use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
13use crate::backends::typescript_common::{TsRowType, generate_zod_row_struct};
14use crate::singularize;
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/typescript-duckdb.toml");
17
18pub struct TypescriptDuckdbBackend {
19    manifest: BackendManifest,
20    row_type: TsRowType,
21}
22
23impl TypescriptDuckdbBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "duckdb" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "typescript-duckdb only supports DuckDB, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/typescript-duckdb/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self {
46            manifest,
47            row_type: TsRowType::default(),
48        })
49    }
50}
51
52impl CodegenBackend for TypescriptDuckdbBackend {
53    fn name(&self) -> &str {
54        "typescript-duckdb"
55    }
56
57    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
58        &self.manifest
59    }
60
61    fn supported_engines(&self) -> &[&str] {
62        &["duckdb"]
63    }
64
65    fn file_header(&self) -> String {
66        let mut header =
67            "/** Auto-generated by scythe. Do not edit. */\n\nimport type { Connection } from \"@duckdb/node-api\";\n"
68                .to_string();
69        if self.row_type == TsRowType::Zod {
70            header.push_str("import { z } from \"zod\";\n");
71        }
72        header
73    }
74
75    fn generate_row_struct(
76        &self,
77        query_name: &str,
78        columns: &[ResolvedColumn],
79    ) -> Result<String, ScytheError> {
80        let struct_name = row_struct_name(query_name, &self.manifest.naming);
81        if self.row_type == TsRowType::Zod {
82            return Ok(generate_zod_row_struct(&struct_name, query_name, columns));
83        }
84        let mut out = String::new();
85        let _ = writeln!(out, "/** Row type for {} queries. */", query_name);
86        let _ = writeln!(out, "export interface {} {{", struct_name);
87        for col in columns {
88            let _ = writeln!(out, "\t{}: {};", col.field_name, col.full_type);
89        }
90        let _ = write!(out, "}}");
91        Ok(out)
92    }
93
94    fn generate_model_struct(
95        &self,
96        table_name: &str,
97        columns: &[ResolvedColumn],
98    ) -> Result<String, ScytheError> {
99        let singular = singularize(table_name);
100        let name = to_pascal_case(&singular);
101        self.generate_row_struct(&name, columns)
102    }
103
104    fn generate_query_fn(
105        &self,
106        analyzed: &AnalyzedQuery,
107        struct_name: &str,
108        _columns: &[ResolvedColumn],
109        params: &[ResolvedParam],
110    ) -> Result<String, ScytheError> {
111        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
112        let mut out = String::new();
113
114        let param_list = params
115            .iter()
116            .map(|p| format!("{}: {}", p.field_name, p.full_type))
117            .collect::<Vec<_>>()
118            .join(", ");
119
120        let sql = super::clean_sql_with_optional(
121            &analyzed.sql,
122            &analyzed.optional_params,
123            &analyzed.params,
124        );
125
126        let inline_params = if params.is_empty() {
127            "conn: Connection".to_string()
128        } else {
129            format!("conn: Connection, {}", param_list)
130        };
131
132        let write_fn_sig = |out: &mut String, name: &str, params_inline: &str, ret: &str| {
133            let oneliner = format!(
134                "export async function {}({}): Promise<{}> {{",
135                name, params_inline, ret
136            );
137            if oneliner.len() <= 80 {
138                let _ = writeln!(out, "{}", oneliner);
139            } else {
140                let mut parts = vec!["\tconn: Connection".to_string()];
141                for p in params {
142                    parts.push(format!("\t{}: {}", p.field_name, p.full_type));
143                }
144                let _ = writeln!(out, "export async function {}(", name);
145                for part in &parts {
146                    let _ = writeln!(out, "{},", part);
147                }
148                let _ = writeln!(out, "): Promise<{}> {{", ret);
149            }
150        };
151
152        let param_args = if params.is_empty() {
153            String::new()
154        } else {
155            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
156            args.join(", ")
157        };
158
159        match &analyzed.command {
160            QueryCommand::One => {
161                let _ = writeln!(out, "/** Fetch a single {} or null. */", struct_name);
162                let ret = format!("{} | null", struct_name);
163                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
164                let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
165                if params.is_empty() {
166                    let _ = writeln!(out, "\tconst result = await stmt.run();");
167                } else {
168                    let _ = writeln!(out, "\tconst result = await stmt.run({});", param_args);
169                }
170                let _ = writeln!(out, "\tconst rows = await result.getRows();");
171                // TODO: `as unknown as T` is the standard DuckDB Node.js pattern but lacks
172                // per-column mapping/validation. A future improvement could map columns by
173                // name to struct fields for type safety.
174                let _ = writeln!(
175                    out,
176                    "\tconst row = rows.length > 0 ? rows[0] as unknown as {} : null;",
177                    struct_name
178                );
179                let _ = writeln!(out, "\treturn row;");
180                let _ = write!(out, "}}");
181            }
182            QueryCommand::Batch => {
183                let batch_fn_name = format!("{}Batch", func_name);
184                if params.len() > 1 {
185                    let params_type_name = format!("{}BatchParams", struct_name);
186                    let _ = writeln!(out, "/** Params for {} batch operation. */", struct_name);
187                    let _ = writeln!(out, "export interface {} {{", params_type_name);
188                    for p in params {
189                        let _ = writeln!(out, "\t{}: {};", p.field_name, p.full_type);
190                    }
191                    let _ = writeln!(out, "}}");
192                    let _ = writeln!(out);
193                    let _ = writeln!(
194                        out,
195                        "/** Execute {} for each item in the batch. */",
196                        analyzed.name
197                    );
198                    let batch_params = format!("conn: Connection, items: {}[]", params_type_name);
199                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "void");
200                    let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
201                    let _ = writeln!(out, "\tfor (const item of items) {{");
202                    let args: Vec<String> = params
203                        .iter()
204                        .map(|p| format!("item.{}", p.field_name))
205                        .collect();
206                    let _ = writeln!(out, "\t\tawait stmt.run({});", args.join(", "));
207                    let _ = writeln!(out, "\t}}");
208                    let _ = write!(out, "}}");
209                } else if params.len() == 1 {
210                    let _ = writeln!(
211                        out,
212                        "/** Execute {} for each item in the batch. */",
213                        analyzed.name
214                    );
215                    let batch_params =
216                        format!("conn: Connection, items: {}[]", params[0].full_type);
217                    write_fn_sig(&mut out, &batch_fn_name, &batch_params, "void");
218                    let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
219                    let _ = writeln!(out, "\tfor (const item of items) {{");
220                    let _ = writeln!(out, "\t\tawait stmt.run(item);");
221                    let _ = writeln!(out, "\t}}");
222                    let _ = write!(out, "}}");
223                } else {
224                    let _ = writeln!(
225                        out,
226                        "/** Execute {} for each item in the batch. */",
227                        analyzed.name
228                    );
229                    write_fn_sig(
230                        &mut out,
231                        &batch_fn_name,
232                        "conn: Connection, count: number",
233                        "void",
234                    );
235                    let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
236                    let _ = writeln!(out, "\tfor (let i = 0; i < count; i++) {{");
237                    let _ = writeln!(out, "\t\tawait stmt.run();");
238                    let _ = writeln!(out, "\t}}");
239                    let _ = write!(out, "}}");
240                }
241            }
242            QueryCommand::Many => {
243                let _ = writeln!(out, "/** Fetch all {} rows. */", struct_name);
244                let ret = format!("{}[]", struct_name);
245                write_fn_sig(&mut out, &func_name, &inline_params, &ret);
246                let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
247                if params.is_empty() {
248                    let _ = writeln!(out, "\tconst result = await stmt.run();");
249                } else {
250                    let _ = writeln!(out, "\tconst result = await stmt.run({});", param_args);
251                }
252                let _ = writeln!(
253                    out,
254                    "\treturn await result.getRows() as unknown as {}[];",
255                    struct_name
256                );
257                let _ = write!(out, "}}");
258            }
259            QueryCommand::Exec => {
260                let _ = writeln!(out, "/** Execute a query returning no rows. */");
261                write_fn_sig(&mut out, &func_name, &inline_params, "void");
262                let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
263                if params.is_empty() {
264                    let _ = writeln!(out, "\tawait stmt.run();");
265                } else {
266                    let _ = writeln!(out, "\tawait stmt.run({});", param_args);
267                }
268                let _ = write!(out, "}}");
269            }
270            QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
271            QueryCommand::ExecResult | QueryCommand::ExecRows => {
272                let _ = writeln!(
273                    out,
274                    "/** Execute a query and return the number of affected rows. */"
275                );
276                write_fn_sig(&mut out, &func_name, &inline_params, "number");
277                let _ = writeln!(out, "\tconst stmt = await conn.prepare(`{}`);", sql);
278                if params.is_empty() {
279                    let _ = writeln!(out, "\tconst result = await stmt.run();");
280                } else {
281                    let _ = writeln!(out, "\tconst result = await stmt.run({});", param_args);
282                }
283                let _ = writeln!(out, "\treturn result.rowsChanged;");
284                let _ = write!(out, "}}");
285            }
286        }
287
288        Ok(out)
289    }
290
291    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
292        let type_name = to_pascal_case(&enum_info.sql_name);
293        if self.row_type == TsRowType::Zod {
294            return Ok(super::typescript_common::generate_zod_enum(
295                &type_name,
296                &enum_info.values,
297            ));
298        }
299        let mut out = String::new();
300        let variants: Vec<String> = enum_info
301            .values
302            .iter()
303            .map(|v| format!("\"{}\"", v))
304            .collect();
305        let _ = write!(out, "export type {} = {};", type_name, variants.join(" | "));
306        Ok(out)
307    }
308
309    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
310        let name = to_pascal_case(&composite.sql_name);
311        let mut out = String::new();
312        let _ = writeln!(out, "/** Composite type {}. */", composite.sql_name);
313        let _ = writeln!(out, "export interface {} {{", name);
314        for field in &composite.fields {
315            let ts_type = resolve_type(&field.neutral_type, &self.manifest, false)
316                .map(|t| t.into_owned())
317                .map_err(|e| {
318                    ScytheError::new(
319                        ErrorCode::InternalError,
320                        format!("composite field type error: {}", e),
321                    )
322                })?;
323            let _ = writeln!(out, "\t{}: {};", to_camel_case(&field.name), ts_type);
324        }
325        let _ = write!(out, "}}");
326        Ok(out)
327    }
328
329    fn apply_options(
330        &mut self,
331        options: &std::collections::HashMap<String, String>,
332    ) -> Result<(), ScytheError> {
333        if let Some(value) = options.get("row_type") {
334            self.row_type = TsRowType::from_option(value)?;
335        }
336        Ok(())
337    }
338}