Skip to main content

scythe_codegen/backends/
python_asyncpg.rs

1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-asyncpg.toml");
21
22pub struct PythonAsyncpgBackend {
23    manifest: BackendManifest,
24    row_type: PythonRowType,
25}
26
27impl PythonAsyncpgBackend {
28    pub fn new(engine: &str) -> Result<Self, ScytheError> {
29        match engine {
30            "postgresql" | "postgres" | "pg" => {}
31            _ => {
32                return Err(ScytheError::new(
33                    ErrorCode::InternalError,
34                    format!(
35                        "python-asyncpg only supports PostgreSQL, got engine '{}'",
36                        engine
37                    ),
38                ));
39            }
40        }
41        let manifest_path = Path::new("backends/python-asyncpg/manifest.toml");
42        let manifest = if manifest_path.exists() {
43            load_manifest(manifest_path)
44                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45        } else {
46            toml::from_str(DEFAULT_MANIFEST_TOML)
47                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48        };
49        Ok(Self {
50            manifest,
51            row_type: PythonRowType::default(),
52        })
53    }
54}
55
56impl CodegenBackend for PythonAsyncpgBackend {
57    fn name(&self) -> &str {
58        "python-asyncpg"
59    }
60
61    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
62        &self.manifest
63    }
64
65    fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
66        if let Some(rt) = options.get("row_type") {
67            self.row_type = PythonRowType::from_option(rt)?;
68        }
69        Ok(())
70    }
71
72    fn file_header(&self) -> String {
73        let import_line = self.row_type.import_line();
74        if self.row_type.is_stdlib_import() {
75            format!(
76                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
77                 \n\
78                 import datetime  # noqa: F401\n\
79                 import decimal  # noqa: F401\n\
80                 {import_line}\n\
81                 from enum import Enum  # noqa: F401\n\
82                 \n\
83                 from asyncpg import Connection  # noqa: F401\n\
84                 \n",
85            )
86        } else {
87            let third_party = self
88                .row_type
89                .sorted_third_party_imports("from asyncpg import Connection  # noqa: F401");
90            format!(
91                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
92                 \n\
93                 import datetime  # noqa: F401\n\
94                 import decimal  # noqa: F401\n\
95                 from enum import Enum  # noqa: F401\n\
96                 \n\
97                 {third_party}\n\
98                 \n",
99            )
100        }
101    }
102
103    fn generate_row_struct(
104        &self,
105        query_name: &str,
106        columns: &[ResolvedColumn],
107    ) -> Result<String, ScytheError> {
108        let struct_name = row_struct_name(query_name, &self.manifest.naming);
109        let mut out = String::new();
110        let _ = write!(out, "{}", self.row_type.decorator());
111        let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
112        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
113        if columns.is_empty() {
114            let _ = writeln!(out, "    pass");
115        } else {
116            let _ = writeln!(out);
117            for col in columns {
118                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
119            }
120        }
121        Ok(out)
122    }
123
124    fn generate_model_struct(
125        &self,
126        table_name: &str,
127        columns: &[ResolvedColumn],
128    ) -> Result<String, ScytheError> {
129        let singular = singularize(table_name);
130        let name = to_pascal_case(&singular);
131        self.generate_row_struct(&name, columns)
132    }
133
134    fn generate_query_fn(
135        &self,
136        analyzed: &AnalyzedQuery,
137        struct_name: &str,
138        columns: &[ResolvedColumn],
139        params: &[ResolvedParam],
140    ) -> Result<String, ScytheError> {
141        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
142        let mut out = String::new();
143
144        // Build parameter list (keyword-only after conn)
145        let param_list = params
146            .iter()
147            .map(|p| format!("{}: {}", p.field_name, p.full_type))
148            .collect::<Vec<_>>()
149            .join(", ");
150        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
151
152        // Clean SQL — asyncpg uses $1, $2 positional params natively
153        let sql = super::clean_sql_with_optional(
154            &analyzed.sql,
155            &analyzed.optional_params,
156            &analyzed.params,
157        );
158
159        match &analyzed.command {
160            QueryCommand::One => {
161                let _ = writeln!(
162                    out,
163                    "async def {}(conn: Connection{}{}) -> {} | None:",
164                    func_name, kw_sep, param_list, struct_name
165                );
166                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
167                let _ = writeln!(out, "    row = await conn.fetchrow(");
168                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
169                if !params.is_empty() {
170                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
171                    let _ = writeln!(out, "        {},", args.join(", "));
172                }
173                let _ = writeln!(out, "    )");
174                let _ = writeln!(out, "    if row is None:");
175                let _ = writeln!(out, "        return None");
176                let field_assignments: Vec<String> = columns
177                    .iter()
178                    .map(|col| format!("{}=row[\"{}\"]", col.field_name, col.name))
179                    .collect();
180                let oneliner = format!(
181                    "    return {}({})",
182                    struct_name,
183                    field_assignments.join(", ")
184                );
185                if oneliner.len() <= 88 {
186                    let _ = writeln!(out, "{}", oneliner);
187                } else {
188                    let _ = writeln!(out, "    return {}(", struct_name);
189                    for fa in &field_assignments {
190                        let _ = writeln!(out, "        {},", fa);
191                    }
192                    let _ = writeln!(out, "    )");
193                }
194            }
195            QueryCommand::Many => {
196                let _ = writeln!(
197                    out,
198                    "async def {}(conn: Connection{}{}) -> list[{}]:",
199                    func_name, kw_sep, param_list, struct_name
200                );
201                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
202                let _ = writeln!(out, "    rows = await conn.fetch(");
203                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
204                if !params.is_empty() {
205                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
206                    let _ = writeln!(out, "        {},", args.join(", "));
207                }
208                let _ = writeln!(out, "    )");
209                let field_assignments: Vec<String> = columns
210                    .iter()
211                    .map(|col| format!("{}=r[\"{}\"]", col.field_name, col.name))
212                    .collect();
213                let oneliner = format!(
214                    "    return [{}({}) for r in rows]",
215                    struct_name,
216                    field_assignments.join(", ")
217                );
218                if oneliner.len() <= 88 {
219                    let _ = writeln!(out, "{}", oneliner);
220                } else {
221                    let _ = writeln!(out, "    return [");
222                    let _ = writeln!(out, "        {}(", struct_name);
223                    for fa in &field_assignments {
224                        let _ = writeln!(out, "            {},", fa);
225                    }
226                    let _ = writeln!(out, "        )");
227                    let _ = writeln!(out, "        for r in rows");
228                    let _ = writeln!(out, "    ]");
229                }
230            }
231            QueryCommand::Batch => {
232                let batch_fn_name = format!("{}_batch", func_name);
233                let items_type = if params.len() > 1 {
234                    let tuple_types: Vec<String> =
235                        params.iter().map(|p| p.full_type.clone()).collect();
236                    format!("list[tuple[{}]]", tuple_types.join(", "))
237                } else if params.len() == 1 {
238                    format!("list[{}]", params[0].full_type)
239                } else {
240                    "int".to_string()
241                };
242                let _ = writeln!(
243                    out,
244                    "async def {}(conn: Connection, *, items: {}) -> None:",
245                    batch_fn_name, items_type
246                );
247                let _ = writeln!(
248                    out,
249                    "    \"\"\"Execute {} query for each item in the batch.\"\"\"",
250                    analyzed.name
251                );
252                if params.is_empty() {
253                    let _ = writeln!(out, "    for _ in range(items):");
254                    let _ = writeln!(out, "        await conn.execute(");
255                    let _ = writeln!(out, "            \"\"\"{}\"\"\",", sql);
256                    let _ = writeln!(out, "        )");
257                } else {
258                    // asyncpg executemany takes list of tuples
259                    if params.len() == 1 {
260                        let _ = writeln!(out, "    args = [(item,) for item in items]");
261                    } else {
262                        let _ = writeln!(out, "    args = items");
263                    }
264                    let _ = writeln!(out, "    await conn.executemany(");
265                    let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
266                    let _ = writeln!(out, "        args,");
267                    let _ = writeln!(out, "    )");
268                }
269            }
270            QueryCommand::Exec => {
271                let _ = writeln!(
272                    out,
273                    "async def {}(conn: Connection{}{}) -> None:",
274                    func_name, kw_sep, param_list
275                );
276                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
277                let _ = writeln!(out, "    await conn.execute(");
278                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
279                if !params.is_empty() {
280                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
281                    let _ = writeln!(out, "        {},", args.join(", "));
282                }
283                let _ = writeln!(out, "    )");
284            }
285            QueryCommand::ExecResult | QueryCommand::ExecRows => {
286                let _ = writeln!(
287                    out,
288                    "async def {}(conn: Connection{}{}) -> int:",
289                    func_name, kw_sep, param_list
290                );
291                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
292                let _ = writeln!(out, "    result = await conn.execute(");
293                let _ = writeln!(out, "        \"\"\"{}\"\"\",", sql);
294                if !params.is_empty() {
295                    let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
296                    let _ = writeln!(out, "        {},", args.join(", "));
297                }
298                let _ = writeln!(out, "    )");
299                let _ = writeln!(out, "    return int(result.split()[-1])");
300            }
301            QueryCommand::Grouped => {
302                unreachable!("Grouped is rewritten to Many before codegen")
303            }
304        }
305
306        Ok(out)
307    }
308
309    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
310        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
311        let mut out = String::new();
312        let _ = writeln!(out, "class {}(str, Enum):", type_name);
313        let _ = writeln!(
314            out,
315            "    \"\"\"Database enum type {}.\"\"\"",
316            enum_info.sql_name
317        );
318        if enum_info.values.is_empty() {
319            let _ = writeln!(out, "    pass");
320        } else {
321            let _ = writeln!(out);
322            for value in &enum_info.values {
323                let variant = enum_variant_name(value, &self.manifest.naming);
324                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
325            }
326        }
327        Ok(out)
328    }
329
330    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
331        let name = to_pascal_case(&composite.sql_name);
332        let mut out = String::new();
333        let _ = write!(out, "{}", self.row_type.decorator());
334        let _ = writeln!(out, "{}", self.row_type.class_def(&name));
335        let _ = writeln!(
336            out,
337            "    \"\"\"Composite type {}.\"\"\"",
338            composite.sql_name
339        );
340        if composite.fields.is_empty() {
341            let _ = writeln!(out, "    pass");
342        } else {
343            let _ = writeln!(out);
344            for field in &composite.fields {
345                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
346                    .map(|t| t.into_owned())
347                    .map_err(|e| {
348                        ScytheError::new(
349                            ErrorCode::InternalError,
350                            format!("composite field type error: {}", e),
351                        )
352                    })?;
353                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
354            }
355        }
356        Ok(out)
357    }
358}