Skip to main content

scythe_codegen/backends/
python_psycopg3.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-psycopg3.toml");
18
19pub struct PythonPsycopg3Backend {
20    manifest: BackendManifest,
21}
22
23impl PythonPsycopg3Backend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "postgresql" | "postgres" | "pg" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "python-psycopg3 only supports PostgreSQL, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49/// Rewrite `$1`, `$2`, ... positional params to psycopg3 named params `%(name)s`.
50fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
51    let mut result = sql.to_string();
52    // Replace in reverse order so positions don't shift
53    let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
54    params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
55    for param in params_sorted {
56        let placeholder = format!("${}", param.position);
57        let named = format!("%({})s", to_snake_case(&param.name));
58        result = result.replace(&placeholder, &named);
59    }
60    result
61}
62
63impl CodegenBackend for PythonPsycopg3Backend {
64    fn name(&self) -> &str {
65        "python-psycopg3"
66    }
67
68    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
69        &self.manifest
70    }
71
72    fn file_header(&self) -> String {
73        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
74         \n\
75         import datetime  # noqa: F401\n\
76         import decimal  # noqa: F401\n\
77         from dataclasses import dataclass\n\
78         from enum import Enum  # noqa: F401\n\
79         \n\
80         from psycopg import AsyncConnection  # noqa: F401\n\
81         \n"
82        .to_string()
83    }
84
85    fn generate_row_struct(
86        &self,
87        query_name: &str,
88        columns: &[ResolvedColumn],
89    ) -> Result<String, ScytheError> {
90        let struct_name = row_struct_name(query_name, &self.manifest.naming);
91        let mut out = String::new();
92        let _ = writeln!(out, "@dataclass");
93        let _ = writeln!(out, "class {}:", struct_name);
94        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
95        if columns.is_empty() {
96            let _ = writeln!(out, "    pass");
97        } else {
98            let _ = writeln!(out);
99            for col in columns {
100                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
101            }
102        }
103        Ok(out)
104    }
105
106    fn generate_model_struct(
107        &self,
108        table_name: &str,
109        columns: &[ResolvedColumn],
110    ) -> Result<String, ScytheError> {
111        let singular = singularize(table_name);
112        let name = to_pascal_case(&singular);
113        self.generate_row_struct(&name, columns)
114    }
115
116    fn generate_query_fn(
117        &self,
118        analyzed: &AnalyzedQuery,
119        struct_name: &str,
120        columns: &[ResolvedColumn],
121        params: &[ResolvedParam],
122    ) -> Result<String, ScytheError> {
123        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
124        let mut out = String::new();
125
126        // Build parameter list (keyword-only after conn)
127        let param_list = params
128            .iter()
129            .map(|p| format!("{}: {}", p.field_name, p.full_type))
130            .collect::<Vec<_>>()
131            .join(", ");
132        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
133
134        // Clean and rewrite SQL
135        let sql_clean = super::clean_sql(&analyzed.sql);
136        let sql = rewrite_params_named(&sql_clean, analyzed);
137
138        match &analyzed.command {
139            QueryCommand::One => {
140                let _ = writeln!(
141                    out,
142                    "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
143                    func_name, kw_sep, param_list, struct_name
144                );
145                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
146                // Build params dict
147                if params.is_empty() {
148                    let _ = writeln!(out, "    cur = await conn.execute(");
149                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
150                    let _ = writeln!(out, "    )");
151                } else {
152                    let dict_entries: Vec<String> = params
153                        .iter()
154                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
155                        .collect();
156                    let _ = writeln!(out, "    cur = await conn.execute(");
157                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
158                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
159                    let _ = writeln!(out, "    )");
160                }
161                let _ = writeln!(out, "    row = await cur.fetchone()");
162                let _ = writeln!(out, "    if row is None:");
163                let _ = writeln!(out, "        return None");
164                // Construct dataclass from positional row
165                let field_assignments: Vec<String> = columns
166                    .iter()
167                    .enumerate()
168                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
169                    .collect();
170                let oneliner = format!(
171                    "    return {}({})",
172                    struct_name,
173                    field_assignments.join(", ")
174                );
175                if oneliner.len() <= 88 {
176                    let _ = writeln!(out, "{}", oneliner);
177                } else {
178                    let _ = writeln!(out, "    return {}(", struct_name);
179                    for fa in &field_assignments {
180                        let _ = writeln!(out, "        {},", fa);
181                    }
182                    let _ = writeln!(out, "    )");
183                }
184            }
185            QueryCommand::Many | QueryCommand::Batch => {
186                let _ = writeln!(
187                    out,
188                    "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
189                    func_name, kw_sep, param_list, struct_name
190                );
191                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
192                if params.is_empty() {
193                    let _ = writeln!(out, "    cur = await conn.execute(");
194                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
195                    let _ = writeln!(out, "    )");
196                } else {
197                    let dict_entries: Vec<String> = params
198                        .iter()
199                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
200                        .collect();
201                    let _ = writeln!(out, "    cur = await conn.execute(");
202                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
203                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
204                    let _ = writeln!(out, "    )");
205                }
206                let _ = writeln!(out, "    rows = await cur.fetchall()");
207                let field_assignments: Vec<String> = columns
208                    .iter()
209                    .enumerate()
210                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
211                    .collect();
212                let oneliner = format!(
213                    "    return [{}({}) for r in rows]",
214                    struct_name,
215                    field_assignments.join(", ")
216                );
217                if oneliner.len() <= 88 {
218                    let _ = writeln!(out, "{}", oneliner);
219                } else {
220                    let _ = writeln!(out, "    return [");
221                    let _ = writeln!(out, "        {}(", struct_name);
222                    for fa in &field_assignments {
223                        let _ = writeln!(out, "            {},", fa);
224                    }
225                    let _ = writeln!(out, "        )");
226                    let _ = writeln!(out, "        for r in rows");
227                    let _ = writeln!(out, "    ]");
228                }
229            }
230            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
231                let _ = writeln!(
232                    out,
233                    "async def {}(conn: AsyncConnection{}{}) -> None:",
234                    func_name, kw_sep, param_list
235                );
236                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
237                if params.is_empty() {
238                    let _ = writeln!(out, "    await conn.execute(");
239                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
240                    let _ = writeln!(out, "    )");
241                } else {
242                    let dict_entries: Vec<String> = params
243                        .iter()
244                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
245                        .collect();
246                    let _ = writeln!(out, "    await conn.execute(");
247                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
248                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
249                    let _ = writeln!(out, "    )");
250                }
251            }
252        }
253
254        Ok(out)
255    }
256
257    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
258        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
259        let mut out = String::new();
260        let _ = writeln!(out, "class {}(str, Enum):", type_name);
261        let _ = writeln!(
262            out,
263            "    \"\"\"Database enum type {}.\"\"\"",
264            enum_info.sql_name
265        );
266        if enum_info.values.is_empty() {
267            let _ = writeln!(out, "    pass");
268        } else {
269            let _ = writeln!(out);
270            for value in &enum_info.values {
271                let variant = enum_variant_name(value, &self.manifest.naming);
272                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
273            }
274        }
275        Ok(out)
276    }
277
278    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
279        let name = to_pascal_case(&composite.sql_name);
280        let mut out = String::new();
281        let _ = writeln!(out, "@dataclass");
282        let _ = writeln!(out, "class {}:", name);
283        let _ = writeln!(
284            out,
285            "    \"\"\"Composite type {}.\"\"\"",
286            composite.sql_name
287        );
288        if composite.fields.is_empty() {
289            let _ = writeln!(out, "    pass");
290        } else {
291            let _ = writeln!(out);
292            for field in &composite.fields {
293                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
294                    .map(|t| t.into_owned())
295                    .map_err(|e| {
296                        ScytheError::new(
297                            ErrorCode::InternalError,
298                            format!("composite field type error: {}", e),
299                        )
300                    })?;
301                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
302            }
303        }
304        Ok(out)
305    }
306}