Skip to main content

scythe_codegen/backends/
python_aiosqlite.rs

1use std::collections::HashMap;
2use std::fmt::Write;
3use std::path::Path;
4
5use scythe_backend::manifest::{BackendManifest, load_manifest};
6use scythe_backend::naming::{
7    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
8};
9use scythe_backend::types::resolve_type;
10
11use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
12use scythe_core::errors::{ErrorCode, ScytheError};
13use scythe_core::parser::QueryCommand;
14
15use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
16use crate::singularize;
17
18use super::python_common::PythonRowType;
19
20const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiosqlite.toml");
21
22pub struct PythonAiosqliteBackend {
23    manifest: BackendManifest,
24    row_type: PythonRowType,
25}
26
27impl PythonAiosqliteBackend {
28    pub fn new(engine: &str) -> Result<Self, ScytheError> {
29        match engine {
30            "sqlite" | "sqlite3" => {}
31            _ => {
32                return Err(ScytheError::new(
33                    ErrorCode::InternalError,
34                    format!(
35                        "python-aiosqlite only supports SQLite, got engine '{}'",
36                        engine
37                    ),
38                ));
39            }
40        }
41        let manifest_path = Path::new("backends/python-aiosqlite/manifest.toml");
42        let manifest = if manifest_path.exists() {
43            load_manifest(manifest_path)
44                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
45        } else {
46            toml::from_str(DEFAULT_MANIFEST_TOML)
47                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
48        };
49        Ok(Self {
50            manifest,
51            row_type: PythonRowType::default(),
52        })
53    }
54}
55
56/// Rewrite $1, $2, ... positional params to ? for SQLite.
57fn rewrite_params_to_qmark(sql: &str) -> String {
58    let mut result = sql.to_string();
59    for i in (1..=99).rev() {
60        let from = format!("${}", i);
61        result = result.replace(&from, "?");
62    }
63    result
64}
65
66impl CodegenBackend for PythonAiosqliteBackend {
67    fn name(&self) -> &str {
68        "python-aiosqlite"
69    }
70
71    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
72        &self.manifest
73    }
74
75    fn supported_engines(&self) -> &[&str] {
76        &["sqlite"]
77    }
78
79    fn apply_options(&mut self, options: &HashMap<String, String>) -> Result<(), ScytheError> {
80        if let Some(rt) = options.get("row_type") {
81            self.row_type = PythonRowType::from_option(rt)?;
82        }
83        Ok(())
84    }
85
86    fn file_header(&self) -> String {
87        let import_line = self.row_type.import_line();
88        if self.row_type.is_stdlib_import() {
89            format!(
90                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
91                 \n\
92                 {import_line}\n\
93                 from enum import Enum  # noqa: F401\n\
94                 \n\
95                 import aiosqlite  # noqa: F401\n\
96                 \n",
97            )
98        } else {
99            let third_party = self
100                .row_type
101                .sorted_third_party_imports("import aiosqlite  # noqa: F401");
102            format!(
103                "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
104                 \n\
105                 from enum import Enum  # noqa: F401\n\
106                 \n\
107                 {third_party}\n\
108                 \n",
109            )
110        }
111    }
112
113    fn generate_row_struct(
114        &self,
115        query_name: &str,
116        columns: &[ResolvedColumn],
117    ) -> Result<String, ScytheError> {
118        let struct_name = row_struct_name(query_name, &self.manifest.naming);
119        let mut out = String::new();
120        let _ = write!(out, "{}", self.row_type.decorator());
121        let _ = writeln!(out, "{}", self.row_type.class_def(&struct_name));
122        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
123        if columns.is_empty() {
124            let _ = writeln!(out, "    pass");
125        } else {
126            let _ = writeln!(out);
127            for col in columns {
128                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
129            }
130        }
131        Ok(out)
132    }
133
134    fn generate_model_struct(
135        &self,
136        table_name: &str,
137        columns: &[ResolvedColumn],
138    ) -> Result<String, ScytheError> {
139        let singular = singularize(table_name);
140        let name = to_pascal_case(&singular);
141        self.generate_row_struct(&name, columns)
142    }
143
144    fn generate_query_fn(
145        &self,
146        analyzed: &AnalyzedQuery,
147        struct_name: &str,
148        columns: &[ResolvedColumn],
149        params: &[ResolvedParam],
150    ) -> Result<String, ScytheError> {
151        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
152        let mut out = String::new();
153
154        let param_list = params
155            .iter()
156            .map(|p| format!("{}: {}", p.field_name, p.full_type))
157            .collect::<Vec<_>>()
158            .join(", ");
159        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
160
161        let sql = rewrite_params_to_qmark(&super::clean_sql_with_optional(
162            &analyzed.sql,
163            &analyzed.optional_params,
164            &analyzed.params,
165        ));
166
167        let args_list = if params.is_empty() {
168            String::new()
169        } else {
170            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
171            if args.len() == 1 {
172                format!("({},)", args[0])
173            } else {
174                format!("({})", args.join(", "))
175            }
176        };
177
178        match &analyzed.command {
179            QueryCommand::One => {
180                let _ = writeln!(
181                    out,
182                    "async def {}(conn: aiosqlite.Connection{}{}) -> {} | None:",
183                    func_name, kw_sep, param_list, struct_name
184                );
185                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
186                if params.is_empty() {
187                    let _ = writeln!(
188                        out,
189                        "    async with conn.execute(\"\"\"{}\"\"\") as cursor:",
190                        sql
191                    );
192                } else {
193                    let _ = writeln!(
194                        out,
195                        "    async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
196                        sql, args_list
197                    );
198                }
199                let _ = writeln!(out, "        row = await cursor.fetchone()");
200                let _ = writeln!(out, "    if row is None:");
201                let _ = writeln!(out, "        return None");
202                let field_assignments: Vec<String> = columns
203                    .iter()
204                    .enumerate()
205                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
206                    .collect();
207                let oneliner = format!(
208                    "    return {}({})",
209                    struct_name,
210                    field_assignments.join(", ")
211                );
212                if oneliner.len() <= 88 {
213                    let _ = writeln!(out, "{}", oneliner);
214                } else {
215                    let _ = writeln!(out, "    return {}(", struct_name);
216                    for fa in &field_assignments {
217                        let _ = writeln!(out, "        {},", fa);
218                    }
219                    let _ = writeln!(out, "    )");
220                }
221            }
222            QueryCommand::Batch => {
223                let batch_fn_name = format!("{}_batch", func_name);
224                let items_type = if params.len() > 1 {
225                    let tuple_types: Vec<String> =
226                        params.iter().map(|p| p.full_type.clone()).collect();
227                    format!("list[tuple[{}]]", tuple_types.join(", "))
228                } else if params.len() == 1 {
229                    format!("list[{}]", params[0].full_type)
230                } else {
231                    "int".to_string()
232                };
233                let _ = writeln!(
234                    out,
235                    "async def {}(conn: aiosqlite.Connection, *, items: {}) -> None:",
236                    batch_fn_name, items_type
237                );
238                let _ = writeln!(
239                    out,
240                    "    \"\"\"Execute {} query for each item in the batch.\"\"\"",
241                    analyzed.name
242                );
243                if params.is_empty() {
244                    let _ = writeln!(out, "    for _ in range(items):");
245                    let _ = writeln!(out, "        await conn.execute(\"\"\"{}\"\"\") ", sql);
246                } else if params.len() == 1 {
247                    let _ = writeln!(
248                        out,
249                        "    await conn.executemany(\"\"\"{}\"\"\", [(item,) for item in items])",
250                        sql
251                    );
252                } else {
253                    let _ = writeln!(
254                        out,
255                        "    await conn.executemany(\"\"\"{}\"\"\", items)",
256                        sql
257                    );
258                }
259                let _ = writeln!(out, "    await conn.commit()");
260            }
261            QueryCommand::Many => {
262                let _ = writeln!(
263                    out,
264                    "async def {}(conn: aiosqlite.Connection{}{}) -> list[{}]:",
265                    func_name, kw_sep, param_list, struct_name
266                );
267                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
268                if params.is_empty() {
269                    let _ = writeln!(
270                        out,
271                        "    async with conn.execute(\"\"\"{}\"\"\") as cursor:",
272                        sql
273                    );
274                } else {
275                    let _ = writeln!(
276                        out,
277                        "    async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
278                        sql, args_list
279                    );
280                }
281                let _ = writeln!(out, "        rows = await cursor.fetchall()");
282                let field_assignments: Vec<String> = columns
283                    .iter()
284                    .enumerate()
285                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
286                    .collect();
287                let oneliner = format!(
288                    "    return [{}({}) for r in rows]",
289                    struct_name,
290                    field_assignments.join(", ")
291                );
292                if oneliner.len() <= 88 {
293                    let _ = writeln!(out, "{}", oneliner);
294                } else {
295                    let _ = writeln!(out, "    return [");
296                    let _ = writeln!(out, "        {}(", struct_name);
297                    for fa in &field_assignments {
298                        let _ = writeln!(out, "            {},", fa);
299                    }
300                    let _ = writeln!(out, "        )");
301                    let _ = writeln!(out, "        for r in rows");
302                    let _ = writeln!(out, "    ]");
303                }
304            }
305            QueryCommand::Exec => {
306                let _ = writeln!(
307                    out,
308                    "async def {}(conn: aiosqlite.Connection{}{}) -> None:",
309                    func_name, kw_sep, param_list
310                );
311                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
312                if params.is_empty() {
313                    let _ = writeln!(out, "    await conn.execute(\"\"\"{}\"\"\") ", sql);
314                } else {
315                    let _ = writeln!(
316                        out,
317                        "    await conn.execute(\"\"\"{}\"\"\", {})",
318                        sql, args_list
319                    );
320                }
321            }
322            QueryCommand::ExecResult | QueryCommand::ExecRows => {
323                let _ = writeln!(
324                    out,
325                    "async def {}(conn: aiosqlite.Connection{}{}) -> int:",
326                    func_name, kw_sep, param_list
327                );
328                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
329                if params.is_empty() {
330                    let _ = writeln!(out, "    cursor = await conn.execute(\"\"\"{}\"\"\") ", sql);
331                } else {
332                    let _ = writeln!(
333                        out,
334                        "    cursor = await conn.execute(\"\"\"{}\"\"\", {})",
335                        sql, args_list
336                    );
337                }
338                let _ = writeln!(out, "    return cursor.rowcount");
339            }
340            QueryCommand::Grouped => {
341                unreachable!("Grouped is rewritten to Many before codegen")
342            }
343        }
344
345        Ok(out)
346    }
347
348    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
349        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
350        let mut out = String::new();
351        let _ = writeln!(out, "class {}(str, Enum):", type_name);
352        let _ = writeln!(
353            out,
354            "    \"\"\"Database enum type {}.\"\"\"",
355            enum_info.sql_name
356        );
357        if enum_info.values.is_empty() {
358            let _ = writeln!(out, "    pass");
359        } else {
360            let _ = writeln!(out);
361            for value in &enum_info.values {
362                let variant = enum_variant_name(value, &self.manifest.naming);
363                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
364            }
365        }
366        Ok(out)
367    }
368
369    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
370        let name = to_pascal_case(&composite.sql_name);
371        let mut out = String::new();
372        let _ = write!(out, "{}", self.row_type.decorator());
373        let _ = writeln!(out, "{}", self.row_type.class_def(&name));
374        let _ = writeln!(
375            out,
376            "    \"\"\"Composite type {}.\"\"\"",
377            composite.sql_name
378        );
379        if composite.fields.is_empty() {
380            let _ = writeln!(out, "    pass");
381        } else {
382            let _ = writeln!(out);
383            for field in &composite.fields {
384                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
385                    .map(|t| t.into_owned())
386                    .map_err(|e| {
387                        ScytheError::new(
388                            ErrorCode::InternalError,
389                            format!("composite field type error: {}", e),
390                        )
391                    })?;
392                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
393            }
394        }
395        Ok(out)
396    }
397}