Skip to main content

scythe_codegen/backends/
python_psycopg3.rs

1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-psycopg3.toml");
21
22pub struct PythonPsycopg3Backend {
23    manifest: BackendManifest,
24    row_type: PythonRowType,
25}
26
27impl PythonPsycopg3Backend {
28    pub fn new(engine: &str) -> Result<Self, ScytheError> {
29        match engine {
30            "postgresql" | "postgres" | "pg" => {}
31            _ => {
32                return Err(ScytheError::new(
33                    ErrorCode::InternalError,
34                    format!(
35                        "python-psycopg3 only supports PostgreSQL, got engine '{}'",
36                        engine
37                    ),
38                ));
39            }
40        }
41        let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
42        let manifest = if manifest_path.exists() {
43            load_manifest(manifest_path)
44                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45        } else {
46            toml::from_str(DEFAULT_MANIFEST_TOML)
47                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48        };
49        Ok(Self {
50            manifest,
51            row_type: PythonRowType::default(),
52        })
53    }
54}
55
56/// Rewrite `$1`, `$2`, ... positional params to psycopg3 named params `%(name)s`.
57fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
58    let mut result = sql.to_string();
59    // Replace in reverse order so positions don't shift
60    let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
61    params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
62    for param in params_sorted {
63        let placeholder = format!("${}", param.position);
64        let named = format!("%({})s", to_snake_case(&param.name));
65        result = result.replace(&placeholder, &named);
66    }
67    result
68}
69
70impl CodegenBackend for PythonPsycopg3Backend {
71    fn name(&self) -> &str {
72        "python-psycopg3"
73    }
74
75    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
76        &self.manifest
77    }
78
79    fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
80        if let Some(rt) = options.get("row_type") {
81            self.row_type = PythonRowType::from_option(rt)?;
82        }
83        Ok(())
84    }
85
86    fn file_header(&self) -> String {
87        let import_line = self.row_type.import_line();
88        if self.row_type.is_stdlib_import() {
89            format!(
90                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
91                 \n\
92                 import datetime  # noqa: F401\n\
93                 import decimal  # noqa: F401\n\
94                 {import_line}\n\
95                 from enum import Enum  # noqa: F401\n\
96                 \n\
97                 from psycopg import AsyncConnection  # noqa: F401\n\
98                 \n",
99            )
100        } else {
101            // Third-party imports: `import` before `from`, sorted by module.
102            // msgspec uses `import msgspec` (bare import, comes first).
103            // pydantic uses `from pydantic import BaseModel` (from import,
104            //   sorted after `from psycopg`).
105            let third_party = self
106                .row_type
107                .sorted_third_party_imports("from psycopg import AsyncConnection  # noqa: F401");
108            format!(
109                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
110                 \n\
111                 import datetime  # noqa: F401\n\
112                 import decimal  # noqa: F401\n\
113                 from enum import Enum  # noqa: F401\n\
114                 \n\
115                 {third_party}\n\
116                 \n",
117            )
118        }
119    }
120
121    fn generate_row_struct(
122        &self,
123        query_name: &str,
124        columns: &[ResolvedColumn],
125    ) -> Result<String, ScytheError> {
126        let struct_name = row_struct_name(query_name, &self.manifest.naming);
127        let mut out = String::new();
128        let _ = write!(out, "{}", self.row_type.decorator());
129        let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
130        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
131        if columns.is_empty() {
132            let _ = writeln!(out, "    pass");
133        } else {
134            let _ = writeln!(out);
135            for col in columns {
136                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
137            }
138        }
139        Ok(out)
140    }
141
142    fn generate_model_struct(
143        &self,
144        table_name: &str,
145        columns: &[ResolvedColumn],
146    ) -> Result<String, ScytheError> {
147        let singular = singularize(table_name);
148        let name = to_pascal_case(&singular);
149        self.generate_row_struct(&name, columns)
150    }
151
152    fn generate_query_fn(
153        &self,
154        analyzed: &AnalyzedQuery,
155        struct_name: &str,
156        columns: &[ResolvedColumn],
157        params: &[ResolvedParam],
158    ) -> Result<String, ScytheError> {
159        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
160        let mut out = String::new();
161
162        // Build parameter list (keyword-only after conn)
163        let param_list = params
164            .iter()
165            .map(|p| format!("{}: {}", p.field_name, p.full_type))
166            .collect::<Vec<_>>()
167            .join(", ");
168        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
169
170        // Clean and rewrite SQL
171        let sql_clean = super::clean_sql_with_optional(
172            &analyzed.sql,
173            &analyzed.optional_params,
174            &analyzed.params,
175        );
176        let sql = rewrite_params_named(&sql_clean, analyzed);
177
178        match &analyzed.command {
179            QueryCommand::One => {
180                let _ = writeln!(
181                    out,
182                    "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
183                    func_name, kw_sep, param_list, struct_name
184                );
185                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
186                // Build params dict
187                if params.is_empty() {
188                    let _ = writeln!(out, "    cur = await conn.execute(");
189                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
190                    let _ = writeln!(out, "    )");
191                } else {
192                    let dict_entries: Vec<String> = params
193                        .iter()
194                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
195                        .collect();
196                    let _ = writeln!(out, "    cur = await conn.execute(");
197                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
198                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
199                    let _ = writeln!(out, "    )");
200                }
201                let _ = writeln!(out, "    row = await cur.fetchone()");
202                let _ = writeln!(out, "    if row is None:");
203                let _ = writeln!(out, "        return None");
204                // Construct dataclass from positional row
205                let field_assignments: Vec<String> = columns
206                    .iter()
207                    .enumerate()
208                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
209                    .collect();
210                let oneliner = format!(
211                    "    return {}({})",
212                    struct_name,
213                    field_assignments.join(", ")
214                );
215                if oneliner.len() <= 88 {
216                    let _ = writeln!(out, "{}", oneliner);
217                } else {
218                    let _ = writeln!(out, "    return {}(", struct_name);
219                    for fa in &field_assignments {
220                        let _ = writeln!(out, "        {},", fa);
221                    }
222                    let _ = writeln!(out, "    )");
223                }
224            }
225            QueryCommand::Many => {
226                let _ = writeln!(
227                    out,
228                    "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
229                    func_name, kw_sep, param_list, struct_name
230                );
231                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
232                if params.is_empty() {
233                    let _ = writeln!(out, "    cur = await conn.execute(");
234                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
235                    let _ = writeln!(out, "    )");
236                } else {
237                    let dict_entries: Vec<String> = params
238                        .iter()
239                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
240                        .collect();
241                    let _ = writeln!(out, "    cur = await conn.execute(");
242                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
243                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
244                    let _ = writeln!(out, "    )");
245                }
246                let _ = writeln!(out, "    rows = await cur.fetchall()");
247                let field_assignments: Vec<String> = columns
248                    .iter()
249                    .enumerate()
250                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
251                    .collect();
252                let oneliner = format!(
253                    "    return [{}({}) for r in rows]",
254                    struct_name,
255                    field_assignments.join(", ")
256                );
257                if oneliner.len() <= 88 {
258                    let _ = writeln!(out, "{}", oneliner);
259                } else {
260                    let _ = writeln!(out, "    return [");
261                    let _ = writeln!(out, "        {}(", struct_name);
262                    for fa in &field_assignments {
263                        let _ = writeln!(out, "            {},", fa);
264                    }
265                    let _ = writeln!(out, "        )");
266                    let _ = writeln!(out, "        for r in rows");
267                    let _ = writeln!(out, "    ]");
268                }
269            }
270            QueryCommand::Batch => {
271                let batch_fn_name = format!("{}_batch", func_name);
272                // Build the items type annotation
273                let items_type = if params.len() > 1 {
274                    let tuple_types: Vec<String> =
275                        params.iter().map(|p| p.full_type.clone()).collect();
276                    format!("list[tuple[{}]]", tuple_types.join(", "))
277                } else if params.len() == 1 {
278                    format!("list[{}]", params[0].full_type)
279                } else {
280                    "int".to_string()
281                };
282                let _ = writeln!(
283                    out,
284                    "async def {}(conn: AsyncConnection, *, items: {}) -> None:",
285                    batch_fn_name, items_type
286                );
287                let _ = writeln!(
288                    out,
289                    "    \"\"\"Execute {} query for each item in the batch.\"\"\"",
290                    analyzed.name
291                );
292                if params.is_empty() {
293                    let _ = writeln!(out, "    for _ in range(items):");
294                    let _ = writeln!(out, "        await conn.execute(");
295                    let _ = writeln!(out, "            \"\"\"{}\"\"\",", sql);
296                    let _ = writeln!(out, "        )");
297                } else {
298                    // Use executemany with named params dict list
299                    let dict_entries: Vec<String> = params
300                        .iter()
301                        .enumerate()
302                        .map(|(i, p)| {
303                            if params.len() == 1 {
304                                format!("\"{}\": item", p.field_name)
305                            } else {
306                                format!("\"{}\": item[{}]", p.field_name, i)
307                            }
308                        })
309                        .collect();
310                    let _ = writeln!(
311                        out,
312                        "    params_list = [{{{dict}}} for item in items]",
313                        dict = dict_entries.join(", ")
314                    );
315                    let _ = writeln!(out, "    cur = conn.cursor()");
316                    let _ = writeln!(out, "    await cur.executemany(");
317                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
318                    let _ = writeln!(out, "        params_list,");
319                    let _ = writeln!(out, "    )");
320                }
321            }
322            QueryCommand::Exec => {
323                let _ = writeln!(
324                    out,
325                    "async def {}(conn: AsyncConnection{}{}) -> None:",
326                    func_name, kw_sep, param_list
327                );
328                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
329                if params.is_empty() {
330                    let _ = writeln!(out, "    await conn.execute(");
331                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
332                    let _ = writeln!(out, "    )");
333                } else {
334                    let dict_entries: Vec<String> = params
335                        .iter()
336                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
337                        .collect();
338                    let _ = writeln!(out, "    await conn.execute(");
339                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
340                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
341                    let _ = writeln!(out, "    )");
342                }
343            }
344            QueryCommand::ExecResult | QueryCommand::ExecRows => {
345                let _ = writeln!(
346                    out,
347                    "async def {}(conn: AsyncConnection{}{}) -> int:",
348                    func_name, kw_sep, param_list
349                );
350                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
351                if params.is_empty() {
352                    let _ = writeln!(out, "    cur = await conn.execute(");
353                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
354                    let _ = writeln!(out, "    )");
355                } else {
356                    let dict_entries: Vec<String> = params
357                        .iter()
358                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
359                        .collect();
360                    let _ = writeln!(out, "    cur = await conn.execute(");
361                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
362                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
363                    let _ = writeln!(out, "    )");
364                }
365                let _ = writeln!(out, "    return cur.rowcount");
366            }
367            QueryCommand::Grouped => {
368                unreachable!("Grouped is rewritten to Many before codegen")
369            }
370        }
371
372        Ok(out)
373    }
374
375    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
376        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
377        let mut out = String::new();
378        let _ = writeln!(out, "class {}(str, Enum):", type_name);
379        let _ = writeln!(
380            out,
381            "    \"\"\"Database enum type {}.\"\"\"",
382            enum_info.sql_name
383        );
384        if enum_info.values.is_empty() {
385            let _ = writeln!(out, "    pass");
386        } else {
387            let _ = writeln!(out);
388            for value in &enum_info.values {
389                let variant = enum_variant_name(value, &self.manifest.naming);
390                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
391            }
392        }
393        Ok(out)
394    }
395
396    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
397        let name = to_pascal_case(&composite.sql_name);
398        let mut out = String::new();
399        let _ = write!(out, "{}", self.row_type.decorator());
400        let _ = writeln!(out, "{}", self.row_type.class_def(&name));
401        let _ = writeln!(
402            out,
403            "    \"\"\"Composite type {}.\"\"\"",
404            composite.sql_name
405        );
406        if composite.fields.is_empty() {
407            let _ = writeln!(out, "    pass");
408        } else {
409            let _ = writeln!(out);
410            for field in &composite.fields {
411                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
412                    .map(|t| t.into_owned())
413                    .map_err(|e| {
414                        ScytheError::new(
415                            ErrorCode::InternalError,
416                            format!("composite field type error: {}", e),
417                        )
418                    })?;
419                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
420            }
421        }
422        Ok(out)
423    }
424}