1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-duckdb.toml");
21
22pub struct PythonDuckdbBackend {
23 manifest: BackendManifest,
24 row_type: PythonRowType,
25}
26
27impl PythonDuckdbBackend {
28 pub fn new(engine: &str) -> Result<Self, ScytheError> {
29 match engine {
30 "duckdb" => {}
31 _ => {
32 return Err(ScytheError::new(
33 ErrorCode::InternalError,
34 format!(
35 "python-duckdb only supports DuckDB, got engine '{}'",
36 engine
37 ),
38 ));
39 }
40 }
41 let manifest_path = Path::new("backends/python-duckdb/manifest.toml");
42 let manifest = if manifest_path.exists() {
43 load_manifest(manifest_path)
44 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML)
47 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48 };
49 Ok(Self {
50 manifest,
51 row_type: PythonRowType::default(),
52 })
53 }
54}
55
56fn rewrite_params_to_qmark(sql: &str) -> String {
62 let mut result = String::with_capacity(sql.len());
63 let mut chars = sql.chars().peekable();
64 while let Some(ch) = chars.next() {
65 if ch == '\'' {
66 result.push(ch);
68 while let Some(inner) = chars.next() {
69 result.push(inner);
70 if inner == '\'' {
71 if chars.peek() == Some(&'\'') {
73 result.push(chars.next().unwrap());
74 } else {
75 break;
76 }
77 }
78 }
79 } else if ch == '$' {
80 if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
81 while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
83 chars.next();
84 }
85 result.push('?');
86 } else {
87 result.push(ch);
88 }
89 } else {
90 result.push(ch);
91 }
92 }
93 result
94}
95
96impl CodegenBackend for PythonDuckdbBackend {
97 fn name(&self) -> &str {
98 "python-duckdb"
99 }
100
101 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
102 &self.manifest
103 }
104
105 fn supported_engines(&self) -> &[&str] {
106 &["duckdb"]
107 }
108
109 fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
110 if let Some(rt) = options.get("row_type") {
111 self.row_type = PythonRowType::from_option(rt)?;
112 }
113 Ok(())
114 }
115
116 fn file_header(&self) -> String {
117 let import_line = self.row_type.import_line();
118 if self.row_type.is_stdlib_import() {
119 format!(
120 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
121 \n\
122 import datetime # noqa: F401\n\
123 import decimal # noqa: F401\n\
124 {import_line}\n\
125 from enum import Enum # noqa: F401\n\
126 \n\
127 import duckdb # noqa: F401\n\
128 \n",
129 )
130 } else {
131 let third_party = self
132 .row_type
133 .sorted_third_party_imports("import duckdb # noqa: F401");
134 format!(
135 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
136 \n\
137 import datetime # noqa: F401\n\
138 import decimal # noqa: F401\n\
139 from enum import Enum # noqa: F401\n\
140 \n\
141 {third_party}\n\
142 \n",
143 )
144 }
145 }
146
147 fn generate_row_struct(
148 &self,
149 query_name: &str,
150 columns: &[ResolvedColumn],
151 ) -> Result<String, ScytheError> {
152 let struct_name = row_struct_name(query_name, &self.manifest.naming);
153 let mut out = String::new();
154 let _ = write!(out, "{}", self.row_type.decorator());
155 let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
156 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
157 if columns.is_empty() {
158 let _ = writeln!(out, " pass");
159 } else {
160 let _ = writeln!(out);
161 for col in columns {
162 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
163 }
164 }
165 Ok(out)
166 }
167
168 fn generate_model_struct(
169 &self,
170 table_name: &str,
171 columns: &[ResolvedColumn],
172 ) -> Result<String, ScytheError> {
173 let singular = singularize(table_name);
174 let name = to_pascal_case(&singular);
175 self.generate_row_struct(&name, columns)
176 }
177
178 fn generate_query_fn(
179 &self,
180 analyzed: &AnalyzedQuery,
181 struct_name: &str,
182 columns: &[ResolvedColumn],
183 params: &[ResolvedParam],
184 ) -> Result<String, ScytheError> {
185 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
186 let mut out = String::new();
187
188 let param_list = params
189 .iter()
190 .map(|p| format!("{}: {}", p.field_name, p.full_type))
191 .collect::<Vec<_>>()
192 .join(", ");
193 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
194
195 let sql = rewrite_params_to_qmark(&super::clean_sql_with_optional(
196 &analyzed.sql,
197 &analyzed.optional_params,
198 &analyzed.params,
199 ));
200
201 let args_list = if params.is_empty() {
202 String::new()
203 } else {
204 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
205 format!("[{}]", args.join(", "))
206 };
207
208 match &analyzed.command {
209 QueryCommand::One => {
210 let _ = writeln!(
211 out,
212 "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> {} | None:",
213 func_name, kw_sep, param_list, struct_name
214 );
215 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
216 if params.is_empty() {
217 let _ = writeln!(
218 out,
219 " row = conn.execute(\"\"\"{}\"\"\").fetchone()",
220 sql
221 );
222 } else {
223 let _ = writeln!(
224 out,
225 " row = conn.execute(\"\"\"{}\"\"\", {}).fetchone()",
226 sql, args_list
227 );
228 }
229 let _ = writeln!(out, " if row is None:");
230 let _ = writeln!(out, " return None");
231 let field_assignments: Vec<String> = columns
232 .iter()
233 .enumerate()
234 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
235 .collect();
236 let oneliner = format!(
237 " return {}({})",
238 struct_name,
239 field_assignments.join(", ")
240 );
241 if oneliner.len() <= 88 {
242 let _ = writeln!(out, "{}", oneliner);
243 } else {
244 let _ = writeln!(out, " return {}(", struct_name);
245 for fa in &field_assignments {
246 let _ = writeln!(out, " {},", fa);
247 }
248 let _ = writeln!(out, " )");
249 }
250 }
251 QueryCommand::Batch => {
252 let batch_fn_name = format!("{}_batch", func_name);
253 let items_type = if params.len() > 1 {
254 let tuple_types: Vec<String> =
255 params.iter().map(|p| p.full_type.clone()).collect();
256 format!("list[tuple[{}]]", tuple_types.join(", "))
257 } else if params.len() == 1 {
258 format!("list[{}]", params[0].full_type)
259 } else {
260 "int".to_string()
261 };
262 let _ = writeln!(
263 out,
264 "def {}(conn: duckdb.DuckDBPyConnection, *, items: {}) -> None:",
265 batch_fn_name, items_type
266 );
267 let _ = writeln!(
268 out,
269 " \"\"\"Execute {} query for each item in the batch.\"\"\"",
270 analyzed.name
271 );
272 if params.is_empty() {
273 let _ = writeln!(out, " for _ in range(items):");
274 let _ = writeln!(out, " conn.execute(\"\"\"{}\"\"\")", sql);
275 } else if params.len() == 1 {
276 let _ = writeln!(
277 out,
278 " conn.executemany(\"\"\"{}\"\"\", [[item] for item in items])",
279 sql
280 );
281 } else {
282 let _ = writeln!(
283 out,
284 " conn.executemany(\"\"\"{}\"\"\", [list(item) for item in items])",
285 sql
286 );
287 }
288 }
289 QueryCommand::Many => {
290 let _ = writeln!(
291 out,
292 "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> list[{}]:",
293 func_name, kw_sep, param_list, struct_name
294 );
295 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
296 if params.is_empty() {
297 let _ = writeln!(
298 out,
299 " rows = conn.execute(\"\"\"{}\"\"\").fetchall()",
300 sql
301 );
302 } else {
303 let _ = writeln!(
304 out,
305 " rows = conn.execute(\"\"\"{}\"\"\", {}).fetchall()",
306 sql, args_list
307 );
308 }
309 let field_assignments: Vec<String> = columns
310 .iter()
311 .enumerate()
312 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
313 .collect();
314 let oneliner = format!(
315 " return [{}({}) for r in rows]",
316 struct_name,
317 field_assignments.join(", ")
318 );
319 if oneliner.len() <= 88 {
320 let _ = writeln!(out, "{}", oneliner);
321 } else {
322 let _ = writeln!(out, " return [");
323 let _ = writeln!(out, " {}(", struct_name);
324 for fa in &field_assignments {
325 let _ = writeln!(out, " {},", fa);
326 }
327 let _ = writeln!(out, " )");
328 let _ = writeln!(out, " for r in rows");
329 let _ = writeln!(out, " ]");
330 }
331 }
332 QueryCommand::Exec => {
333 let _ = writeln!(
334 out,
335 "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> None:",
336 func_name, kw_sep, param_list
337 );
338 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
339 if params.is_empty() {
340 let _ = writeln!(out, " conn.execute(\"\"\"{}\"\"\")", sql);
341 } else {
342 let _ = writeln!(out, " conn.execute(\"\"\"{}\"\"\", {})", sql, args_list);
343 }
344 }
345 QueryCommand::Grouped => unreachable!("Grouped is rewritten to Many before codegen"),
346 QueryCommand::ExecResult | QueryCommand::ExecRows => {
347 let _ = writeln!(
348 out,
349 "def {}(conn: duckdb.DuckDBPyConnection{}{}) -> int:",
350 func_name, kw_sep, param_list
351 );
352 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
353 if params.is_empty() {
354 let _ = writeln!(out, " result = conn.execute(\"\"\"{}\"\"\")", sql);
355 } else {
356 let _ = writeln!(
357 out,
358 " result = conn.execute(\"\"\"{}\"\"\", {})",
359 sql, args_list
360 );
361 }
362 let _ = writeln!(
363 out,
364 " row = result.fetchone()\n return row[0] if row else 0"
365 );
366 }
367 }
368
369 Ok(out)
370 }
371
372 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
373 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
374 let mut out = String::new();
375 let _ = writeln!(out, "class {}(str, Enum):", type_name);
376 let _ = writeln!(
377 out,
378 " \"\"\"Database enum type {}.\"\"\"",
379 enum_info.sql_name
380 );
381 if enum_info.values.is_empty() {
382 let _ = writeln!(out, " pass");
383 } else {
384 let _ = writeln!(out);
385 for value in &enum_info.values {
386 let variant = enum_variant_name(value, &self.manifest.naming);
387 let _ = writeln!(out, " {} = \"{}\"", variant, value);
388 }
389 }
390 Ok(out)
391 }
392
393 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
394 let name = to_pascal_case(&composite.sql_name);
395 let mut out = String::new();
396 let _ = write!(out, "{}", self.row_type.decorator());
397 let _ = writeln!(out, "{}", self.row_type.class_def(&name));
398 let _ = writeln!(
399 out,
400 " \"\"\"Composite type {}.\"\"\"",
401 composite.sql_name
402 );
403 if composite.fields.is_empty() {
404 let _ = writeln!(out, " pass");
405 } else {
406 let _ = writeln!(out);
407 for field in &composite.fields {
408 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
409 .map(|t| t.into_owned())
410 .map_err(|e| {
411 ScytheError::new(
412 ErrorCode::InternalError,
413 format!("composite field type error: {}", e),
414 )
415 })?;
416 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
417 }
418 }
419 Ok(out)
420 }
421}