1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiomysql.toml");
21
22pub struct PythonAiomysqlBackend {
23 manifest: BackendManifest,
24 row_type: PythonRowType,
25}
26
27impl PythonAiomysqlBackend {
28 pub fn new(engine: &str) -> Result<Self, ScytheError> {
29 match engine {
30 "mysql" | "mariadb" => {}
31 _ => {
32 return Err(ScytheError::new(
33 ErrorCode::InternalError,
34 format!(
35 "python-aiomysql only supports MySQL/MariaDB, got engine '{}'",
36 engine
37 ),
38 ));
39 }
40 }
41 let manifest_path = Path::new("backends/python-aiomysql/manifest.toml");
42 let manifest = if manifest_path.exists() {
43 load_manifest(manifest_path)
44 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML)
47 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48 };
49 Ok(Self {
50 manifest,
51 row_type: PythonRowType::default(),
52 })
53 }
54}
55
56fn rewrite_params_to_percent_s(sql: &str) -> String {
58 let mut result = sql.to_string();
59 for i in (1..=99).rev() {
60 let from = format!("${}", i);
61 result = result.replace(&from, "%s");
62 }
63 result
64}
65
66impl CodegenBackend for PythonAiomysqlBackend {
67 fn name(&self) -> &str {
68 "python-aiomysql"
69 }
70
71 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
72 &self.manifest
73 }
74
75 fn supported_engines(&self) -> &[&str] {
76 &["mysql"]
77 }
78
79 fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
80 if let Some(rt) = options.get("row_type") {
81 self.row_type = PythonRowType::from_option(rt)?;
82 }
83 Ok(())
84 }
85
86 fn file_header(&self) -> String {
87 let import_line = self.row_type.import_line();
88 if self.row_type.is_stdlib_import() {
89 format!(
90 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
91 \n\
92 import datetime # noqa: F401\n\
93 import decimal # noqa: F401\n\
94 {import_line}\n\
95 from enum import Enum # noqa: F401\n\
96 \n\
97 import aiomysql # noqa: F401\n\
98 \n",
99 )
100 } else {
101 let third_party = self
102 .row_type
103 .sorted_third_party_imports("import aiomysql # noqa: F401");
104 format!(
105 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
106 \n\
107 import datetime # noqa: F401\n\
108 import decimal # noqa: F401\n\
109 from enum import Enum # noqa: F401\n\
110 \n\
111 {third_party}\n\
112 \n",
113 )
114 }
115 }
116
117 fn generate_row_struct(
118 &self,
119 query_name: &str,
120 columns: &[ResolvedColumn],
121 ) -> Result<String, ScytheError> {
122 let struct_name = row_struct_name(query_name, &self.manifest.naming);
123 let mut out = String::new();
124 let _ = write!(out, "{}", self.row_type.decorator());
125 let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
126 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
127 if columns.is_empty() {
128 let _ = writeln!(out, " pass");
129 } else {
130 let _ = writeln!(out);
131 for col in columns {
132 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
133 }
134 }
135 Ok(out)
136 }
137
138 fn generate_model_struct(
139 &self,
140 table_name: &str,
141 columns: &[ResolvedColumn],
142 ) -> Result<String, ScytheError> {
143 let singular = singularize(table_name);
144 let name = to_pascal_case(&singular);
145 self.generate_row_struct(&name, columns)
146 }
147
148 fn generate_query_fn(
149 &self,
150 analyzed: &AnalyzedQuery,
151 struct_name: &str,
152 columns: &[ResolvedColumn],
153 params: &[ResolvedParam],
154 ) -> Result<String, ScytheError> {
155 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
156 let mut out = String::new();
157
158 let param_list = params
159 .iter()
160 .map(|p| format!("{}: {}", p.field_name, p.full_type))
161 .collect::<Vec<_>>()
162 .join(", ");
163 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
164
165 let sql = rewrite_params_to_percent_s(&super::clean_sql_with_optional(
166 &analyzed.sql,
167 &analyzed.optional_params,
168 &analyzed.params,
169 ));
170
171 let args_tuple = if params.is_empty() {
172 String::new()
173 } else {
174 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
175 if args.len() == 1 {
176 format!("({},)", args[0])
177 } else {
178 format!("({})", args.join(", "))
179 }
180 };
181
182 match &analyzed.command {
183 QueryCommand::One => {
184 let _ = writeln!(
185 out,
186 "async def {}(conn: aiomysql.Connection{}{}) -> {} | None:",
187 func_name, kw_sep, param_list, struct_name
188 );
189 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
190 let _ = writeln!(out, " async with conn.cursor() as cur:");
191 if params.is_empty() {
192 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
193 } else {
194 let _ = writeln!(
195 out,
196 " await cur.execute(\"\"\"{}\"\"\", {})",
197 sql, args_tuple
198 );
199 }
200 let _ = writeln!(out, " row = await cur.fetchone()");
201 let _ = writeln!(out, " if row is None:");
202 let _ = writeln!(out, " return None");
203 let field_assignments: Vec<String> = columns
204 .iter()
205 .enumerate()
206 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
207 .collect();
208 let oneliner = format!(
209 " return {}({})",
210 struct_name,
211 field_assignments.join(", ")
212 );
213 if oneliner.len() <= 88 {
214 let _ = writeln!(out, "{}", oneliner);
215 } else {
216 let _ = writeln!(out, " return {}(", struct_name);
217 for fa in &field_assignments {
218 let _ = writeln!(out, " {},", fa);
219 }
220 let _ = writeln!(out, " )");
221 }
222 }
223 QueryCommand::Batch => {
224 let batch_fn_name = format!("{}_batch", func_name);
225 let items_type = if params.len() > 1 {
226 let tuple_types: Vec<String> =
227 params.iter().map(|p| p.full_type.clone()).collect();
228 format!("list[tuple[{}]]", tuple_types.join(", "))
229 } else if params.len() == 1 {
230 format!("list[{}]", params[0].full_type)
231 } else {
232 "int".to_string()
233 };
234 let _ = writeln!(
235 out,
236 "async def {}(conn: aiomysql.Connection, *, items: {}) -> None:",
237 batch_fn_name, items_type
238 );
239 let _ = writeln!(
240 out,
241 " \"\"\"Execute {} query for each item in the batch.\"\"\"",
242 analyzed.name
243 );
244 let _ = writeln!(out, " async with conn.cursor() as cur:");
245 if params.is_empty() {
246 let _ = writeln!(out, " for _ in range(items):");
247 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
248 } else if params.len() == 1 {
249 let _ = writeln!(
250 out,
251 " await cur.executemany(\"\"\"{}\"\"\", [(item,) for item in items])",
252 sql
253 );
254 } else {
255 let _ = writeln!(
256 out,
257 " await cur.executemany(\"\"\"{}\"\"\", items)",
258 sql
259 );
260 }
261 let _ = writeln!(out, " await conn.commit()");
262 }
263 QueryCommand::Many => {
264 let _ = writeln!(
265 out,
266 "async def {}(conn: aiomysql.Connection{}{}) -> list[{}]:",
267 func_name, kw_sep, param_list, struct_name
268 );
269 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
270 let _ = writeln!(out, " async with conn.cursor() as cur:");
271 if params.is_empty() {
272 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
273 } else {
274 let _ = writeln!(
275 out,
276 " await cur.execute(\"\"\"{}\"\"\", {})",
277 sql, args_tuple
278 );
279 }
280 let _ = writeln!(out, " rows = await cur.fetchall()");
281 let field_assignments: Vec<String> = columns
282 .iter()
283 .enumerate()
284 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
285 .collect();
286 let oneliner = format!(
287 " return [{}({}) for r in rows]",
288 struct_name,
289 field_assignments.join(", ")
290 );
291 if oneliner.len() <= 88 {
292 let _ = writeln!(out, "{}", oneliner);
293 } else {
294 let _ = writeln!(out, " return [");
295 let _ = writeln!(out, " {}(", struct_name);
296 for fa in &field_assignments {
297 let _ = writeln!(out, " {},", fa);
298 }
299 let _ = writeln!(out, " )");
300 let _ = writeln!(out, " for r in rows");
301 let _ = writeln!(out, " ]");
302 }
303 }
304 QueryCommand::Exec => {
305 let _ = writeln!(
306 out,
307 "async def {}(conn: aiomysql.Connection{}{}) -> None:",
308 func_name, kw_sep, param_list
309 );
310 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
311 let _ = writeln!(out, " async with conn.cursor() as cur:");
312 if params.is_empty() {
313 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
314 } else {
315 let _ = writeln!(
316 out,
317 " await cur.execute(\"\"\"{}\"\"\", {})",
318 sql, args_tuple
319 );
320 }
321 }
322 QueryCommand::ExecResult | QueryCommand::ExecRows => {
323 let _ = writeln!(
324 out,
325 "async def {}(conn: aiomysql.Connection{}{}) -> int:",
326 func_name, kw_sep, param_list
327 );
328 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
329 let _ = writeln!(out, " async with conn.cursor() as cur:");
330 if params.is_empty() {
331 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
332 } else {
333 let _ = writeln!(
334 out,
335 " await cur.execute(\"\"\"{}\"\"\", {})",
336 sql, args_tuple
337 );
338 }
339 let _ = writeln!(out, " return cur.rowcount");
340 }
341 QueryCommand::Grouped => {
342 unreachable!("Grouped is rewritten to Many before codegen")
343 }
344 }
345
346 Ok(out)
347 }
348
349 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
350 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
351 let mut out = String::new();
352 let _ = writeln!(out, "class {}(str, Enum):", type_name);
353 let _ = writeln!(
354 out,
355 " \"\"\"Database enum type {}.\"\"\"",
356 enum_info.sql_name
357 );
358 if enum_info.values.is_empty() {
359 let _ = writeln!(out, " pass");
360 } else {
361 let _ = writeln!(out);
362 for value in &enum_info.values {
363 let variant = enum_variant_name(value, &self.manifest.naming);
364 let _ = writeln!(out, " {} = \"{}\"", variant, value);
365 }
366 }
367 Ok(out)
368 }
369
370 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
371 let name = to_pascal_case(&composite.sql_name);
372 let mut out = String::new();
373 let _ = write!(out, "{}", self.row_type.decorator());
374 let _ = writeln!(out, "{}", self.row_type.class_def(&name));
375 let _ = writeln!(
376 out,
377 " \"\"\"Composite type {}.\"\"\"",
378 composite.sql_name
379 );
380 if composite.fields.is_empty() {
381 let _ = writeln!(out, " pass");
382 } else {
383 let _ = writeln!(out);
384 for field in &composite.fields {
385 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
386 .map(|t| t.into_owned())
387 .map_err(|e| {
388 ScytheError::new(
389 ErrorCode::InternalError,
390 format!("composite field type error: {}", e),
391 )
392 })?;
393 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
394 }
395 }
396 Ok(out)
397 }
398}