1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiosqlite.toml");
21
22pub struct PythonAiosqliteBackend {
23 manifest: BackendManifest,
24 row_type: PythonRowType,
25}
26
27impl PythonAiosqliteBackend {
28 pub fn new(engine: &str) -> Result<Self, ScytheError> {
29 match engine {
30 "sqlite" | "sqlite3" => {}
31 _ => {
32 return Err(ScytheError::new(
33 ErrorCode::InternalError,
34 format!(
35 "python-aiosqlite only supports SQLite, got engine '{}'",
36 engine
37 ),
38 ));
39 }
40 }
41 let manifest_path = Path::new("backends/python-aiosqlite/manifest.toml");
42 let manifest = if manifest_path.exists() {
43 load_manifest(manifest_path)
44 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML)
47 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48 };
49 Ok(Self {
50 manifest,
51 row_type: PythonRowType::default(),
52 })
53 }
54}
55
56fn rewrite_params_to_qmark(sql: &str) -> String {
58 let mut result = sql.to_string();
59 for i in (1..=99).rev() {
60 let from = format!("${}", i);
61 result = result.replace(&from, "?");
62 }
63 result
64}
65
66impl CodegenBackend for PythonAiosqliteBackend {
67 fn name(&self) -> &str {
68 "python-aiosqlite"
69 }
70
71 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
72 &self.manifest
73 }
74
75 fn supported_engines(&self) -> &[&str] {
76 &["sqlite"]
77 }
78
79 fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
80 if let Some(rt) = options.get("row_type") {
81 self.row_type = PythonRowType::from_option(rt)?;
82 }
83 Ok(())
84 }
85
86 fn file_header(&self) -> String {
87 let import_line = self.row_type.import_line();
88 if self.row_type.is_stdlib_import() {
89 format!(
90 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
91 \n\
92 {import_line}\n\
93 from enum import Enum # noqa: F401\n\
94 \n\
95 import aiosqlite # noqa: F401\n\
96 \n",
97 )
98 } else {
99 let third_party = self
100 .row_type
101 .sorted_third_party_imports("import aiosqlite # noqa: F401");
102 format!(
103 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
104 \n\
105 from enum import Enum # noqa: F401\n\
106 \n\
107 {third_party}\n\
108 \n",
109 )
110 }
111 }
112
113 fn generate_row_struct(
114 &self,
115 query_name: &str,
116 columns: &[ResolvedColumn],
117 ) -> Result<String, ScytheError> {
118 let struct_name = row_struct_name(query_name, &self.manifest.naming);
119 let mut out = String::new();
120 let _ = write!(out, "{}", self.row_type.decorator());
121 let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
122 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
123 if columns.is_empty() {
124 let _ = writeln!(out, " pass");
125 } else {
126 let _ = writeln!(out);
127 for col in columns {
128 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
129 }
130 }
131 Ok(out)
132 }
133
134 fn generate_model_struct(
135 &self,
136 table_name: &str,
137 columns: &[ResolvedColumn],
138 ) -> Result<String, ScytheError> {
139 let singular = singularize(table_name);
140 let name = to_pascal_case(&singular);
141 self.generate_row_struct(&name, columns)
142 }
143
144 fn generate_query_fn(
145 &self,
146 analyzed: &AnalyzedQuery,
147 struct_name: &str,
148 columns: &[ResolvedColumn],
149 params: &[ResolvedParam],
150 ) -> Result<String, ScytheError> {
151 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
152 let mut out = String::new();
153
154 let param_list = params
155 .iter()
156 .map(|p| format!("{}: {}", p.field_name, p.full_type))
157 .collect::<Vec<_>>()
158 .join(", ");
159 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
160
161 let sql = rewrite_params_to_qmark(&super::clean_sql_with_optional(
162 &analyzed.sql,
163 &analyzed.optional_params,
164 &analyzed.params,
165 ));
166
167 let args_list = if params.is_empty() {
168 String::new()
169 } else {
170 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
171 if args.len() == 1 {
172 format!("({},)", args[0])
173 } else {
174 format!("({})", args.join(", "))
175 }
176 };
177
178 match &analyzed.command {
179 QueryCommand::One => {
180 let _ = writeln!(
181 out,
182 "async def {}(conn: aiosqlite.Connection{}{}) -> {} | None:",
183 func_name, kw_sep, param_list, struct_name
184 );
185 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
186 if params.is_empty() {
187 let _ = writeln!(
188 out,
189 " async with conn.execute(\"\"\"{}\"\"\") as cursor:",
190 sql
191 );
192 } else {
193 let _ = writeln!(
194 out,
195 " async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
196 sql, args_list
197 );
198 }
199 let _ = writeln!(out, " row = await cursor.fetchone()");
200 let _ = writeln!(out, " if row is None:");
201 let _ = writeln!(out, " return None");
202 let field_assignments: Vec<String> = columns
203 .iter()
204 .enumerate()
205 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
206 .collect();
207 let oneliner = format!(
208 " return {}({})",
209 struct_name,
210 field_assignments.join(", ")
211 );
212 if oneliner.len() <= 88 {
213 let _ = writeln!(out, "{}", oneliner);
214 } else {
215 let _ = writeln!(out, " return {}(", struct_name);
216 for fa in &field_assignments {
217 let _ = writeln!(out, " {},", fa);
218 }
219 let _ = writeln!(out, " )");
220 }
221 }
222 QueryCommand::Batch => {
223 let batch_fn_name = format!("{}_batch", func_name);
224 let items_type = if params.len() > 1 {
225 let tuple_types: Vec<String> =
226 params.iter().map(|p| p.full_type.clone()).collect();
227 format!("list[tuple[{}]]", tuple_types.join(", "))
228 } else if params.len() == 1 {
229 format!("list[{}]", params[0].full_type)
230 } else {
231 "int".to_string()
232 };
233 let _ = writeln!(
234 out,
235 "async def {}(conn: aiosqlite.Connection, *, items: {}) -> None:",
236 batch_fn_name, items_type
237 );
238 let _ = writeln!(
239 out,
240 " \"\"\"Execute {} query for each item in the batch.\"\"\"",
241 analyzed.name
242 );
243 if params.is_empty() {
244 let _ = writeln!(out, " for _ in range(items):");
245 let _ = writeln!(out, " await conn.execute(\"\"\"{}\"\"\") ", sql);
246 } else if params.len() == 1 {
247 let _ = writeln!(
248 out,
249 " await conn.executemany(\"\"\"{}\"\"\", [(item,) for item in items])",
250 sql
251 );
252 } else {
253 let _ = writeln!(
254 out,
255 " await conn.executemany(\"\"\"{}\"\"\", items)",
256 sql
257 );
258 }
259 let _ = writeln!(out, " await conn.commit()");
260 }
261 QueryCommand::Many => {
262 let _ = writeln!(
263 out,
264 "async def {}(conn: aiosqlite.Connection{}{}) -> list[{}]:",
265 func_name, kw_sep, param_list, struct_name
266 );
267 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
268 if params.is_empty() {
269 let _ = writeln!(
270 out,
271 " async with conn.execute(\"\"\"{}\"\"\") as cursor:",
272 sql
273 );
274 } else {
275 let _ = writeln!(
276 out,
277 " async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
278 sql, args_list
279 );
280 }
281 let _ = writeln!(out, " rows = await cursor.fetchall()");
282 let field_assignments: Vec<String> = columns
283 .iter()
284 .enumerate()
285 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
286 .collect();
287 let oneliner = format!(
288 " return [{}({}) for r in rows]",
289 struct_name,
290 field_assignments.join(", ")
291 );
292 if oneliner.len() <= 88 {
293 let _ = writeln!(out, "{}", oneliner);
294 } else {
295 let _ = writeln!(out, " return [");
296 let _ = writeln!(out, " {}(", struct_name);
297 for fa in &field_assignments {
298 let _ = writeln!(out, " {},", fa);
299 }
300 let _ = writeln!(out, " )");
301 let _ = writeln!(out, " for r in rows");
302 let _ = writeln!(out, " ]");
303 }
304 }
305 QueryCommand::Exec => {
306 let _ = writeln!(
307 out,
308 "async def {}(conn: aiosqlite.Connection{}{}) -> None:",
309 func_name, kw_sep, param_list
310 );
311 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
312 if params.is_empty() {
313 let _ = writeln!(out, " await conn.execute(\"\"\"{}\"\"\") ", sql);
314 } else {
315 let _ = writeln!(
316 out,
317 " await conn.execute(\"\"\"{}\"\"\", {})",
318 sql, args_list
319 );
320 }
321 }
322 QueryCommand::ExecResult | QueryCommand::ExecRows => {
323 let _ = writeln!(
324 out,
325 "async def {}(conn: aiosqlite.Connection{}{}) -> int:",
326 func_name, kw_sep, param_list
327 );
328 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
329 if params.is_empty() {
330 let _ = writeln!(out, " cursor = await conn.execute(\"\"\"{}\"\"\") ", sql);
331 } else {
332 let _ = writeln!(
333 out,
334 " cursor = await conn.execute(\"\"\"{}\"\"\", {})",
335 sql, args_list
336 );
337 }
338 let _ = writeln!(out, " return cursor.rowcount");
339 }
340 QueryCommand::Grouped => {
341 unreachable!("Grouped is rewritten to Many before codegen")
342 }
343 }
344
345 Ok(out)
346 }
347
348 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
349 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
350 let mut out = String::new();
351 let _ = writeln!(out, "class {}(str, Enum):", type_name);
352 let _ = writeln!(
353 out,
354 " \"\"\"Database enum type {}.\"\"\"",
355 enum_info.sql_name
356 );
357 if enum_info.values.is_empty() {
358 let _ = writeln!(out, " pass");
359 } else {
360 let _ = writeln!(out);
361 for value in &enum_info.values {
362 let variant = enum_variant_name(value, &self.manifest.naming);
363 let _ = writeln!(out, " {} = \"{}\"", variant, value);
364 }
365 }
366 Ok(out)
367 }
368
369 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
370 let name = to_pascal_case(&composite.sql_name);
371 let mut out = String::new();
372 let _ = write!(out, "{}", self.row_type.decorator());
373 let _ = writeln!(out, "{}", self.row_type.class_def(&name));
374 let _ = writeln!(
375 out,
376 " \"\"\"Composite type {}.\"\"\"",
377 composite.sql_name
378 );
379 if composite.fields.is_empty() {
380 let _ = writeln!(out, " pass");
381 } else {
382 let _ = writeln!(out);
383 for field in &composite.fields {
384 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
385 .map(|t| t.into_owned())
386 .map_err(|e| {
387 ScytheError::new(
388 ErrorCode::InternalError,
389 format!("composite field type error: {}", e),
390 )
391 })?;
392 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
393 }
394 }
395 Ok(out)
396 }
397}