Skip to main content

scythe_codegen/backends/
python_aiosqlite.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/python-aiosqlite.toml");
18
19pub struct PythonAiosqliteBackend {
20    manifest: BackendManifest,
21}
22
23impl PythonAiosqliteBackend {
24    pub fn new(engine: &str) -> Result<Self, ScytheError> {
25        match engine {
26            "sqlite" | "sqlite3" => {}
27            _ => {
28                return Err(ScytheError::new(
29                    ErrorCode::InternalError,
30                    format!(
31                        "python-aiosqlite only supports SQLite, got engine '{}'",
32                        engine
33                    ),
34                ));
35            }
36        }
37        let manifest_path = Path::new("backends/python-aiosqlite/manifest.toml");
38        let manifest = if manifest_path.exists() {
39            load_manifest(manifest_path)
40                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
41        } else {
42            toml::from_str(DEFAULT_MANIFEST_TOML)
43                .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))?
44        };
45        Ok(Self { manifest })
46    }
47}
48
49/// Rewrite $1, $2, ... positional params to ? for SQLite.
50fn rewrite_params_to_qmark(sql: &str) -> String {
51    let mut result = sql.to_string();
52    for i in (1..=99).rev() {
53        let from = format!("${}", i);
54        result = result.replace(&from, "?");
55    }
56    result
57}
58
59impl CodegenBackend for PythonAiosqliteBackend {
60    fn name(&self) -> &str {
61        "python-aiosqlite"
62    }
63
64    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
65        &self.manifest
66    }
67
68    fn supported_engines(&self) -> &[&str] {
69        &["sqlite"]
70    }
71
72    fn file_header(&self) -> String {
73        "\"\"\"Auto-generated by scythe. Do not edit.\"\"\"\n\
74         \n\
75         from dataclasses import dataclass\n\
76         from enum import Enum  # noqa: F401\n\
77         \n\
78         import aiosqlite  # noqa: F401\n\
79         \n"
80        .to_string()
81    }
82
83    fn generate_row_struct(
84        &self,
85        query_name: &str,
86        columns: &[ResolvedColumn],
87    ) -> Result<String, ScytheError> {
88        let struct_name = row_struct_name(query_name, &self.manifest.naming);
89        let mut out = String::new();
90        let _ = writeln!(out, "@dataclass");
91        let _ = writeln!(out, "class {}:", struct_name);
92        let _ = writeln!(out, "    \"\"\"Row type for {} query.\"\"\"", query_name);
93        if columns.is_empty() {
94            let _ = writeln!(out, "    pass");
95        } else {
96            let _ = writeln!(out);
97            for col in columns {
98                let _ = writeln!(out, "    {}: {}", col.field_name, col.full_type);
99            }
100        }
101        Ok(out)
102    }
103
104    fn generate_model_struct(
105        &self,
106        table_name: &str,
107        columns: &[ResolvedColumn],
108    ) -> Result<String, ScytheError> {
109        let singular = singularize(table_name);
110        let name = to_pascal_case(&singular);
111        self.generate_row_struct(&name, columns)
112    }
113
114    fn generate_query_fn(
115        &self,
116        analyzed: &AnalyzedQuery,
117        struct_name: &str,
118        columns: &[ResolvedColumn],
119        params: &[ResolvedParam],
120    ) -> Result<String, ScytheError> {
121        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
122        let mut out = String::new();
123
124        let param_list = params
125            .iter()
126            .map(|p| format!("{}: {}", p.field_name, p.full_type))
127            .collect::<Vec<_>>()
128            .join(", ");
129        let kw_sep = if param_list.is_empty() { "" } else { ", *, " };
130
131        let sql = rewrite_params_to_qmark(&super::clean_sql(&analyzed.sql));
132
133        let args_list = if params.is_empty() {
134            String::new()
135        } else {
136            let args: Vec<String> = params.iter().map(|p| p.field_name.clone()).collect();
137            if args.len() == 1 {
138                format!("({},)", args[0])
139            } else {
140                format!("({})", args.join(", "))
141            }
142        };
143
144        match &analyzed.command {
145            QueryCommand::One => {
146                let _ = writeln!(
147                    out,
148                    "async def {}(conn: aiosqlite.Connection{}{}) -> {} | None:",
149                    func_name, kw_sep, param_list, struct_name
150                );
151                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
152                if params.is_empty() {
153                    let _ = writeln!(
154                        out,
155                        "    async with conn.execute(\"\"\"{}\"\"\") as cursor:",
156                        sql
157                    );
158                } else {
159                    let _ = writeln!(
160                        out,
161                        "    async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
162                        sql, args_list
163                    );
164                }
165                let _ = writeln!(out, "        row = await cursor.fetchone()");
166                let _ = writeln!(out, "    if row is None:");
167                let _ = writeln!(out, "        return None");
168                let field_assignments: Vec<String> = columns
169                    .iter()
170                    .enumerate()
171                    .map(|(i, col)| format!("{}=row[{}]", col.field_name, i))
172                    .collect();
173                let oneliner = format!(
174                    "    return {}({})",
175                    struct_name,
176                    field_assignments.join(", ")
177                );
178                if oneliner.len() <= 88 {
179                    let _ = writeln!(out, "{}", oneliner);
180                } else {
181                    let _ = writeln!(out, "    return {}(", struct_name);
182                    for fa in &field_assignments {
183                        let _ = writeln!(out, "        {},", fa);
184                    }
185                    let _ = writeln!(out, "    )");
186                }
187            }
188            QueryCommand::Many | QueryCommand::Batch => {
189                let _ = writeln!(
190                    out,
191                    "async def {}(conn: aiosqlite.Connection{}{}) -> list[{}]:",
192                    func_name, kw_sep, param_list, struct_name
193                );
194                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
195                if params.is_empty() {
196                    let _ = writeln!(
197                        out,
198                        "    async with conn.execute(\"\"\"{}\"\"\") as cursor:",
199                        sql
200                    );
201                } else {
202                    let _ = writeln!(
203                        out,
204                        "    async with conn.execute(\"\"\"{}\"\"\", {}) as cursor:",
205                        sql, args_list
206                    );
207                }
208                let _ = writeln!(out, "        rows = await cursor.fetchall()");
209                let field_assignments: Vec<String> = columns
210                    .iter()
211                    .enumerate()
212                    .map(|(i, col)| format!("{}=r[{}]", col.field_name, i))
213                    .collect();
214                let oneliner = format!(
215                    "    return [{}({}) for r in rows]",
216                    struct_name,
217                    field_assignments.join(", ")
218                );
219                if oneliner.len() <= 88 {
220                    let _ = writeln!(out, "{}", oneliner);
221                } else {
222                    let _ = writeln!(out, "    return [");
223                    let _ = writeln!(out, "        {}(", struct_name);
224                    for fa in &field_assignments {
225                        let _ = writeln!(out, "            {},", fa);
226                    }
227                    let _ = writeln!(out, "        )");
228                    let _ = writeln!(out, "        for r in rows");
229                    let _ = writeln!(out, "    ]");
230                }
231            }
232            QueryCommand::Exec | QueryCommand::ExecResult | QueryCommand::ExecRows => {
233                let _ = writeln!(
234                    out,
235                    "async def {}(conn: aiosqlite.Connection{}{}) -> None:",
236                    func_name, kw_sep, param_list
237                );
238                let _ = writeln!(out, "    \"\"\"Execute {} query.\"\"\"", analyzed.name);
239                if params.is_empty() {
240                    let _ = writeln!(out, "    await conn.execute(\"\"\"{}\"\"\") ", sql);
241                } else {
242                    let _ = writeln!(
243                        out,
244                        "    await conn.execute(\"\"\"{}\"\"\", {})",
245                        sql, args_list
246                    );
247                }
248            }
249        }
250
251        Ok(out)
252    }
253
254    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
255        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
256        let mut out = String::new();
257        let _ = writeln!(out, "class {}(str, Enum):", type_name);
258        let _ = writeln!(
259            out,
260            "    \"\"\"Database enum type {}.\"\"\"",
261            enum_info.sql_name
262        );
263        if enum_info.values.is_empty() {
264            let _ = writeln!(out, "    pass");
265        } else {
266            let _ = writeln!(out);
267            for value in &enum_info.values {
268                let variant = enum_variant_name(value, &self.manifest.naming);
269                let _ = writeln!(out, "    {} = \"{}\"", variant, value);
270            }
271        }
272        Ok(out)
273    }
274
275    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
276        let name = to_pascal_case(&composite.sql_name);
277        let mut out = String::new();
278        let _ = writeln!(out, "@dataclass");
279        let _ = writeln!(out, "class {}:", name);
280        let _ = writeln!(
281            out,
282            "    \"\"\"Composite type {}.\"\"\"",
283            composite.sql_name
284        );
285        if composite.fields.is_empty() {
286            let _ = writeln!(out, "    pass");
287        } else {
288            let _ = writeln!(out);
289            for field in &composite.fields {
290                let py_type = resolve_type(&field.neutral_type, &self.manifest, false)
291                    .map(|t| t.into_owned())
292                    .map_err(|e| {
293                        ScytheError::new(
294                            ErrorCode::InternalError,
295                            format!("composite field type error: {}", e),
296                        )
297                    })?;
298                let _ = writeln!(out, "    {}: {}", to_snake_case(&field.name), py_type);
299            }
300        }
301        Ok(out)
302    }
303}