1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiosqlite.toml");
18
19pub struct PythonAiosqliteBackend {
20 manifest: BackendManifest,
21}
22
23impl PythonAiosqliteBackend {
24 pub fn new(engine: &str) -> Result<Self, ScytheError> {
25 match engine {
26 "sqlite" | "sqlite3" => {}
27 _ => {
28 return Err(ScytheError::new(
29 ErrorCode::InternalError,
30 format!(
31 "python-aiosqlite only supports SQLite, got engine '{}'",
32 engine
33 ),
34 ));
35 }
36 }
37 let manifest_path = Path::new("backends/python-aiosqlite/manifest.toml");
38 let manifest = if manifest_path.exists() {
39 load_manifest(manifest_path)
40 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41 } else {
42 toml::from_str(DEFAULT_MANIFEST_TOML)
43 .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44 };
45 Ok(Self { manifest })
46 }
47}
48
49fn rewrite_params_to_qmark(sql: &str) -> String {
51 let mut result = sql.to_string();
52 for i in (1..=99).rev() {
53 let from = format!("${}", i);
54 result = result.replace(&from, "?");
55 }
56 result
57}
58
59impl CodegenBackend for PythonAiosqliteBackend {
60 fn name(&self) -> &str {
61 "python-aiosqlite"
62 }
63
64 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
65 &self.manifest
66 }
67
68 fn supported_engines(&self) -> &[&str] {
69 &["sqlite"]
70 }
71
72 fn file_header(&self) -> String {
73 "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
74 \n\
75 from dataclasses import dataclass\n\
76 from enum import Enum # noqa: F401\n\
77 \n\
78 import aiosqlite # noqa: F401\n\
79 \n"
80 .to_string()
81 }
82
83 fn generate_row_struct(
84 &self,
85 query_name: &str,
86 columns: &[ResolvedColumn],
87 ) -> Result<String, ScytheError> {
88 let struct_name = row_struct_name(query_name, &self.manifest.naming);
89 let mut out = String::new();
90 let _ = writeln!(out, "@dataclass");
91 let _ = writeln!(out, "class {}:", struct_name);
92 let _ = writeln!(out, " \"\"\"Row type for {} query.\"\"\"", query_name);
93 if columns.is_empty() {
94 let _ = writeln!(out, " pass");
95 } else {
96 let _ = writeln!(out);
97 for col in columns {
98 let _ = writeln!(out, " {}: {}", col.field_name, col.full_type);
99 }
100 }
101 Ok(out)
102 }
103
104 fn generate_model_struct(
105 &self,
106 table_name: &str,
107 columns: &[ResolvedColumn],
108 ) -> Result<String, ScytheError> {
109 let singular = singularize(table_name);
110 let name = to_pascal_case(&singular);
111 self.generate_row_struct(&name, columns)
112 }
113
114 fn generate_query_fn(
115 &self,
116 analyzed: &AnalyzedQuery,
117 struct_name: &str,
118 columns: &[ResolvedColumn],
119 params: &[ResolvedParam],
120 ) -> Result<String, ScytheError> {
121 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
122 let mut out = String::new();
123
124 let param_list = params
125 .iter()
126 .map(|p| format!("{}: {}", p.field_name, p.full_type))
127 .collect::<Vec<_>>()
128 .join(", ");
129 let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
130
131 let sql = rewrite_params_to_qmark(&super::clean_sql(&analyzed.sql));
132
133 let args_list = if params.is_empty() {
134 String::new()
135 } else {
136 let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
137 if args.len() == 1 {
138 format!("({},)", args[0])
139 } else {
140 format!("({})", args.join(", "))
141 }
142 };
143
144 match &analyzed.command {
145 QueryCommand::One => {
146 let _ = writeln!(
147 out,
148 "async def {}(conn: aiosqlite.Connection{}{}) -> {} | None:",
149 func_name, kw_sep, param_list, struct_name
150 );
151 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
152 if params.is_empty() {
153 let _ = writeln!(
154 out,
155 " async with conn.execute(\"\"\"{}\"\"\") as cursor:",
156 sql
157 );
158 } else {
159 let _ = writeln!(
160 out,
161 " async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
162 sql, args_list
163 );
164 }
165 let _ = writeln!(out, " row = await cursor.fetchone()");
166 let _ = writeln!(out, " if row is None:");
167 let _ = writeln!(out, " return None");
168 let field_assignments: Vec<String> = columns
169 .iter()
170 .enumerate()
171 .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
172 .collect();
173 let oneliner = format!(
174 " return {}({})",
175 struct_name,
176 field_assignments.join(", ")
177 );
178 if oneliner.len() <= 88 {
179 let _ = writeln!(out, "{}", oneliner);
180 } else {
181 let _ = writeln!(out, " return {}(", struct_name);
182 for fa in &field_assignments {
183 let _ = writeln!(out, " {},", fa);
184 }
185 let _ = writeln!(out, " )");
186 }
187 }
188 QueryCommand::Many | QueryCommand::Batch => {
189 let _ = writeln!(
190 out,
191 "async def {}(conn: aiosqlite.Connection{}{}) -> list[{}]:",
192 func_name, kw_sep, param_list, struct_name
193 );
194 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
195 if params.is_empty() {
196 let _ = writeln!(
197 out,
198 " async with conn.execute(\"\"\"{}\"\"\") as cursor:",
199 sql
200 );
201 } else {
202 let _ = writeln!(
203 out,
204 " async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
205 sql, args_list
206 );
207 }
208 let _ = writeln!(out, " rows = await cursor.fetchall()");
209 let field_assignments: Vec<String> = columns
210 .iter()
211 .enumerate()
212 .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
213 .collect();
214 let oneliner = format!(
215 " return [{}({}) for r in rows]",
216 struct_name,
217 field_assignments.join(", ")
218 );
219 if oneliner.len() <= 88 {
220 let _ = writeln!(out, "{}", oneliner);
221 } else {
222 let _ = writeln!(out, " return [");
223 let _ = writeln!(out, " {}(", struct_name);
224 for fa in &field_assignments {
225 let _ = writeln!(out, " {},", fa);
226 }
227 let _ = writeln!(out, " )");
228 let _ = writeln!(out, " for r in rows");
229 let _ = writeln!(out, " ]");
230 }
231 }
232 QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
233 let _ = writeln!(
234 out,
235 "async def {}(conn: aiosqlite.Connection{}{}) -> None:",
236 func_name, kw_sep, param_list
237 );
238 let _ = writeln!(out, " \"\"\"Execute {} query.\"\"\"", analyzed.name);
239 if params.is_empty() {
240 let _ = writeln!(out, " await conn.execute(\"\"\"{}\"\"\") ", sql);
241 } else {
242 let _ = writeln!(
243 out,
244 " await conn.execute(\"\"\"{}\"\"\", {})",
245 sql, args_list
246 );
247 }
248 }
249 }
250
251 Ok(out)
252 }
253
254 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
255 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
256 let mut out = String::new();
257 let _ = writeln!(out, "class {}(str, Enum):", type_name);
258 let _ = writeln!(
259 out,
260 " \"\"\"Database enum type {}.\"\"\"",
261 enum_info.sql_name
262 );
263 if enum_info.values.is_empty() {
264 let _ = writeln!(out, " pass");
265 } else {
266 let _ = writeln!(out);
267 for value in &enum_info.values {
268 let variant = enum_variant_name(value, &self.manifest.naming);
269 let _ = writeln!(out, " {} = \"{}\"", variant, value);
270 }
271 }
272 Ok(out)
273 }
274
275 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
276 let name = to_pascal_case(&composite.sql_name);
277 let mut out = String::new();
278 let _ = writeln!(out, "@dataclass");
279 let _ = writeln!(out, "class {}:", name);
280 let _ = writeln!(
281 out,
282 " \"\"\"Composite type {}.\"\"\"",
283 composite.sql_name
284 );
285 if composite.fields.is_empty() {
286 let _ = writeln!(out, " pass");
287 } else {
288 let _ = writeln!(out);
289 for field in &composite.fields {
290 let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
291 .map(|t| t.into_owned())
292 .map_err(|e| {
293 ScytheError::new(
294 ErrorCode::InternalError,
295 format!("composite field type error: {}", e),
296 )
297 })?;
298 let _ = writeln!(out, " {}: {}", to_snake_case(&field.name), py_type);
299 }
300 }
301 Ok(out)
302 }
303}