Skip to main content

scythe_codegen/backends/
python_asyncpg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
18    include_str!("../../manifests/python-asyncpg.toml");
19
20pub struct PythonAsyncpgBackend {
21    manifest: BackendManifest,
22}
23
24impl PythonAsyncpgBackend {
25    pub fn new() -> Result<Self, ScytheError> {
26        let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
27        let manifest = if manifest_path.exists() {
28            load_manifest(manifest_path)
29                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
30        } else {
31            toml::from_str(DEFAULT_MANIFEST_TOML)
32                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
33        };
34        Ok(Self { manifest })
35    }
36
37    pub fn manifest(&self) -> &BackendManifest {
38        &self.manifest
39    }
40}
41
42impl CodegenBackend for PythonAsyncpgBackend {
43    fn name(&self) -> &str {
44        "python-asyncpg"
45    }
46
47    fn file_header(&self) -> String {
48        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
49         \n\
50         import datetime  # noqa: F401\n\
51         from dataclasses import dataclass\n\
52         from enum import Enum  # noqa: F401\n\
53         \n\
54         from asyncpg import Connection  # noqa: F401\n\
55         \n"
56        .to_string()
57    }
58
59    fn generate_row_struct(
60        &self,
61        query_name: &str,
62        columns: &[ResolvedColumn],
63    ) -> Result<String, ScytheError> {
64        let struct_name = row_struct_name(query_name, &self.manifest.naming);
65        let mut out = String::new();
66        let _ = writeln!(out, "@dataclass");
67        let _ = writeln!(out, "class {}:", struct_name);
68        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
69        if columns.is_empty() {
70            let _ = writeln!(out, "    pass");
71        } else {
72            let _ = writeln!(out);
73            for col in columns {
74                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
75            }
76        }
77        Ok(out)
78    }
79
80    fn generate_model_struct(
81        &self,
82        table_name: &str,
83        columns: &[ResolvedColumn],
84    ) -> Result<String, ScytheError> {
85        let singular = singularize(table_name);
86        let name = to_pascal_case(&singular);
87        self.generate_row_struct(&name, columns)
88    }
89
90    fn generate_query_fn(
91        &self,
92        analyzed: &AnalyzedQuery,
93        struct_name: &str,
94        columns: &[ResolvedColumn],
95        params: &[ResolvedParam],
96    ) -> Result<String, ScytheError> {
97        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
98        let mut out = String::new();
99
100        // Build parameter list (keyword-only after conn)
101        let param_list = params
102            .iter()
103            .map(|p| format!("{}: {}", p.field_name, p.full_type))
104            .collect::<Vec<_>>()
105            .join(", ");
106        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
107
108        // Clean SQL — asyncpg uses $1, $2 positional params natively
109        let sql = super::clean_sql(&analyzed.sql);
110
111        match &analyzed.command {
112            QueryCommand::One => {
113                let _ = writeln!(
114                    out,
115                    "async def {}(conn: Connection{}{}) -> {} | None:",
116                    func_name, kw_sep, param_list, struct_name
117                );
118                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
119                let _ = writeln!(out, "    row = await conn.fetchrow(");
120                let _ = writeln!(out, "        \"{}\",", sql);
121                if !params.is_empty() {
122                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
123                    let _ = writeln!(out, "        {},", args.join(", "));
124                }
125                let _ = writeln!(out, "    )");
126                let _ = writeln!(out, "    if row is None:");
127                let _ = writeln!(out, "        return None");
128                let field_assignments: Vec<String> = columns
129                    .iter()
130                    .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
131                    .collect();
132                let oneliner = format!(
133                    "    return {}({})",
134                    struct_name,
135                    field_assignments.join(", ")
136                );
137                if oneliner.len() <= 88 {
138                    let _ = writeln!(out, "{}", oneliner);
139                } else {
140                    let _ = writeln!(out, "    return {}(", struct_name);
141                    for fa in &field_assignments {
142                        let _ = writeln!(out, "        {},", fa);
143                    }
144                    let _ = writeln!(out, "    )");
145                }
146            }
147            QueryCommand::Many | QueryCommand::Batch => {
148                let _ = writeln!(
149                    out,
150                    "async def {}(conn: Connection{}{}) -> list[{}]:",
151                    func_name, kw_sep, param_list, struct_name
152                );
153                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
154                let _ = writeln!(out, "    rows = await conn.fetch(");
155                let _ = writeln!(out, "        \"{}\",", sql);
156                if !params.is_empty() {
157                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
158                    let _ = writeln!(out, "        {},", args.join(", "));
159                }
160                let _ = writeln!(out, "    )");
161                let field_assignments: Vec<String> = columns
162                    .iter()
163                    .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
164                    .collect();
165                let oneliner = format!(
166                    "    return [{}({}) for r in rows]",
167                    struct_name,
168                    field_assignments.join(", ")
169                );
170                if oneliner.len() <= 88 {
171                    let _ = writeln!(out, "{}", oneliner);
172                } else {
173                    let _ = writeln!(out, "    return [");
174                    let _ = writeln!(out, "        {}(", struct_name);
175                    for fa in &field_assignments {
176                        let _ = writeln!(out, "            {},", fa);
177                    }
178                    let _ = writeln!(out, "        )");
179                    let _ = writeln!(out, "        for r in rows");
180                    let _ = writeln!(out, "    ]");
181                }
182            }
183            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
184                let _ = writeln!(
185                    out,
186                    "async def {}(conn: Connection{}{}) -> None:",
187                    func_name, kw_sep, param_list
188                );
189                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
190                let _ = writeln!(out, "    await conn.execute(");
191                let _ = writeln!(out, "        \"{}\",", sql);
192                if !params.is_empty() {
193                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
194                    let _ = writeln!(out, "        {},", args.join(", "));
195                }
196                let _ = writeln!(out, "    )");
197            }
198        }
199
200        Ok(out)
201    }
202
203    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
204        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
205        let mut out = String::new();
206        let _ = writeln!(out, "class {}(str, Enum):", type_name);
207        let _ = writeln!(
208            out,
209            "    \"\"\"Database enum type {}.\"\"\"",
210            enum_info.sql_name
211        );
212        if enum_info.values.is_empty() {
213            let _ = writeln!(out, "    pass");
214        } else {
215            let _ = writeln!(out);
216            for value in &enum_info.values {
217                let variant = enum_variant_name(value, &self.manifest.naming);
218                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
219            }
220        }
221        Ok(out)
222    }
223
224    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
225        let name = to_pascal_case(&composite.sql_name);
226        let mut out = String::new();
227        let _ = writeln!(out, "@dataclass");
228        let _ = writeln!(out, "class {}:", name);
229        let _ = writeln!(
230            out,
231            "    \"\"\"Composite type {}.\"\"\"",
232            composite.sql_name
233        );
234        if composite.fields.is_empty() {
235            let _ = writeln!(out, "    pass");
236        } else {
237            let _ = writeln!(out);
238            for field in &composite.fields {
239                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
240                    .map(|t| t.into_owned())
241                    .map_err(|e| {
242                        ScytheError::new(
243                            ErrorCode::InternalError,
244                            format!("composite field type error: {}", e),
245                        )
246                    })?;
247                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
248            }
249        }
250        Ok(out)
251    }
252}