Skip to main content

sqlcx_core/generator/python/
sqlite3_driver.rs

1// sqlite3 (Python stdlib) driver. Emits queries.py only. ? positional
2// placeholders, tuple params arg, sync functions against sqlite3.Connection.
3// SQLite lacks native Boolean/Date/Json — the type map surfaces them as
4// int/str/str and binary as bytes.
5
6use crate::error::Result;
7use crate::generator::python::common::{
8    PyBodyCtx, PyDriverShape, PyTypeMap, generate_driver_files,
9};
10use crate::generator::{DriverGenerator, GeneratedFile};
11use crate::ir::{QueryCommand, QueryDef, SqlcxIR};
12
13pub struct Sqlite3Generator;
14
15impl PyTypeMap for Sqlite3Generator {
16    fn boolean_ty(&self) -> &'static str {
17        "int"
18    }
19    fn date_ty(&self) -> &'static str {
20        "str"
21    }
22    fn json_ty(&self) -> &'static str {
23        "str"
24    }
25}
26
27/// Normalize placeholders to sqlite3's `?` positional form and return the
28/// param indices in occurrence order. Accepts Postgres-style `$N` (stored
29/// by the PG parser) and native `?` (stored by the SQLite/MySQL parsers) —
30/// for `?` the occurrence index is the 1-based count, matching the parser's
31/// `extract_param_indices`.
32fn rewrite_qmark(sql: &str) -> (String, Vec<u32>) {
33    let mut result = String::with_capacity(sql.len());
34    let mut indices = Vec::new();
35    let mut chars = sql.chars().peekable();
36    while let Some(c) = chars.next() {
37        if c == '$' && chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
38            let mut num_str = String::new();
39            while chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
40                num_str.push(chars.next().unwrap());
41            }
42            result.push('?');
43            indices.push(num_str.parse::<u32>().unwrap_or(0));
44        } else if c == '?' {
45            result.push('?');
46            indices.push(indices.len() as u32 + 1);
47        } else {
48            result.push(c);
49        }
50    }
51    (result, indices)
52}
53
54impl PyDriverShape for Sqlite3Generator {
55    fn driver_import(&self) -> &'static str {
56        "from sqlite3 import Connection"
57    }
58    fn connection_type(&self) -> &'static str {
59        "Connection"
60    }
61    fn is_async(&self) -> bool {
62        false
63    }
64    fn rewrite_sql(&self, query: &QueryDef) -> String {
65        rewrite_qmark(&query.sql).0
66    }
67    fn build_params_arg(&self, query: &QueryDef) -> String {
68        if query.params.is_empty() {
69            return "()".to_string();
70        }
71        let indices = rewrite_qmark(&query.sql).1;
72        let args: Vec<String> = indices
73            .iter()
74            .map(|idx| {
75                query
76                    .params
77                    .iter()
78                    .find(|p| p.index == *idx)
79                    .map(|p| format!("params.{}", p.name))
80                    .unwrap_or_else(|| "None".to_string())
81            })
82            .collect();
83        let trailing = if args.len() == 1 { "," } else { "" };
84        format!("({}{trailing})", args.join(", "))
85    }
86    fn render_body(&self, ctx: &PyBodyCtx<'_>) -> (String, String) {
87        let (sc, rt, pa) = (ctx.sql_const, ctx.row_type, ctx.params_arg);
88        match ctx.command {
89            QueryCommand::One => (
90                format!("{rt} | None"),
91                format!(
92                    "    cur = conn.execute({sc}, {pa})\n    row = cur.fetchone()\n    if row is None:\n        return None\n    return {rt}(*row)"
93                ),
94            ),
95            QueryCommand::Many => (
96                format!("list[{rt}]"),
97                format!(
98                    "    cur = conn.execute({sc}, {pa})\n    return [{rt}(*row) for row in cur.fetchall()]"
99                ),
100            ),
101            QueryCommand::Exec => ("None".to_string(), format!("    conn.execute({sc}, {pa})")),
102            QueryCommand::ExecResult => (
103                "int".to_string(),
104                format!("    cur = conn.execute({sc}, {pa})\n    return cur.rowcount"),
105            ),
106        }
107    }
108}
109
110impl DriverGenerator for Sqlite3Generator {
111    fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
112        generate_driver_files(self, ir)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use crate::generator::python::common::generate_queries_file;
120    use crate::parser::DatabaseParser;
121    use crate::parser::sqlite::SqliteParser;
122
123    fn parse_fixture_ir() -> SqlcxIR {
124        let schema_sql = include_str!("../../../../../tests/fixtures/sqlite_schema.sql");
125        let queries_sql = include_str!("../../../../../tests/fixtures/sqlite_queries/users.sql");
126        let parser = SqliteParser::new();
127        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
128        let queries = parser
129            .parse_queries(queries_sql, &tables, &enums, "sqlite_queries/users.sql")
130            .unwrap();
131        SqlcxIR {
132            tables,
133            queries,
134            enums,
135        }
136    }
137
138    #[test]
139    fn generates_sqlite3_query_functions() {
140        let ir = parse_fixture_ir();
141        let content = generate_queries_file(&Sqlite3Generator, &ir.queries);
142        assert!(content.contains("from sqlite3 import Connection"));
143        assert!(content.contains("def get_user"));
144        assert!(!content.contains("$1"));
145        insta::assert_snapshot!("sqlite3_queries", content);
146    }
147
148    #[test]
149    fn native_qmark_input_tracks_occurrence_indices() {
150        // SQLite parser stores SQL with native `?` placeholders. rewrite_qmark
151        // must still emit 1-based occurrence indices so build_params_arg
152        // doesn't produce an empty tuple.
153        let (sql, idx) = rewrite_qmark("WHERE a = ? AND b = ?");
154        assert_eq!(sql, "WHERE a = ? AND b = ?");
155        assert_eq!(idx, vec![1, 2]);
156    }
157}