Skip to main content

scythe_codegen/backends/
python_aiomysql.rs

1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiomysql.toml");
21
22pub struct PythonAiomysqlBackend {
23    manifest: BackendManifest,
24    row_type: PythonRowType,
25}
26
27impl PythonAiomysqlBackend {
28    pub fn new(engine: &str) -> Result<Self, ScytheError> {
29        match engine {
30            "mysql" | "mariadb" => {}
31            _ => {
32                return Err(ScytheError::new(
33                    ErrorCode::InternalError,
34                    format!(
35                        "python-aiomysql only supports MySQL/MariaDB, got engine '{}'",
36                        engine
37                    ),
38                ));
39            }
40        }
41        let manifest_path = Path::new("backends/python-aiomysql/manifest.toml");
42        let manifest = if manifest_path.exists() {
43            load_manifest(manifest_path)
44                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45        } else {
46            toml::from_str(DEFAULT_MANIFEST_TOML)
47                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48        };
49        Ok(Self {
50            manifest,
51            row_type: PythonRowType::default(),
52        })
53    }
54}
55
56/// Rewrite $1, $2, ... positional params to %s for aiomysql.
57fn rewrite_params_to_percent_s(sql: &str) -> String {
58    let mut result = sql.to_string();
59    for i in (1..=99).rev() {
60        let from = format!("${}", i);
61        result = result.replace(&from, "%s");
62    }
63    result
64}
65
66impl CodegenBackend for PythonAiomysqlBackend {
67    fn name(&self) -> &str {
68        "python-aiomysql"
69    }
70
71    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
72        &self.manifest
73    }
74
75    fn supported_engines(&self) -> &[&str] {
76        &["mysql"]
77    }
78
79    fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
80        if let Some(rt) = options.get("row_type") {
81            self.row_type = PythonRowType::from_option(rt)?;
82        }
83        Ok(())
84    }
85
86    fn file_header(&self) -> String {
87        let import_line = self.row_type.import_line();
88        if self.row_type.is_stdlib_import() {
89            format!(
90                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
91                 \n\
92                 import datetime  # noqa: F401\n\
93                 import decimal  # noqa: F401\n\
94                 {import_line}\n\
95                 from enum import Enum  # noqa: F401\n\
96                 \n\
97                 import aiomysql  # noqa: F401\n\
98                 \n",
99            )
100        } else {
101            let third_party = self
102                .row_type
103                .sorted_third_party_imports("import aiomysql  # noqa: F401");
104            format!(
105                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
106                 \n\
107                 import datetime  # noqa: F401\n\
108                 import decimal  # noqa: F401\n\
109                 from enum import Enum  # noqa: F401\n\
110                 \n\
111                 {third_party}\n\
112                 \n",
113            )
114        }
115    }
116
117    fn generate_row_struct(
118        &self,
119        query_name: &str,
120        columns: &[ResolvedColumn],
121    ) -> Result<String, ScytheError> {
122        let struct_name = row_struct_name(query_name, &self.manifest.naming);
123        let mut out = String::new();
124        let _ = write!(out, "{}", self.row_type.decorator());
125        let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
126        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
127        if columns.is_empty() {
128            let _ = writeln!(out, "    pass");
129        } else {
130            let _ = writeln!(out);
131            for col in columns {
132                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
133            }
134        }
135        Ok(out)
136    }
137
138    fn generate_model_struct(
139        &self,
140        table_name: &str,
141        columns: &[ResolvedColumn],
142    ) -> Result<String, ScytheError> {
143        let singular = singularize(table_name);
144        let name = to_pascal_case(&singular);
145        self.generate_row_struct(&name, columns)
146    }
147
148    fn generate_query_fn(
149        &self,
150        analyzed: &AnalyzedQuery,
151        struct_name: &str,
152        columns: &[ResolvedColumn],
153        params: &[ResolvedParam],
154    ) -> Result<String, ScytheError> {
155        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
156        let mut out = String::new();
157
158        let param_list = params
159            .iter()
160            .map(|p| format!("{}: {}", p.field_name, p.full_type))
161            .collect::<Vec<_>>()
162            .join(", ");
163        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
164
165        let sql = rewrite_params_to_percent_s(&super::clean_sql_with_optional(
166            &analyzed.sql,
167            &analyzed.optional_params,
168            &analyzed.params,
169        ));
170
171        let args_tuple = if params.is_empty() {
172            String::new()
173        } else {
174            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
175            if args.len() == 1 {
176                format!("({},)", args[0])
177            } else {
178                format!("({})", args.join(", "))
179            }
180        };
181
182        match &analyzed.command {
183            QueryCommand::One => {
184                let _ = writeln!(
185                    out,
186                    "async def {}(conn: aiomysql.Connection{}{}) -> {} | None:",
187                    func_name, kw_sep, param_list, struct_name
188                );
189                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
190                let _ = writeln!(out, "    async with conn.cursor() as cur:");
191                if params.is_empty() {
192                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
193                } else {
194                    let _ = writeln!(
195                        out,
196                        "        await cur.execute(\"\"\"{}\"\"\", {})",
197                        sql, args_tuple
198                    );
199                }
200                let _ = writeln!(out, "        row = await cur.fetchone()");
201                let _ = writeln!(out, "    if row is None:");
202                let _ = writeln!(out, "        return None");
203                let field_assignments: Vec<String> = columns
204                    .iter()
205                    .enumerate()
206                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
207                    .collect();
208                let oneliner = format!(
209                    "    return {}({})",
210                    struct_name,
211                    field_assignments.join(", ")
212                );
213                if oneliner.len() <= 88 {
214                    let _ = writeln!(out, "{}", oneliner);
215                } else {
216                    let _ = writeln!(out, "    return {}(", struct_name);
217                    for fa in &field_assignments {
218                        let _ = writeln!(out, "        {},", fa);
219                    }
220                    let _ = writeln!(out, "    )");
221                }
222            }
223            QueryCommand::Batch => {
224                let batch_fn_name = format!("{}_batch", func_name);
225                let items_type = if params.len() > 1 {
226                    let tuple_types: Vec<String> =
227                        params.iter().map(|p| p.full_type.clone()).collect();
228                    format!("list[tuple[{}]]", tuple_types.join(", "))
229                } else if params.len() == 1 {
230                    format!("list[{}]", params[0].full_type)
231                } else {
232                    "int".to_string()
233                };
234                let _ = writeln!(
235                    out,
236                    "async def {}(conn: aiomysql.Connection, *, items: {}) -> None:",
237                    batch_fn_name, items_type
238                );
239                let _ = writeln!(
240                    out,
241                    "    \"\"\"Execute {} query for each item in the batch.\"\"\"",
242                    analyzed.name
243                );
244                let _ = writeln!(out, "    async with conn.cursor() as cur:");
245                if params.is_empty() {
246                    let _ = writeln!(out, "        for _ in range(items):");
247                    let _ = writeln!(out, "            await cur.execute(\"\"\"{}\"\"\")", sql);
248                } else if params.len() == 1 {
249                    let _ = writeln!(
250                        out,
251                        "        await cur.executemany(\"\"\"{}\"\"\", [(item,) for item in items])",
252                        sql
253                    );
254                } else {
255                    let _ = writeln!(
256                        out,
257                        "        await cur.executemany(\"\"\"{}\"\"\", items)",
258                        sql
259                    );
260                }
261                let _ = writeln!(out, "        await conn.commit()");
262            }
263            QueryCommand::Many => {
264                let _ = writeln!(
265                    out,
266                    "async def {}(conn: aiomysql.Connection{}{}) -> list[{}]:",
267                    func_name, kw_sep, param_list, struct_name
268                );
269                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
270                let _ = writeln!(out, "    async with conn.cursor() as cur:");
271                if params.is_empty() {
272                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
273                } else {
274                    let _ = writeln!(
275                        out,
276                        "        await cur.execute(\"\"\"{}\"\"\", {})",
277                        sql, args_tuple
278                    );
279                }
280                let _ = writeln!(out, "        rows = await cur.fetchall()");
281                let field_assignments: Vec<String> = columns
282                    .iter()
283                    .enumerate()
284                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
285                    .collect();
286                let oneliner = format!(
287                    "    return [{}({}) for r in rows]",
288                    struct_name,
289                    field_assignments.join(", ")
290                );
291                if oneliner.len() <= 88 {
292                    let _ = writeln!(out, "{}", oneliner);
293                } else {
294                    let _ = writeln!(out, "    return [");
295                    let _ = writeln!(out, "        {}(", struct_name);
296                    for fa in &field_assignments {
297                        let _ = writeln!(out, "            {},", fa);
298                    }
299                    let _ = writeln!(out, "        )");
300                    let _ = writeln!(out, "        for r in rows");
301                    let _ = writeln!(out, "    ]");
302                }
303            }
304            QueryCommand::Exec => {
305                let _ = writeln!(
306                    out,
307                    "async def {}(conn: aiomysql.Connection{}{}) -> None:",
308                    func_name, kw_sep, param_list
309                );
310                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
311                let _ = writeln!(out, "    async with conn.cursor() as cur:");
312                if params.is_empty() {
313                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
314                } else {
315                    let _ = writeln!(
316                        out,
317                        "        await cur.execute(\"\"\"{}\"\"\", {})",
318                        sql, args_tuple
319                    );
320                }
321            }
322            QueryCommand::ExecResult | QueryCommand::ExecRows => {
323                let _ = writeln!(
324                    out,
325                    "async def {}(conn: aiomysql.Connection{}{}) -> int:",
326                    func_name, kw_sep, param_list
327                );
328                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
329                let _ = writeln!(out, "    async with conn.cursor() as cur:");
330                if params.is_empty() {
331                    let _ = writeln!(out, "        await cur.execute(\"\"\"{}\"\"\")", sql);
332                } else {
333                    let _ = writeln!(
334                        out,
335                        "        await cur.execute(\"\"\"{}\"\"\", {})",
336                        sql, args_tuple
337                    );
338                }
339                let _ = writeln!(out, "        return cur.rowcount");
340            }
341            QueryCommand::Grouped => {
342                unreachable!("Grouped is rewritten to Many before codegen")
343            }
344        }
345
346        Ok(out)
347    }
348
349    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
350        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
351        let mut out = String::new();
352        let _ = writeln!(out, "class {}(str, Enum):", type_name);
353        let _ = writeln!(
354            out,
355            "    \"\"\"Database enum type {}.\"\"\"",
356            enum_info.sql_name
357        );
358        if enum_info.values.is_empty() {
359            let _ = writeln!(out, "    pass");
360        } else {
361            let _ = writeln!(out);
362            for value in &enum_info.values {
363                let variant = enum_variant_name(value, &self.manifest.naming);
364                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
365            }
366        }
367        Ok(out)
368    }
369
370    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
371        let name = to_pascal_case(&composite.sql_name);
372        let mut out = String::new();
373        let _ = write!(out, "{}", self.row_type.decorator());
374        let _ = writeln!(out, "{}", self.row_type.class_def(&name));
375        let _ = writeln!(
376            out,
377            "    \"\"\"Composite type {}.\"\"\"",
378            composite.sql_name
379        );
380        if composite.fields.is_empty() {
381            let _ = writeln!(out, "    pass");
382        } else {
383            let _ = writeln!(out);
384            for field in &composite.fields {
385                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
386                    .map(|t| t.into_owned())
387                    .map_err(|e| {
388                        ScytheError::new(
389                            ErrorCode::InternalError,
390                            format!("composite field type error: {}", e),
391                        )
392                    })?;
393                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
394            }
395        }
396        Ok(out)
397    }
398}