Skip to main content

scythe_codegen/backends/
python_asyncpg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");
18
19pub struct PythonAsyncpgBackend {
20    manifest: BackendManifest,
21}
22
23impl PythonAsyncpgBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "postgresql" | "postgres" | "pg" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "python-asyncpg only supports PostgreSQL, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49impl CodegenBackend for PythonAsyncpgBackend {
50    fn name(&self) -> &str {
51        "python-asyncpg"
52    }
53
54    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
55        &self.manifest
56    }
57
58    fn file_header(&self) -> String {
59        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
60         \n\
61         import datetime  # noqa: F401\n\
62         import decimal  # noqa: F401\n\
63         from dataclasses import dataclass\n\
64         from enum import Enum  # noqa: F401\n\
65         \n\
66         from asyncpg import Connection  # noqa: F401\n\
67         \n"
68        .to_string()
69    }
70
71    fn generate_row_struct(
72        &self,
73        query_name: &str,
74        columns: &[ResolvedColumn],
75    ) -> Result<String, ScytheError> {
76        let struct_name = row_struct_name(query_name, &self.manifest.naming);
77        let mut out = String::new();
78        let _ = writeln!(out, "@dataclass");
79        let _ = writeln!(out, "class {}:", struct_name);
80        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
81        if columns.is_empty() {
82            let _ = writeln!(out, "    pass");
83        } else {
84            let _ = writeln!(out);
85            for col in columns {
86                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
87            }
88        }
89        Ok(out)
90    }
91
92    fn generate_model_struct(
93        &self,
94        table_name: &str,
95        columns: &[ResolvedColumn],
96    ) -> Result<String, ScytheError> {
97        let singular = singularize(table_name);
98        let name = to_pascal_case(&singular);
99        self.generate_row_struct(&name, columns)
100    }
101
102    fn generate_query_fn(
103        &self,
104        analyzed: &AnalyzedQuery,
105        struct_name: &str,
106        columns: &[ResolvedColumn],
107        params: &[ResolvedParam],
108    ) -> Result<String, ScytheError> {
109        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
110        let mut out = String::new();
111
112        // Build parameter list (keyword-only after conn)
113        let param_list = params
114            .iter()
115            .map(|p| format!("{}: {}", p.field_name, p.full_type))
116            .collect::<Vec<_>>()
117            .join(", ");
118        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
119
120        // Clean SQL — asyncpg uses $1, $2 positional params natively
121        let sql = super::clean_sql(&analyzed.sql);
122
123        match &analyzed.command {
124            QueryCommand::One => {
125                let _ = writeln!(
126                    out,
127                    "async def {}(conn: Connection{}{}) -> {} | None:",
128                    func_name, kw_sep, param_list, struct_name
129                );
130                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
131                let _ = writeln!(out, "    row = await conn.fetchrow(");
132                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
133                if !params.is_empty() {
134                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
135                    let _ = writeln!(out, "        {},", args.join(", "));
136                }
137                let _ = writeln!(out, "    )");
138                let _ = writeln!(out, "    if row is None:");
139                let _ = writeln!(out, "        return None");
140                let field_assignments: Vec<String> = columns
141                    .iter()
142                    .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
143                    .collect();
144                let oneliner = format!(
145                    "    return {}({})",
146                    struct_name,
147                    field_assignments.join(", ")
148                );
149                if oneliner.len() <= 88 {
150                    let _ = writeln!(out, "{}", oneliner);
151                } else {
152                    let _ = writeln!(out, "    return {}(", struct_name);
153                    for fa in &field_assignments {
154                        let _ = writeln!(out, "        {},", fa);
155                    }
156                    let _ = writeln!(out, "    )");
157                }
158            }
159            QueryCommand::Many | QueryCommand::Batch => {
160                let _ = writeln!(
161                    out,
162                    "async def {}(conn: Connection{}{}) -> list[{}]:",
163                    func_name, kw_sep, param_list, struct_name
164                );
165                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
166                let _ = writeln!(out, "    rows = await conn.fetch(");
167                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
168                if !params.is_empty() {
169                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
170                    let _ = writeln!(out, "        {},", args.join(", "));
171                }
172                let _ = writeln!(out, "    )");
173                let field_assignments: Vec<String> = columns
174                    .iter()
175                    .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
176                    .collect();
177                let oneliner = format!(
178                    "    return [{}({}) for r in rows]",
179                    struct_name,
180                    field_assignments.join(", ")
181                );
182                if oneliner.len() <= 88 {
183                    let _ = writeln!(out, "{}", oneliner);
184                } else {
185                    let _ = writeln!(out, "    return [");
186                    let _ = writeln!(out, "        {}(", struct_name);
187                    for fa in &field_assignments {
188                        let _ = writeln!(out, "            {},", fa);
189                    }
190                    let _ = writeln!(out, "        )");
191                    let _ = writeln!(out, "        for r in rows");
192                    let _ = writeln!(out, "    ]");
193                }
194            }
195            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
196                let _ = writeln!(
197                    out,
198                    "async def {}(conn: Connection{}{}) -> None:",
199                    func_name, kw_sep, param_list
200                );
201                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
202                let _ = writeln!(out, "    await conn.execute(");
203                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
204                if !params.is_empty() {
205                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
206                    let _ = writeln!(out, "        {},", args.join(", "));
207                }
208                let _ = writeln!(out, "    )");
209            }
210        }
211
212        Ok(out)
213    }
214
215    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
216        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
217        let mut out = String::new();
218        let _ = writeln!(out, "class {}(str, Enum):", type_name);
219        let _ = writeln!(
220            out,
221            "    \"\"\"Database enum type {}.\"\"\"",
222            enum_info.sql_name
223        );
224        if enum_info.values.is_empty() {
225            let _ = writeln!(out, "    pass");
226        } else {
227            let _ = writeln!(out);
228            for value in &enum_info.values {
229                let variant = enum_variant_name(value, &self.manifest.naming);
230                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
231            }
232        }
233        Ok(out)
234    }
235
236    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
237        let name = to_pascal_case(&composite.sql_name);
238        let mut out = String::new();
239        let _ = writeln!(out, "@dataclass");
240        let _ = writeln!(out, "class {}:", name);
241        let _ = writeln!(
242            out,
243            "    \"\"\"Composite type {}.\"\"\"",
244            composite.sql_name
245        );
246        if composite.fields.is_empty() {
247            let _ = writeln!(out, "    pass");
248        } else {
249            let _ = writeln!(out);
250            for field in &composite.fields {
251                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
252                    .map(|t| t.into_owned())
253                    .map_err(|e| {
254                        ScytheError::new(
255                            ErrorCode::InternalError,
256                            format!("composite field type error: {}", e),
257                        )
258                    })?;
259                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
260            }
261        }
262        Ok(out)
263    }
264}