Skip to main content

scythe_codegen/backends/
python_psycopg3.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
18    include_str!("../../manifests/python-psycopg3.toml");
19
20pub struct PythonPsycopg3Backend {
21    manifest: BackendManifest,
22}
23
24impl PythonPsycopg3Backend {
25    pub fn new() -> Result<Self, ScytheError> {
26        let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
27        let manifest = if manifest_path.exists() {
28            load_manifest(manifest_path)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        } else {
31            toml::from_str(DEFAULT_MANIFEST_TOML)
32                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
33        };
34        Ok(Self { manifest })
35    }
36
37    pub fn manifest(&self) -> &BackendManifest {
38        &self.manifest
39    }
40}
41
42/// Rewrite `$1`, `$2`, ... positional params to psycopg3 named params `%(name)s`.
43fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
44    let mut result = sql.to_string();
45    // Replace in reverse order so positions don't shift
46    let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
47    params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
48    for param in params_sorted {
49        let placeholder = format!("${}", param.position);
50        let named = format!("%({})s", to_snake_case(&param.name));
51        result = result.replace(&placeholder, &named);
52    }
53    result
54}
55
56impl CodegenBackend for PythonPsycopg3Backend {
57    fn name(&self) -> &str {
58        "python-psycopg3"
59    }
60
61    fn file_header(&self) -> String {
62        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
63         \n\
64         import datetime  # noqa: F401\n\
65         from dataclasses import dataclass\n\
66         from enum import Enum  # noqa: F401\n\
67         \n\
68         from psycopg import AsyncConnection  # noqa: F401\n\
69         \n"
70        .to_string()
71    }
72
73    fn generate_row_struct(
74        &self,
75        query_name: &str,
76        columns: &[ResolvedColumn],
77    ) -> Result<String, ScytheError> {
78        let struct_name = row_struct_name(query_name, &self.manifest.naming);
79        let mut out = String::new();
80        let _ = writeln!(out, "@dataclass");
81        let _ = writeln!(out, "class {}:", struct_name);
82        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
83        if columns.is_empty() {
84            let _ = writeln!(out, "    pass");
85        } else {
86            let _ = writeln!(out);
87            for col in columns {
88                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
89            }
90        }
91        Ok(out)
92    }
93
94    fn generate_model_struct(
95        &self,
96        table_name: &str,
97        columns: &[ResolvedColumn],
98    ) -> Result<String, ScytheError> {
99        let singular = singularize(table_name);
100        let name = to_pascal_case(&singular);
101        self.generate_row_struct(&name, columns)
102    }
103
104    fn generate_query_fn(
105        &self,
106        analyzed: &AnalyzedQuery,
107        struct_name: &str,
108        columns: &[ResolvedColumn],
109        params: &[ResolvedParam],
110    ) -> Result<String, ScytheError> {
111        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
112        let mut out = String::new();
113
114        // Build parameter list (keyword-only after conn)
115        let param_list = params
116            .iter()
117            .map(|p| format!("{}: {}", p.field_name, p.full_type))
118            .collect::<Vec<_>>()
119            .join(", ");
120        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
121
122        // Clean and rewrite SQL
123        let sql_clean = super::clean_sql(&analyzed.sql);
124        let sql = rewrite_params_named(&sql_clean, analyzed);
125
126        match &analyzed.command {
127            QueryCommand::One => {
128                let _ = writeln!(
129                    out,
130                    "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
131                    func_name, kw_sep, param_list, struct_name
132                );
133                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
134                // Build params dict
135                if params.is_empty() {
136                    let _ = writeln!(out, "    cur = await conn.execute(");
137                    let _ = writeln!(out, "        \"{}\",", sql);
138                    let _ = writeln!(out, "    )");
139                } else {
140                    let dict_entries: Vec<String> = params
141                        .iter()
142                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
143                        .collect();
144                    let _ = writeln!(out, "    cur = await conn.execute(");
145                    let _ = writeln!(out, "        \"{}\",", sql);
146                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
147                    let _ = writeln!(out, "    )");
148                }
149                let _ = writeln!(out, "    row = await cur.fetchone()");
150                let _ = writeln!(out, "    if row is None:");
151                let _ = writeln!(out, "        return None");
152                // Construct dataclass from positional row
153                let field_assignments: Vec<String> = columns
154                    .iter()
155                    .enumerate()
156                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
157                    .collect();
158                let oneliner = format!(
159                    "    return {}({})",
160                    struct_name,
161                    field_assignments.join(", ")
162                );
163                if oneliner.len() <= 88 {
164                    let _ = writeln!(out, "{}", oneliner);
165                } else {
166                    let _ = writeln!(out, "    return {}(", struct_name);
167                    for fa in &field_assignments {
168                        let _ = writeln!(out, "        {},", fa);
169                    }
170                    let _ = writeln!(out, "    )");
171                }
172            }
173            QueryCommand::Many | QueryCommand::Batch => {
174                let _ = writeln!(
175                    out,
176                    "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
177                    func_name, kw_sep, param_list, struct_name
178                );
179                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
180                if params.is_empty() {
181                    let _ = writeln!(out, "    cur = await conn.execute(");
182                    let _ = writeln!(out, "        \"{}\",", sql);
183                    let _ = writeln!(out, "    )");
184                } else {
185                    let dict_entries: Vec<String> = params
186                        .iter()
187                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
188                        .collect();
189                    let _ = writeln!(out, "    cur = await conn.execute(");
190                    let _ = writeln!(out, "        \"{}\",", sql);
191                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
192                    let _ = writeln!(out, "    )");
193                }
194                let _ = writeln!(out, "    rows = await cur.fetchall()");
195                let field_assignments: Vec<String> = columns
196                    .iter()
197                    .enumerate()
198                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
199                    .collect();
200                let oneliner = format!(
201                    "    return [{}({}) for r in rows]",
202                    struct_name,
203                    field_assignments.join(", ")
204                );
205                if oneliner.len() <= 88 {
206                    let _ = writeln!(out, "{}", oneliner);
207                } else {
208                    let _ = writeln!(out, "    return [");
209                    let _ = writeln!(out, "        {}(", struct_name);
210                    for fa in &field_assignments {
211                        let _ = writeln!(out, "            {},", fa);
212                    }
213                    let _ = writeln!(out, "        )");
214                    let _ = writeln!(out, "        for r in rows");
215                    let _ = writeln!(out, "    ]");
216                }
217            }
218            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
219                let _ = writeln!(
220                    out,
221                    "async def {}(conn: AsyncConnection{}{}) -> None:",
222                    func_name, kw_sep, param_list
223                );
224                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
225                if params.is_empty() {
226                    let _ = writeln!(out, "    await conn.execute(");
227                    let _ = writeln!(out, "        \"{}\",", sql);
228                    let _ = writeln!(out, "    )");
229                } else {
230                    let dict_entries: Vec<String> = params
231                        .iter()
232                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
233                        .collect();
234                    let _ = writeln!(out, "    await conn.execute(");
235                    let _ = writeln!(out, "        \"{}\",", sql);
236                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
237                    let _ = writeln!(out, "    )");
238                }
239            }
240        }
241
242        Ok(out)
243    }
244
245    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
246        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
247        let mut out = String::new();
248        let _ = writeln!(out, "class {}(str, Enum):", type_name);
249        let _ = writeln!(
250            out,
251            "    \"\"\"Database enum type {}.\"\"\"",
252            enum_info.sql_name
253        );
254        if enum_info.values.is_empty() {
255            let _ = writeln!(out, "    pass");
256        } else {
257            let _ = writeln!(out);
258            for value in &enum_info.values {
259                let variant = enum_variant_name(value, &self.manifest.naming);
260                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
261            }
262        }
263        Ok(out)
264    }
265
266    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
267        let name = to_pascal_case(&composite.sql_name);
268        let mut out = String::new();
269        let _ = writeln!(out, "@dataclass");
270        let _ = writeln!(out, "class {}:", name);
271        let _ = writeln!(
272            out,
273            "    \"\"\"Composite type {}.\"\"\"",
274            composite.sql_name
275        );
276        if composite.fields.is_empty() {
277            let _ = writeln!(out, "    pass");
278        } else {
279            let _ = writeln!(out);
280            for field in &composite.fields {
281                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
282                    .map(|t| t.into_owned())
283                    .map_err(|e| {
284                        ScytheError::new(
285                            ErrorCode::InternalError,
286                            format!("composite field type error: {}", e),
287                        )
288                    })?;
289                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
290            }
291        }
292        Ok(out)
293    }
294}