Skip to main content

scythe_codegen/backends/
python_psycopg3.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-psycopg3.toml");
18
19pub struct PythonPsycopg3Backend {
20    manifest: BackendManifest,
21}
22
23impl PythonPsycopg3Backend {
24    pub fn new() -> Result<Self, ScytheError> {
25        let manifest_path = Path::new("backends/python-psycopg3/manifest.toml");
26        let manifest = if manifest_path.exists() {
27            load_manifest(manifest_path)
28                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29        } else {
30            toml::from_str(DEFAULT_MANIFEST_TOML)
31                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32        };
33        Ok(Self { manifest })
34    }
35
36    pub fn manifest(&self) -> &BackendManifest {
37        &self.manifest
38    }
39}
40
41/// Rewrite `$1`, `$2`, ... positional params to psycopg3 named params `%(name)s`.
42fn rewrite_params_named(sql: &str, analyzed: &AnalyzedQuery) -> String {
43    let mut result = sql.to_string();
44    // Replace in reverse order so positions don't shift
45    let mut params_sorted: Vec<_> = analyzed.params.iter().collect();
46    params_sorted.sort_by(|a, b| b.position.cmp(&a.position));
47    for param in params_sorted {
48        let placeholder = format!("${}", param.position);
49        let named = format!("%({})s", to_snake_case(&param.name));
50        result = result.replace(&placeholder, &named);
51    }
52    result
53}
54
55impl CodegenBackend for PythonPsycopg3Backend {
56    fn name(&self) -> &str {
57        "python-psycopg3"
58    }
59
60    fn file_header(&self) -> String {
61        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
62         \n\
63         import datetime  # noqa: F401\n\
64         from dataclasses import dataclass\n\
65         from enum import Enum  # noqa: F401\n\
66         \n\
67         from psycopg import AsyncConnection  # noqa: F401\n\
68         \n"
69        .to_string()
70    }
71
72    fn generate_row_struct(
73        &self,
74        query_name: &str,
75        columns: &[ResolvedColumn],
76    ) -> Result<String, ScytheError> {
77        let struct_name = row_struct_name(query_name, &self.manifest.naming);
78        let mut out = String::new();
79        let _ = writeln!(out, "@dataclass");
80        let _ = writeln!(out, "class {}:", struct_name);
81        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
82        if columns.is_empty() {
83            let _ = writeln!(out, "    pass");
84        } else {
85            let _ = writeln!(out);
86            for col in columns {
87                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
88            }
89        }
90        Ok(out)
91    }
92
93    fn generate_model_struct(
94        &self,
95        table_name: &str,
96        columns: &[ResolvedColumn],
97    ) -> Result<String, ScytheError> {
98        let singular = singularize(table_name);
99        let name = to_pascal_case(&singular);
100        self.generate_row_struct(&name, columns)
101    }
102
103    fn generate_query_fn(
104        &self,
105        analyzed: &AnalyzedQuery,
106        struct_name: &str,
107        columns: &[ResolvedColumn],
108        params: &[ResolvedParam],
109    ) -> Result<String, ScytheError> {
110        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
111        let mut out = String::new();
112
113        // Build parameter list (keyword-only after conn)
114        let param_list = params
115            .iter()
116            .map(|p| format!("{}: {}", p.field_name, p.full_type))
117            .collect::<Vec<_>>()
118            .join(", ");
119        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
120
121        // Clean and rewrite SQL
122        let sql_clean = super::clean_sql(&analyzed.sql);
123        let sql = rewrite_params_named(&sql_clean, analyzed);
124
125        match &analyzed.command {
126            QueryCommand::One => {
127                let _ = writeln!(
128                    out,
129                    "async def {}(conn: AsyncConnection{}{}) -> {} | None:",
130                    func_name, kw_sep, param_list, struct_name
131                );
132                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
133                // Build params dict
134                if params.is_empty() {
135                    let _ = writeln!(out, "    cur = await conn.execute(");
136                    let _ = writeln!(out, "        \"{}\",", sql);
137                    let _ = writeln!(out, "    )");
138                } else {
139                    let dict_entries: Vec<String> = params
140                        .iter()
141                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
142                        .collect();
143                    let _ = writeln!(out, "    cur = await conn.execute(");
144                    let _ = writeln!(out, "        \"{}\",", sql);
145                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
146                    let _ = writeln!(out, "    )");
147                }
148                let _ = writeln!(out, "    row = await cur.fetchone()");
149                let _ = writeln!(out, "    if row is None:");
150                let _ = writeln!(out, "        return None");
151                // Construct dataclass from positional row
152                let field_assignments: Vec<String> = columns
153                    .iter()
154                    .enumerate()
155                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
156                    .collect();
157                let oneliner = format!(
158                    "    return {}({})",
159                    struct_name,
160                    field_assignments.join(", ")
161                );
162                if oneliner.len() <= 88 {
163                    let _ = writeln!(out, "{}", oneliner);
164                } else {
165                    let _ = writeln!(out, "    return {}(", struct_name);
166                    for fa in &field_assignments {
167                        let _ = writeln!(out, "        {},", fa);
168                    }
169                    let _ = writeln!(out, "    )");
170                }
171            }
172            QueryCommand::Many | QueryCommand::Batch => {
173                let _ = writeln!(
174                    out,
175                    "async def {}(conn: AsyncConnection{}{}) -> list[{}]:",
176                    func_name, kw_sep, param_list, struct_name
177                );
178                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
179                if params.is_empty() {
180                    let _ = writeln!(out, "    cur = await conn.execute(");
181                    let _ = writeln!(out, "        \"{}\",", sql);
182                    let _ = writeln!(out, "    )");
183                } else {
184                    let dict_entries: Vec<String> = params
185                        .iter()
186                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
187                        .collect();
188                    let _ = writeln!(out, "    cur = await conn.execute(");
189                    let _ = writeln!(out, "        \"{}\",", sql);
190                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
191                    let _ = writeln!(out, "    )");
192                }
193                let _ = writeln!(out, "    rows = await cur.fetchall()");
194                let field_assignments: Vec<String> = columns
195                    .iter()
196                    .enumerate()
197                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
198                    .collect();
199                let oneliner = format!(
200                    "    return [{}({}) for r in rows]",
201                    struct_name,
202                    field_assignments.join(", ")
203                );
204                if oneliner.len() <= 88 {
205                    let _ = writeln!(out, "{}", oneliner);
206                } else {
207                    let _ = writeln!(out, "    return [");
208                    let _ = writeln!(out, "        {}(", struct_name);
209                    for fa in &field_assignments {
210                        let _ = writeln!(out, "            {},", fa);
211                    }
212                    let _ = writeln!(out, "        )");
213                    let _ = writeln!(out, "        for r in rows");
214                    let _ = writeln!(out, "    ]");
215                }
216            }
217            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
218                let _ = writeln!(
219                    out,
220                    "async def {}(conn: AsyncConnection{}{}) -> None:",
221                    func_name, kw_sep, param_list
222                );
223                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
224                if params.is_empty() {
225                    let _ = writeln!(out, "    await conn.execute(");
226                    let _ = writeln!(out, "        \"{}\",", sql);
227                    let _ = writeln!(out, "    )");
228                } else {
229                    let dict_entries: Vec<String> = params
230                        .iter()
231                        .map(|p| format!("\"{}\": {}", p.field_name, p.field_name))
232                        .collect();
233                    let _ = writeln!(out, "    await conn.execute(");
234                    let _ = writeln!(out, "        \"{}\",", sql);
235                    let _ = writeln!(out, "        {{{}}},", dict_entries.join(", "));
236                    let _ = writeln!(out, "    )");
237                }
238            }
239        }
240
241        Ok(out)
242    }
243
244    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
245        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
246        let mut out = String::new();
247        let _ = writeln!(out, "class {}(str, Enum):", type_name);
248        let _ = writeln!(
249            out,
250            "    \"\"\"Database enum type {}.\"\"\"",
251            enum_info.sql_name
252        );
253        if enum_info.values.is_empty() {
254            let _ = writeln!(out, "    pass");
255        } else {
256            let _ = writeln!(out);
257            for value in &enum_info.values {
258                let variant = enum_variant_name(value, &self.manifest.naming);
259                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
260            }
261        }
262        Ok(out)
263    }
264
265    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
266        let name = to_pascal_case(&composite.sql_name);
267        let mut out = String::new();
268        let _ = writeln!(out, "@dataclass");
269        let _ = writeln!(out, "class {}:", name);
270        let _ = writeln!(
271            out,
272            "    \"\"\"Composite type {}.\"\"\"",
273            composite.sql_name
274        );
275        if composite.fields.is_empty() {
276            let _ = writeln!(out, "    pass");
277        } else {
278            let _ = writeln!(out);
279            for field in &composite.fields {
280                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
281                    .map(|t| t.into_owned())
282                    .map_err(|e| {
283                        ScytheError::new(
284                            ErrorCode::InternalError,
285                            format!("composite field type error: {}", e),
286                        )
287                    })?;
288                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
289            }
290        }
291        Ok(out)
292    }
293}