1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiomysql.toml");
18
19pub struct PythonAiomysqlBackend {
20 manifest: BackendManifest,
21}
22
23impl PythonAiomysqlBackend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 match engine {
26 "mysql" | "mariadb" => {}
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "python-aiomysql only supports MySQL/MariaDB, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 }
37 let manifest_path = Path::new("backends/python-aiomysql/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(DEFAULT_MANIFEST_TOML)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49fn rewrite_params_to_percent_s(sql: &str) -> String {
51 let mut result = sql.to_string();
52 for i in (1..=99).rev() {
53 let from = format!("${}", i);
54 result = result.replace(&from, "%s");
55 }
56 result
57}
58
59impl CodegenBackend for PythonAiomysqlBackend {
60 fn name(&self) -> &str {
61 "python-aiomysql"
62 }
63
64 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
65 &self.manifest
66 }
67
68 fn supported_engines(&self) -> &[&str] {
69 &["mysql"]
70 }
71
72 fn file_header(&self) -> String {
73 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
74 \n\
75 import datetime # noqa: F401\n\
76 import decimal # noqa: F401\n\
77 from dataclasses import dataclass\n\
78 from enum import Enum # noqa: F401\n\
79 \n\
80 import aiomysql # noqa: F401\n\
81 \n"
82 .to_string()
83 }
84
85 fn generate_row_struct(
86 &self,
87 query_name: &str,
88 columns: &[ResolvedColumn],
89 ) -> Result<String, ScytheError> {
90 let struct_name = row_struct_name(query_name, &self.manifest.naming);
91 let mut out = String::new();
92 let _ = writeln!(out, "@dataclass");
93 let _ = writeln!(out, "class {}:", struct_name);
94 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
95 if columns.is_empty() {
96 let _ = writeln!(out, " pass");
97 } else {
98 let _ = writeln!(out);
99 for col in columns {
100 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
101 }
102 }
103 Ok(out)
104 }
105
106 fn generate_model_struct(
107 &self,
108 table_name: &str,
109 columns: &[ResolvedColumn],
110 ) -> Result<String, ScytheError> {
111 let singular = singularize(table_name);
112 let name = to_pascal_case(&singular);
113 self.generate_row_struct(&name, columns)
114 }
115
116 fn generate_query_fn(
117 &self,
118 analyzed: &AnalyzedQuery,
119 struct_name: &str,
120 columns: &[ResolvedColumn],
121 params: &[ResolvedParam],
122 ) -> Result<String, ScytheError> {
123 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
124 let mut out = String::new();
125
126 let param_list = params
127 .iter()
128 .map(|p| format!("{}: {}", p.field_name, p.full_type))
129 .collect::<Vec<_>>()
130 .join(", ");
131 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
132
133 let sql = rewrite_params_to_percent_s(&super::clean_sql(&analyzed.sql));
134
135 let args_tuple = if params.is_empty() {
136 String::new()
137 } else {
138 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
139 if args.len() == 1 {
140 format!("({},)", args[0])
141 } else {
142 format!("({})", args.join(", "))
143 }
144 };
145
146 match &analyzed.command {
147 QueryCommand::One => {
148 let _ = writeln!(
149 out,
150 "async def {}(conn: aiomysql.Connection{}{}) -> {} | None:",
151 func_name, kw_sep, param_list, struct_name
152 );
153 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
154 let _ = writeln!(out, " async with conn.cursor() as cur:");
155 if params.is_empty() {
156 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
157 } else {
158 let _ = writeln!(
159 out,
160 " await cur.execute(\"\"\"{}\"\"\", {})",
161 sql, args_tuple
162 );
163 }
164 let _ = writeln!(out, " row = await cur.fetchone()");
165 let _ = writeln!(out, " if row is None:");
166 let _ = writeln!(out, " return None");
167 let field_assignments: Vec<String> = columns
168 .iter()
169 .enumerate()
170 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
171 .collect();
172 let oneliner = format!(
173 " return {}({})",
174 struct_name,
175 field_assignments.join(", ")
176 );
177 if oneliner.len() <= 88 {
178 let _ = writeln!(out, "{}", oneliner);
179 } else {
180 let _ = writeln!(out, " return {}(", struct_name);
181 for fa in &field_assignments {
182 let _ = writeln!(out, " {},", fa);
183 }
184 let _ = writeln!(out, " )");
185 }
186 }
187 QueryCommand::Many | QueryCommand::Batch => {
188 let _ = writeln!(
189 out,
190 "async def {}(conn: aiomysql.Connection{}{}) -> list[{}]:",
191 func_name, kw_sep, param_list, struct_name
192 );
193 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
194 let _ = writeln!(out, " async with conn.cursor() as cur:");
195 if params.is_empty() {
196 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
197 } else {
198 let _ = writeln!(
199 out,
200 " await cur.execute(\"\"\"{}\"\"\", {})",
201 sql, args_tuple
202 );
203 }
204 let _ = writeln!(out, " rows = await cur.fetchall()");
205 let field_assignments: Vec<String> = columns
206 .iter()
207 .enumerate()
208 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
209 .collect();
210 let oneliner = format!(
211 " return [{}({}) for r in rows]",
212 struct_name,
213 field_assignments.join(", ")
214 );
215 if oneliner.len() <= 88 {
216 let _ = writeln!(out, "{}", oneliner);
217 } else {
218 let _ = writeln!(out, " return [");
219 let _ = writeln!(out, " {}(", struct_name);
220 for fa in &field_assignments {
221 let _ = writeln!(out, " {},", fa);
222 }
223 let _ = writeln!(out, " )");
224 let _ = writeln!(out, " for r in rows");
225 let _ = writeln!(out, " ]");
226 }
227 }
228 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
229 let _ = writeln!(
230 out,
231 "async def {}(conn: aiomysql.Connection{}{}) -> None:",
232 func_name, kw_sep, param_list
233 );
234 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
235 let _ = writeln!(out, " async with conn.cursor() as cur:");
236 if params.is_empty() {
237 let _ = writeln!(out, " await cur.execute(\"\"\"{}\"\"\")", sql);
238 } else {
239 let _ = writeln!(
240 out,
241 " await cur.execute(\"\"\"{}\"\"\", {})",
242 sql, args_tuple
243 );
244 }
245 }
246 }
247
248 Ok(out)
249 }
250
251 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
252 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
253 let mut out = String::new();
254 let _ = writeln!(out, "class {}(str, Enum):", type_name);
255 let _ = writeln!(
256 out,
257 " \"\"\"Database enum type {}.\"\"\"",
258 enum_info.sql_name
259 );
260 if enum_info.values.is_empty() {
261 let _ = writeln!(out, " pass");
262 } else {
263 let _ = writeln!(out);
264 for value in &enum_info.values {
265 let variant = enum_variant_name(value, &self.manifest.naming);
266 let _ = writeln!(out, " {} = \"{}\"", variant, value);
267 }
268 }
269 Ok(out)
270 }
271
272 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
273 let name = to_pascal_case(&composite.sql_name);
274 let mut out = String::new();
275 let _ = writeln!(out, "@dataclass");
276 let _ = writeln!(out, "class {}:", name);
277 let _ = writeln!(
278 out,
279 " \"\"\"Composite type {}.\"\"\"",
280 composite.sql_name
281 );
282 if composite.fields.is_empty() {
283 let _ = writeln!(out, " pass");
284 } else {
285 let _ = writeln!(out);
286 for field in &composite.fields {
287 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
288 .map(|t| t.into_owned())
289 .map_err(|e| {
290 ScytheError::new(
291 ErrorCode::InternalError,
292 format!("composite field type error: {}", e),
293 )
294 })?;
295 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
296 }
297 }
298 Ok(out)
299 }
300}