Skip to main content

scythe_codegen/backends/
python_aiomysql.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiomysql.toml");
18
19pub struct PythonAiomysqlBackend {
20    manifest: BackendManifest,
21}
22
23impl PythonAiomysqlBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "mysql" | "mariadb" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "python-aiomysql only supports MySQL/MariaDB, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/python-aiomysql/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49/// Rewrite $1, $2, ... positional params to %s for aiomysql.
50fn rewrite_params_to_percent_s(sql: &str) -> String {
51    let mut result = sql.to_string();
52    for i in (1..=99).rev() {
53        let from = format!("${}", i);
54        result = result.replace(&from, "%s");
55    }
56    result
57}
58
59impl CodegenBackend for PythonAiomysqlBackend {
60    fn name(&self) -> &str {
61        "python-aiomysql"
62    }
63
64    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
65        &self.manifest
66    }
67
68    fn supported_engines(&self) -> &[&str] {
69        &["mysql"]
70    }
71
72    fn file_header(&self) -> String {
73        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
74         \n\
75         import datetime  # noqa: F401\n\
76         import decimal  # noqa: F401\n\
77         from dataclasses import dataclass\n\
78         from enum import Enum  # noqa: F401\n\
79         \n\
80         import aiomysql  # noqa: F401\n\
81         \n"
82        .to_string()
83    }
84
85    fn generate_row_struct(
86        &self,
87        query_name: &str,
88        columns: &[ResolvedColumn],
89    ) -> Result<String, ScytheError> {
90        let struct_name = row_struct_name(query_name, &self.manifest.naming);
91        let mut out = String::new();
92        let _ = writeln!(out, "@dataclass");
93        let _ = writeln!(out, "class {}:", struct_name);
94        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
95        if columns.is_empty() {
96            let _ = writeln!(out, "    pass");
97        } else {
98            let _ = writeln!(out);
99            for col in columns {
100                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
101            }
102        }
103        Ok(out)
104    }
105
106    fn generate_model_struct(
107        &self,
108        table_name: &str,
109        columns: &[ResolvedColumn],
110    ) -> Result<String, ScytheError> {
111        let singular = singularize(table_name);
112        let name = to_pascal_case(&singular);
113        self.generate_row_struct(&name, columns)
114    }
115
116    fn generate_query_fn(
117        &self,
118        analyzed: &AnalyzedQuery,
119        struct_name: &str,
120        columns: &[ResolvedColumn],
121        params: &[ResolvedParam],
122    ) -> Result<String, ScytheError> {
123        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
124        let mut out = String::new();
125
126        let param_list = params
127            .iter()
128            .map(|p| format!("{}: {}", p.field_name, p.full_type))
129            .collect::<Vec<_>>()
130            .join(", ");
131        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
132
133        let sql = rewrite_params_to_percent_s(&super::clean_sql(&analyzed.sql));
134
135        let args_tuple = if params.is_empty() {
136            String::new()
137        } else {
138            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
139            if args.len() == 1 {
140                format!("({},)", args[0])
141            } else {
142                format!("({})", args.join(", "))
143            }
144        };
145
146        match &analyzed.command {
147            QueryCommand::One => {
148                let _ = writeln!(
149                    out,
150                    "async def {}(conn: aiomysql.Connection{}{}) -> {} | None:",
151                    func_name, kw_sep, param_list, struct_name
152                );
153                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
154                let _ = writeln!(out, "    async with conn.cursor() as cur:");
155                if params.is_empty() {
156                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
157                } else {
158                    let _ = writeln!(
159                        out,
160                        "        await cur.execute(\"\"\"{}\"\"\", {})",
161                        sql, args_tuple
162                    );
163                }
164                let _ = writeln!(out, "        row = await cur.fetchone()");
165                let _ = writeln!(out, "    if row is None:");
166                let _ = writeln!(out, "        return None");
167                let field_assignments: Vec<String> = columns
168                    .iter()
169                    .enumerate()
170                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
171                    .collect();
172                let oneliner = format!(
173                    "    return {}({})",
174                    struct_name,
175                    field_assignments.join(", ")
176                );
177                if oneliner.len() <= 88 {
178                    let _ = writeln!(out, "{}", oneliner);
179                } else {
180                    let _ = writeln!(out, "    return {}(", struct_name);
181                    for fa in &field_assignments {
182                        let _ = writeln!(out, "        {},", fa);
183                    }
184                    let _ = writeln!(out, "    )");
185                }
186            }
187            QueryCommand::Many | QueryCommand::Batch => {
188                let _ = writeln!(
189                    out,
190                    "async def {}(conn: aiomysql.Connection{}{}) -> list[{}]:",
191                    func_name, kw_sep, param_list, struct_name
192                );
193                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
194                let _ = writeln!(out, "    async with conn.cursor() as cur:");
195                if params.is_empty() {
196                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
197                } else {
198                    let _ = writeln!(
199                        out,
200                        "        await cur.execute(\"\"\"{}\"\"\", {})",
201                        sql, args_tuple
202                    );
203                }
204                let _ = writeln!(out, "        rows = await cur.fetchall()");
205                let field_assignments: Vec<String> = columns
206                    .iter()
207                    .enumerate()
208                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
209                    .collect();
210                let oneliner = format!(
211                    "    return [{}({}) for r in rows]",
212                    struct_name,
213                    field_assignments.join(", ")
214                );
215                if oneliner.len() <= 88 {
216                    let _ = writeln!(out, "{}", oneliner);
217                } else {
218                    let _ = writeln!(out, "    return [");
219                    let _ = writeln!(out, "        {}(", struct_name);
220                    for fa in &field_assignments {
221                        let _ = writeln!(out, "            {},", fa);
222                    }
223                    let _ = writeln!(out, "        )");
224                    let _ = writeln!(out, "        for r in rows");
225                    let _ = writeln!(out, "    ]");
226                }
227            }
228            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
229                let _ = writeln!(
230                    out,
231                    "async def {}(conn: aiomysql.Connection{}{}) -> None:",
232                    func_name, kw_sep, param_list
233                );
234                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
235                let _ = writeln!(out, "    async with conn.cursor() as cur:");
236                if params.is_empty() {
237                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
238                } else {
239                    let _ = writeln!(
240                        out,
241                        "        await cur.execute(\"\"\"{}\"\"\", {})",
242                        sql, args_tuple
243                    );
244                }
245            }
246        }
247
248        Ok(out)
249    }
250
251    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
252        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
253        let mut out = String::new();
254        let _ = writeln!(out, "class {}(str, Enum):", type_name);
255        let _ = writeln!(
256            out,
257            "    \"\"\"Database enum type {}.\"\"\"",
258            enum_info.sql_name
259        );
260        if enum_info.values.is_empty() {
261            let _ = writeln!(out, "    pass");
262        } else {
263            let _ = writeln!(out);
264            for value in &enum_info.values {
265                let variant = enum_variant_name(value, &self.manifest.naming);
266                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
267            }
268        }
269        Ok(out)
270    }
271
272    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
273        let name = to_pascal_case(&composite.sql_name);
274        let mut out = String::new();
275        let _ = writeln!(out, "@dataclass");
276        let _ = writeln!(out, "class {}:", name);
277        let _ = writeln!(
278            out,
279            "    \"\"\"Composite type {}.\"\"\"",
280            composite.sql_name
281        );
282        if composite.fields.is_empty() {
283            let _ = writeln!(out, "    pass");
284        } else {
285            let _ = writeln!(out);
286            for field in &composite.fields {
287                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
288                    .map(|t| t.into_owned())
289                    .map_err(|e| {
290                        ScytheError::new(
291                            ErrorCode::InternalError,
292                            format!("composite field type error: {}", e),
293                        )
294                    })?;
295                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
296            }
297        }
298        Ok(out)
299    }
300}