Skip to main content

sqlcx_core/generator/python/
psycopg.rs

1// psycopg (psycopg3) driver. Emits queries.py only. Named %(name)s
2// placeholders + dict params arg. Sync functions against psycopg.Connection.
3
4use crate::error::Result;
5use crate::generator::python::common::{
6    PyBodyCtx, PyDriverShape, PyTypeMap, generate_driver_files,
7};
8use crate::generator::{DriverGenerator, GeneratedFile};
9use crate::ir::{ParamDef, QueryCommand, QueryDef, SqlcxIR};
10
11pub struct PsycopgGenerator;
12
13impl PyTypeMap for PsycopgGenerator {}
14
15fn rewrite_named(sql: &str, params: &[ParamDef]) -> String {
16    let mut result = String::with_capacity(sql.len());
17    let mut chars = sql.chars().peekable();
18    while let Some(c) = chars.next() {
19        if c == '$' && chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
20            let mut num_str = String::new();
21            while chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
22                num_str.push(chars.next().unwrap());
23            }
24            let idx: u32 = num_str.parse().unwrap_or(0);
25            let name = params
26                .iter()
27                .find(|p| p.index == idx)
28                .map(|p| p.name.as_str())
29                .unwrap_or("unknown");
30            result.push_str(&format!("%({name})s"));
31        } else {
32            result.push(c);
33        }
34    }
35    result
36}
37
38impl PyDriverShape for PsycopgGenerator {
39    fn driver_import(&self) -> &'static str {
40        "from psycopg import Connection"
41    }
42    fn connection_type(&self) -> &'static str {
43        "Connection"
44    }
45    fn is_async(&self) -> bool {
46        false
47    }
48    fn rewrite_sql(&self, query: &QueryDef) -> String {
49        rewrite_named(&query.sql, &query.params)
50    }
51    fn build_params_arg(&self, query: &QueryDef) -> String {
52        if query.params.is_empty() {
53            return "{}".to_string();
54        }
55        let entries: Vec<String> = query
56            .params
57            .iter()
58            .map(|p| format!("\"{}\": params.{}", p.name, p.name))
59            .collect();
60        format!("{{{}}}", entries.join(", "))
61    }
62    fn render_body(&self, ctx: &PyBodyCtx<'_>) -> (String, String) {
63        let (sc, rt, pa) = (ctx.sql_const, ctx.row_type, ctx.params_arg);
64        match ctx.command {
65            QueryCommand::One => (
66                format!("{rt} | None"),
67                format!(
68                    "    cur = conn.execute({sc}, {pa})\n    row = cur.fetchone()\n    if row is None:\n        return None\n    return {rt}(*row)"
69                ),
70            ),
71            QueryCommand::Many => (
72                format!("list[{rt}]"),
73                format!(
74                    "    cur = conn.execute({sc}, {pa})\n    return [{rt}(*row) for row in cur.fetchall()]"
75                ),
76            ),
77            QueryCommand::Exec => ("None".to_string(), format!("    conn.execute({sc}, {pa})")),
78            QueryCommand::ExecResult => (
79                "int".to_string(),
80                format!("    cur = conn.execute({sc}, {pa})\n    return cur.rowcount"),
81            ),
82        }
83    }
84}
85
86impl DriverGenerator for PsycopgGenerator {
87    fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
88        generate_driver_files(self, ir)
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use crate::generator::python::common::generate_queries_file;
96    use crate::parser::DatabaseParser;
97    use crate::parser::postgres::PostgresParser;
98
99    fn parse_fixture_ir() -> SqlcxIR {
100        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
101        let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
102        let parser = PostgresParser::new();
103        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
104        let queries = parser
105            .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
106            .unwrap();
107        SqlcxIR {
108            tables,
109            queries,
110            enums,
111        }
112    }
113
114    #[test]
115    fn generates_psycopg_query_functions() {
116        let ir = parse_fixture_ir();
117        let content = generate_queries_file(&PsycopgGenerator, &ir.queries);
118        assert!(content.contains("from psycopg import Connection"));
119        assert!(content.contains("def get_user"));
120        assert!(content.contains("%(id)s"));
121        insta::assert_snapshot!("psycopg_queries", content);
122    }
123}