Skip to main content

scythe_codegen/backends/
python_asyncpg.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");
18
19pub struct PythonAsyncpgBackend {
20    manifest: BackendManifest,
21}
22
23impl PythonAsyncpgBackend {
24    pub fn new() -> Result<Self, ScytheError> {
25        let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
26        let manifest = if manifest_path.exists() {
27            load_manifest(manifest_path)
28                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
29        } else {
30            toml::from_str(DEFAULT_MANIFEST_TOML)
31                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
32        };
33        Ok(Self { manifest })
34    }
35
36    pub fn manifest(&self) -> &BackendManifest {
37        &self.manifest
38    }
39}
40
41impl CodegenBackend for PythonAsyncpgBackend {
42    fn name(&self) -> &str {
43        "python-asyncpg"
44    }
45
46    fn file_header(&self) -> String {
47        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
48         \n\
49         import datetime  # noqa: F401\n\
50         from dataclasses import dataclass\n\
51         from enum import Enum  # noqa: F401\n\
52         \n\
53         from asyncpg import Connection  # noqa: F401\n\
54         \n"
55        .to_string()
56    }
57
58    fn generate_row_struct(
59        &self,
60        query_name: &str,
61        columns: &[ResolvedColumn],
62    ) -> Result<String, ScytheError> {
63        let struct_name = row_struct_name(query_name, &self.manifest.naming);
64        let mut out = String::new();
65        let _ = writeln!(out, "@dataclass");
66        let _ = writeln!(out, "class {}:", struct_name);
67        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
68        if columns.is_empty() {
69            let _ = writeln!(out, "    pass");
70        } else {
71            let _ = writeln!(out);
72            for col in columns {
73                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
74            }
75        }
76        Ok(out)
77    }
78
79    fn generate_model_struct(
80        &self,
81        table_name: &str,
82        columns: &[ResolvedColumn],
83    ) -> Result<String, ScytheError> {
84        let singular = singularize(table_name);
85        let name = to_pascal_case(&singular);
86        self.generate_row_struct(&name, columns)
87    }
88
89    fn generate_query_fn(
90        &self,
91        analyzed: &AnalyzedQuery,
92        struct_name: &str,
93        columns: &[ResolvedColumn],
94        params: &[ResolvedParam],
95    ) -> Result<String, ScytheError> {
96        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
97        let mut out = String::new();
98
99        // Build parameter list (keyword-only after conn)
100        let param_list = params
101            .iter()
102            .map(|p| format!("{}: {}", p.field_name, p.full_type))
103            .collect::<Vec<_>>()
104            .join(", ");
105        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
106
107        // Clean SQL — asyncpg uses $1, $2 positional params natively
108        let sql = super::clean_sql(&analyzed.sql);
109
110        match &analyzed.command {
111            QueryCommand::One => {
112                let _ = writeln!(
113                    out,
114                    "async def {}(conn: Connection{}{}) -> {} | None:",
115                    func_name, kw_sep, param_list, struct_name
116                );
117                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
118                let _ = writeln!(out, "    row = await conn.fetchrow(");
119                let _ = writeln!(out, "        \"{}\",", sql);
120                if !params.is_empty() {
121                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
122                    let _ = writeln!(out, "        {},", args.join(", "));
123                }
124                let _ = writeln!(out, "    )");
125                let _ = writeln!(out, "    if row is None:");
126                let _ = writeln!(out, "        return None");
127                let field_assignments: Vec<String> = columns
128                    .iter()
129                    .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
130                    .collect();
131                let oneliner = format!(
132                    "    return {}({})",
133                    struct_name,
134                    field_assignments.join(", ")
135                );
136                if oneliner.len() <= 88 {
137                    let _ = writeln!(out, "{}", oneliner);
138                } else {
139                    let _ = writeln!(out, "    return {}(", struct_name);
140                    for fa in &field_assignments {
141                        let _ = writeln!(out, "        {},", fa);
142                    }
143                    let _ = writeln!(out, "    )");
144                }
145            }
146            QueryCommand::Many | QueryCommand::Batch => {
147                let _ = writeln!(
148                    out,
149                    "async def {}(conn: Connection{}{}) -> list[{}]:",
150                    func_name, kw_sep, param_list, struct_name
151                );
152                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
153                let _ = writeln!(out, "    rows = await conn.fetch(");
154                let _ = writeln!(out, "        \"{}\",", sql);
155                if !params.is_empty() {
156                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
157                    let _ = writeln!(out, "        {},", args.join(", "));
158                }
159                let _ = writeln!(out, "    )");
160                let field_assignments: Vec<String> = columns
161                    .iter()
162                    .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
163                    .collect();
164                let oneliner = format!(
165                    "    return [{}({}) for r in rows]",
166                    struct_name,
167                    field_assignments.join(", ")
168                );
169                if oneliner.len() <= 88 {
170                    let _ = writeln!(out, "{}", oneliner);
171                } else {
172                    let _ = writeln!(out, "    return [");
173                    let _ = writeln!(out, "        {}(", struct_name);
174                    for fa in &field_assignments {
175                        let _ = writeln!(out, "            {},", fa);
176                    }
177                    let _ = writeln!(out, "        )");
178                    let _ = writeln!(out, "        for r in rows");
179                    let _ = writeln!(out, "    ]");
180                }
181            }
182            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
183                let _ = writeln!(
184                    out,
185                    "async def {}(conn: Connection{}{}) -> None:",
186                    func_name, kw_sep, param_list
187                );
188                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
189                let _ = writeln!(out, "    await conn.execute(");
190                let _ = writeln!(out, "        \"{}\",", sql);
191                if !params.is_empty() {
192                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
193                    let _ = writeln!(out, "        {},", args.join(", "));
194                }
195                let _ = writeln!(out, "    )");
196            }
197        }
198
199        Ok(out)
200    }
201
202    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
203        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
204        let mut out = String::new();
205        let _ = writeln!(out, "class {}(str, Enum):", type_name);
206        let _ = writeln!(
207            out,
208            "    \"\"\"Database enum type {}.\"\"\"",
209            enum_info.sql_name
210        );
211        if enum_info.values.is_empty() {
212            let _ = writeln!(out, "    pass");
213        } else {
214            let _ = writeln!(out);
215            for value in &enum_info.values {
216                let variant = enum_variant_name(value, &self.manifest.naming);
217                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
218            }
219        }
220        Ok(out)
221    }
222
223    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
224        let name = to_pascal_case(&composite.sql_name);
225        let mut out = String::new();
226        let _ = writeln!(out, "@dataclass");
227        let _ = writeln!(out, "class {}:", name);
228        let _ = writeln!(
229            out,
230            "    \"\"\"Composite type {}.\"\"\"",
231            composite.sql_name
232        );
233        if composite.fields.is_empty() {
234            let _ = writeln!(out, "    pass");
235        } else {
236            let _ = writeln!(out);
237            for field in &composite.fields {
238                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
239                    .map(|t| t.into_owned())
240                    .map_err(|e| {
241                        ScytheError::new(
242                            ErrorCode::InternalError,
243                            format!("composite field type error: {}", e),
244                        )
245                    })?;
246                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
247            }
248        }
249        Ok(out)
250    }
251}