Skip to main content

scythe_codegen/backends/
python_duckdb.rs

1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-duckdb.toml");
21
22pub struct PythonDuckdbBackend {
23    manifest: BackendManifest,
24    row_type: PythonRowType,
25}
26
27impl PythonDuckdbBackend {
28    pub fn new(engine: &str) -> Result<Self, ScytheError> {
29        match engine {
30            "duckdb" => {}
31            _ => {
32                return Err(ScytheError::new(
33                    ErrorCode::InternalError,
34                    format!(
35                        "python-duckdb only supports DuckDB, got engine '{}'",
36                        engine
37                    ),
38                ));
39            }
40        }
41        let manifest_path = Path::new("backends/python-duckdb/manifest.toml");
42        let manifest = if manifest_path.exists() {
43            load_manifest(manifest_path)
44                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45        } else {
46            toml::from_str(DEFAULT_MANIFEST_TOML)
47                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48        };
49        Ok(Self {
50            manifest,
51            row_type: PythonRowType::default(),
52        })
53    }
54}
55
56/// Rewrite $1, $2, ... positional params to ? for DuckDB.
57///
58/// Uses a char-by-char scan (like `pg_to_jdbc_params` in kotlin_exposed) to
59/// correctly handle any number of parameters and avoid replacing `$N` tokens
60/// that appear inside string literals.
61fn rewrite_params_to_qmark(sql: &str) -> String {
62    let mut result = String::with_capacity(sql.len());
63    let mut chars = sql.chars().peekable();
64    while let Some(ch) = chars.next() {
65        if ch == '\'' {
66            // Pass through single-quoted string literals without rewriting.
67            result.push(ch);
68            while let Some(inner) = chars.next() {
69                result.push(inner);
70                if inner == '\'' {
71                    // Handle escaped quotes ('')
72                    if chars.peek() == Some(&'\'') {
73                        result.push(chars.next().unwrap());
74                    } else {
75                        break;
76                    }
77                }
78            }
79        } else if ch == '$' {
80            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
81                // Consume all digits after $
82                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
83                    chars.next();
84                }
85                result.push('?');
86            } else {
87                result.push(ch);
88            }
89        } else {
90            result.push(ch);
91        }
92    }
93    result
94}
95
96impl CodegenBackend for PythonDuckdbBackend {
97    fn name(&self) -> &str {
98        "python-duckdb"
99    }
100
101    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
102        &self.manifest
103    }
104
105    fn supported_engines(&self) -> &[&str] {
106        &["duckdb"]
107    }
108
109    fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
110        if let Some(rt) = options.get("row_type") {
111            self.row_type = PythonRowType::from_option(rt)?;
112        }
113        Ok(())
114    }
115
116    fn file_header(&self) -> String {
117        let import_line = self.row_type.import_line();
118        if self.row_type.is_stdlib_import() {
119            format!(
120                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
121                 \n\
122                 import datetime  # noqa: F401\n\
123                 import decimal  # noqa: F401\n\
124                 {import_line}\n\
125                 from enum import Enum  # noqa: F401\n\
126                 \n\
127                 import duckdb  # noqa: F401\n\
128                 \n",
129            )
130        } else {
131            let third_party = self
132                .row_type
133                .sorted_third_party_imports("import duckdb  # noqa: F401");
134            format!(
135                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
136                 \n\
137                 import datetime  # noqa: F401\n\
138                 import decimal  # noqa: F401\n\
139                 from enum import Enum  # noqa: F401\n\
140                 \n\
141                 {third_party}\n\
142                 \n",
143            )
144        }
145    }
146
147    fn generate_row_struct(
148        &self,
149        query_name: &str,
150        columns: &[ResolvedColumn],
151    ) -> Result<String, ScytheError> {
152        let struct_name = row_struct_name(query_name, &self.manifest.naming);
153        let mut out = String::new();
154        let _ = write!(out, "{}", self.row_type.decorator());
155        let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
156        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
157        if columns.is_empty() {
158            let _ = writeln!(out, "    pass");
159        } else {
160            let _ = writeln!(out);
161            for col in columns {
162                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
163            }
164        }
165        Ok(out)
166    }
167
168    fn generate_model_struct(
169        &self,
170        table_name: &str,
171        columns: &[ResolvedColumn],
172    ) -> Result<String, ScytheError> {
173        let singular = singularize(table_name);
174        let name = to_pascal_case(&singular);
175        self.generate_row_struct(&name, columns)
176    }
177
178    fn generate_query_fn(
179        &self,
180        analyzed: &AnalyzedQuery,
181        struct_name: &str,
182        columns: &[ResolvedColumn],
183        params: &[ResolvedParam],
184    ) -> Result<String, ScytheError> {
185        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
186        let mut out = String::new();
187
188        let param_list = params
189            .iter()
190            .map(|p| format!("{}: {}", p.field_name, p.full_type))
191            .collect::<Vec<_>>()
192            .join(", ");
193        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
194
195        let sql = rewrite_params_to_qmark(&super::clean_sql_with_optional(
196            &analyzed.sql,
197            &analyzed.optional_params,
198            &analyzed.params,
199        ));
200
201        let args_list = if params.is_empty() {
202            String::new()
203        } else {
204            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
205            format!("[{}]", args.join(", "))
206        };
207
208        match &analyzed.command {
209            QueryCommand::One => {
210                let _ = writeln!(
211                    out,
212                    "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> {} | None:",
213                    func_name, kw_sep, param_list, struct_name
214                );
215                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
216                if params.is_empty() {
217                    let _ = writeln!(
218                        out,
219                        "    row = conn.execute(\"\"\"{}\"\"\").fetchone()",
220                        sql
221                    );
222                } else {
223                    let _ = writeln!(
224                        out,
225                        "    row = conn.execute(\"\"\"{}\"\"\", {}).fetchone()",
226                        sql, args_list
227                    );
228                }
229                let _ = writeln!(out, "    if row is None:");
230                let _ = writeln!(out, "        return None");
231                let field_assignments: Vec<String> = columns
232                    .iter()
233                    .enumerate()
234                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
235                    .collect();
236                let oneliner = format!(
237                    "    return {}({})",
238                    struct_name,
239                    field_assignments.join(", ")
240                );
241                if oneliner.len() <= 88 {
242                    let _ = writeln!(out, "{}", oneliner);
243                } else {
244                    let _ = writeln!(out, "    return {}(", struct_name);
245                    for fa in &field_assignments {
246                        let _ = writeln!(out, "        {},", fa);
247                    }
248                    let _ = writeln!(out, "    )");
249                }
250            }
251            QueryCommand::Batch => {
252                let batch_fn_name = format!("{}_batch", func_name);
253                let items_type = if params.len() > 1 {
254                    let tuple_types: Vec<String> =
255                        params.iter().map(|p| p.full_type.clone()).collect();
256                    format!("list[tuple[{}]]", tuple_types.join(", "))
257                } else if params.len() == 1 {
258                    format!("list[{}]", params[0].full_type)
259                } else {
260                    "int".to_string()
261                };
262                let _ = writeln!(
263                    out,
264                    "def {}(conn: duckdb.DuckDBPyConnection, *, items: {}) -> None:",
265                    batch_fn_name, items_type
266                );
267                let _ = writeln!(
268                    out,
269                    "    \"\"\"Execute {} query for each item in the batch.\"\"\"",
270                    analyzed.name
271                );
272                if params.is_empty() {
273                    let _ = writeln!(out, "    for _ in range(items):");
274                    let _ = writeln!(out, "        conn.execute(\"\"\"{}\"\"\")", sql);
275                } else if params.len() == 1 {
276                    let _ = writeln!(
277                        out,
278                        "    conn.executemany(\"\"\"{}\"\"\", [[item] for item in items])",
279                        sql
280                    );
281                } else {
282                    let _ = writeln!(
283                        out,
284                        "    conn.executemany(\"\"\"{}\"\"\", [list(item) for item in items])",
285                        sql
286                    );
287                }
288            }
289            QueryCommand::Many => {
290                let _ = writeln!(
291                    out,
292                    "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> list[{}]:",
293                    func_name, kw_sep, param_list, struct_name
294                );
295                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
296                if params.is_empty() {
297                    let _ = writeln!(
298                        out,
299                        "    rows = conn.execute(\"\"\"{}\"\"\").fetchall()",
300                        sql
301                    );
302                } else {
303                    let _ = writeln!(
304                        out,
305                        "    rows = conn.execute(\"\"\"{}\"\"\", {}).fetchall()",
306                        sql, args_list
307                    );
308                }
309                let field_assignments: Vec<String> = columns
310                    .iter()
311                    .enumerate()
312                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
313                    .collect();
314                let oneliner = format!(
315                    "    return [{}({}) for r in rows]",
316                    struct_name,
317                    field_assignments.join(", ")
318                );
319                if oneliner.len() <= 88 {
320                    let _ = writeln!(out, "{}", oneliner);
321                } else {
322                    let _ = writeln!(out, "    return [");
323                    let _ = writeln!(out, "        {}(", struct_name);
324                    for fa in &field_assignments {
325                        let _ = writeln!(out, "            {},", fa);
326                    }
327                    let _ = writeln!(out, "        )");
328                    let _ = writeln!(out, "        for r in rows");
329                    let _ = writeln!(out, "    ]");
330                }
331            }
332            QueryCommand::Exec => {
333                let _ = writeln!(
334                    out,
335                    "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> None:",
336                    func_name, kw_sep, param_list
337                );
338                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
339                if params.is_empty() {
340                    let _ = writeln!(out, "    conn.execute(\"\"\"{}\"\"\")", sql);
341                } else {
342                    let _ = writeln!(out, "    conn.execute(\"\"\"{}\"\"\", {})", sql, args_list);
343                }
344            }
345            QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
346            QueryCommand::ExecResult | QueryCommand::ExecRows => {
347                let _ = writeln!(
348                    out,
349                    "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> int:",
350                    func_name, kw_sep, param_list
351                );
352                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
353                if params.is_empty() {
354                    let _ = writeln!(out, "    result = conn.execute(\"\"\"{}\"\"\")", sql);
355                } else {
356                    let _ = writeln!(
357                        out,
358                        "    result = conn.execute(\"\"\"{}\"\"\", {})",
359                        sql, args_list
360                    );
361                }
362                let _ = writeln!(
363                    out,
364                    "    row = result.fetchone()\n    return row[0] if row else 0"
365                );
366            }
367        }
368
369        Ok(out)
370    }
371
372    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
373        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
374        let mut out = String::new();
375        let _ = writeln!(out, "class {}(str, Enum):", type_name);
376        let _ = writeln!(
377            out,
378            "    \"\"\"Database enum type {}.\"\"\"",
379            enum_info.sql_name
380        );
381        if enum_info.values.is_empty() {
382            let _ = writeln!(out, "    pass");
383        } else {
384            let _ = writeln!(out);
385            for value in &enum_info.values {
386                let variant = enum_variant_name(value, &self.manifest.naming);
387                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
388            }
389        }
390        Ok(out)
391    }
392
393    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
394        let name = to_pascal_case(&composite.sql_name);
395        let mut out = String::new();
396        let _ = write!(out, "{}", self.row_type.decorator());
397        let _ = writeln!(out, "{}", self.row_type.class_def(&name));
398        let _ = writeln!(
399            out,
400            "    \"\"\"Composite type {}.\"\"\"",
401            composite.sql_name
402        );
403        if composite.fields.is_empty() {
404            let _ = writeln!(out, "    pass");
405        } else {
406            let _ = writeln!(out);
407            for field in &composite.fields {
408                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
409                    .map(|t| t.into_owned())
410                    .map_err(|e| {
411                        ScytheError::new(
412                            ErrorCode::InternalError,
413                            format!("composite field type error: {}", e),
414                        )
415                    })?;
416                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
417            }
418        }
419        Ok(out)
420    }
421}