Skip to main content

scythe_codegen/backends/
tokio_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17/// Default embedded manifest TOML for rust-tokio-postgres, used as fallback.
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
19
20/// TokioPostgresBackend generates Rust code targeting the tokio-postgres crate.
21pub struct TokioPostgresBackend {
22    manifest: BackendManifest,
23}
24
25impl TokioPostgresBackend {
26    pub fn new(engine: &str) -> Result<Self, ScytheError> {
27        match engine {
28            "postgresql" | "postgres" | "pg" => {}
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!(
33                        "rust-tokio-postgres only supports PostgreSQL, got engine '{}'",
34                        engine
35                    ),
36                ));
37            }
38        }
39        let manifest = load_tokio_postgres_manifest()?;
40        Ok(Self { manifest })
41    }
42}
43
44fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
45    let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
46    if manifest_path.exists() {
47        load_manifest(manifest_path).map_err(|e| {
48            ScytheError::new(
49                ErrorCode::InternalError,
50                format!("failed to load manifest: {e}"),
51            )
52        })
53    } else {
54        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
55            ScytheError::new(
56                ErrorCode::InternalError,
57                format!("failed to parse embedded manifest: {e}"),
58            )
59        })
60    }
61}
62
63impl CodegenBackend for TokioPostgresBackend {
64    fn name(&self) -> &str {
65        "rust-tokio-postgres"
66    }
67
68    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
69        &self.manifest
70    }
71
72    fn file_header(&self) -> String {
73        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
74            .to_string()
75    }
76
77    fn generate_row_struct(
78        &self,
79        query_name: &str,
80        columns: &[ResolvedColumn],
81    ) -> Result<String, ScytheError> {
82        let struct_name = row_struct_name(query_name, &self.manifest.naming);
83        generate_struct_with_from_row(&struct_name, columns)
84    }
85
86    fn generate_model_struct(
87        &self,
88        table_name: &str,
89        columns: &[ResolvedColumn],
90    ) -> Result<String, ScytheError> {
91        let singular = singularize(table_name);
92        let struct_name = to_pascal_case(&singular).into_owned();
93        generate_struct_with_from_row(&struct_name, columns)
94    }
95
96    fn generate_query_fn(
97        &self,
98        analyzed: &AnalyzedQuery,
99        struct_name: &str,
100        _columns: &[ResolvedColumn],
101        params: &[ResolvedParam],
102    ) -> Result<String, ScytheError> {
103        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
104        let mut out = String::new();
105
106        // Deprecated annotation
107        if let Some(ref msg) = analyzed.deprecated {
108            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
109        }
110
111        // Build parameter list
112        let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
113        for param in params {
114            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
115        }
116
117        // Clean SQL
118        let sql = super::clean_sql_with_optional(
119            &analyzed.sql,
120            &analyzed.optional_params,
121            &analyzed.params,
122        );
123
124        // Handle :batch separately
125        if matches!(analyzed.command, QueryCommand::Batch) {
126            let batch_fn_name = format!("{}_batch", func_name);
127
128            if params.len() > 1 {
129                let params_struct_name = format!("{}BatchParams", struct_name);
130                let _ = writeln!(out, "#[derive(Debug, Clone)]");
131                let _ = writeln!(out, "pub struct {} {{", params_struct_name);
132                for param in params {
133                    let _ = writeln!(out, "    pub {}: {},", param.field_name, param.full_type);
134                }
135                let _ = writeln!(out, "}}");
136                let _ = writeln!(out);
137                let _ = writeln!(
138                    out,
139                    "pub async fn {}(client: &tokio_postgres::Client, items: &[{}]) -> Result<(), tokio_postgres::Error> {{",
140                    batch_fn_name, params_struct_name
141                );
142                let _ = writeln!(out, "    let stmt = client.prepare(r#\"{}\"#).await?;", sql);
143                let _ = writeln!(out, "    let tx = client.transaction().await?;");
144                let _ = writeln!(out, "    for item in items {{");
145                let refs: Vec<String> = params
146                    .iter()
147                    .map(|p| {
148                        if p.neutral_type.starts_with("enum::") {
149                            format!("&item.{}.to_string()", p.field_name)
150                        } else {
151                            format!("&item.{}", p.field_name)
152                        }
153                    })
154                    .collect();
155                let _ = writeln!(
156                    out,
157                    "        tx.execute(&stmt, &[{}]).await?;",
158                    refs.join(", ")
159                );
160                let _ = writeln!(out, "    }}");
161                let _ = writeln!(out, "    tx.commit().await?;");
162                let _ = writeln!(out, "    Ok(())");
163            } else if params.len() == 1 {
164                let param = &params[0];
165                let _ = writeln!(
166                    out,
167                    "pub async fn {}(client: &tokio_postgres::Client, items: &[{}]) -> Result<(), tokio_postgres::Error> {{",
168                    batch_fn_name, param.full_type
169                );
170                let _ = writeln!(out, "    let stmt = client.prepare(r#\"{}\"#).await?;", sql);
171                let _ = writeln!(out, "    let tx = client.transaction().await?;");
172                let _ = writeln!(out, "    for item in items {{");
173                let _ = writeln!(out, "        tx.execute(&stmt, &[item]).await?;");
174                let _ = writeln!(out, "    }}");
175                let _ = writeln!(out, "    tx.commit().await?;");
176                let _ = writeln!(out, "    Ok(())");
177            } else {
178                let _ = writeln!(
179                    out,
180                    "pub async fn {}(client: &tokio_postgres::Client, count: usize) -> Result<(), tokio_postgres::Error> {{",
181                    batch_fn_name
182                );
183                let _ = writeln!(out, "    let stmt = client.prepare(r#\"{}\"#).await?;", sql);
184                let _ = writeln!(out, "    let tx = client.transaction().await?;");
185                let _ = writeln!(out, "    for _ in 0..count {{");
186                let _ = writeln!(out, "        tx.execute(&stmt, &[]).await?;");
187                let _ = writeln!(out, "    }}");
188                let _ = writeln!(out, "    tx.commit().await?;");
189                let _ = writeln!(out, "    Ok(())");
190            }
191
192            let _ = write!(out, "}}");
193            return Ok(out);
194        }
195
196        // Return type for non-batch commands
197        let return_type = match &analyzed.command {
198            QueryCommand::One => struct_name.to_string(),
199            QueryCommand::Many => format!("Vec<{}>", struct_name),
200            QueryCommand::Exec => "()".to_string(),
201            QueryCommand::ExecResult => "u64".to_string(),
202            QueryCommand::ExecRows => "u64".to_string(),
203            QueryCommand::Batch => unreachable!(),
204        };
205
206        // Function signature
207        let _ = writeln!(
208            out,
209            "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
210            func_name,
211            param_parts.join(", "),
212            return_type
213        );
214
215        // Build param references for the query call
216        let param_refs: String = if params.is_empty() {
217            "&[]".to_string()
218        } else {
219            let refs: Vec<String> = params
220                .iter()
221                .map(|p| {
222                    if p.neutral_type.starts_with("enum::") {
223                        format!("&{}.to_string()", p.field_name)
224                    } else {
225                        format!("&{}", p.field_name)
226                    }
227                })
228                .collect();
229            format!("&[{}]", refs.join(", "))
230        };
231
232        match &analyzed.command {
233            QueryCommand::One => {
234                let _ = writeln!(
235                    out,
236                    "    let row = client.query_one(r#\"{}\"#, {}).await?;",
237                    sql, param_refs
238                );
239                let _ = writeln!(out, "    Ok({}::from_row(&row))", struct_name);
240            }
241            QueryCommand::Many => {
242                let _ = writeln!(
243                    out,
244                    "    let rows = client.query(r#\"{}\"#, {}).await?;",
245                    sql, param_refs
246                );
247                let _ = writeln!(
248                    out,
249                    "    Ok(rows.iter().map({}::from_row).collect())",
250                    struct_name
251                );
252            }
253            QueryCommand::Exec => {
254                let _ = writeln!(
255                    out,
256                    "    client.execute(r#\"{}\"#, {}).await?;",
257                    sql, param_refs
258                );
259                let _ = writeln!(out, "    Ok(())");
260            }
261            QueryCommand::ExecResult | QueryCommand::ExecRows => {
262                let _ = writeln!(
263                    out,
264                    "    let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
265                    sql, param_refs
266                );
267                let _ = writeln!(out, "    Ok(rows_affected)");
268            }
269            QueryCommand::Batch => unreachable!(),
270        }
271
272        let _ = write!(out, "}}");
273        Ok(out)
274    }
275
276    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
277        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
278        let mut out = String::with_capacity(512);
279
280        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
281        let _ = writeln!(out, "pub enum {} {{", type_name);
282        for value in &enum_info.values {
283            let variant = enum_variant_name(value, &self.manifest.naming);
284            let _ = writeln!(out, "    {},", variant);
285        }
286        let _ = writeln!(out, "}}");
287        let _ = writeln!(out);
288
289        // impl Display for serialization
290        let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
291        let _ = writeln!(
292            out,
293            "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
294        );
295        let _ = writeln!(out, "        match self {{");
296        for value in &enum_info.values {
297            let variant = enum_variant_name(value, &self.manifest.naming);
298            let _ = writeln!(
299                out,
300                "            {}::{} => write!(f, \"{}\"),",
301                type_name, variant, value
302            );
303        }
304        let _ = writeln!(out, "        }}");
305        let _ = writeln!(out, "    }}");
306        let _ = writeln!(out, "}}");
307        let _ = writeln!(out);
308
309        // impl FromStr for deserialization
310        let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
311        let _ = writeln!(out, "    type Err = String;");
312        let _ = writeln!(
313            out,
314            "    fn from_str(s: &str) -> Result<Self, Self::Err> {{"
315        );
316        let _ = writeln!(out, "        match s {{");
317        for value in &enum_info.values {
318            let variant = enum_variant_name(value, &self.manifest.naming);
319            let _ = writeln!(
320                out,
321                "            \"{}\" => Ok({}::{}),",
322                value, type_name, variant
323            );
324        }
325        let _ = writeln!(
326            out,
327            "            _ => Err(format!(\"unknown variant: {{}}\", s)),"
328        );
329        let _ = writeln!(out, "        }}");
330        let _ = writeln!(out, "    }}");
331        let _ = write!(out, "}}");
332
333        Ok(out)
334    }
335
336    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
337        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
338        let mut out = String::new();
339
340        let _ = writeln!(out, "#[derive(Debug, Clone)]");
341        let _ = writeln!(out, "pub struct {} {{", struct_name);
342        for field in &composite.fields {
343            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
344                .map(|t| t.into_owned())
345                .map_err(|e| {
346                    ScytheError::new(
347                        ErrorCode::InternalError,
348                        format!("composite field type error: {}", e),
349                    )
350                })?;
351            let _ = writeln!(
352                out,
353                "    pub {}: {},",
354                to_snake_case(&field.name),
355                rust_type
356            );
357        }
358        let _ = write!(out, "}}");
359        Ok(out)
360    }
361}
362
363// ---------------------------------------------------------------------------
364// Internal helpers
365// ---------------------------------------------------------------------------
366
367/// Generate a struct with a `from_row` method for tokio-postgres.
368fn generate_struct_with_from_row(
369    struct_name: &str,
370    columns: &[ResolvedColumn],
371) -> Result<String, ScytheError> {
372    let mut out = String::new();
373
374    let _ = writeln!(out, "#[derive(Debug, Clone)]");
375    let _ = writeln!(out, "pub struct {} {{", struct_name);
376    for col in columns {
377        let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
378    }
379    let _ = writeln!(out, "}}");
380    let _ = writeln!(out);
381
382    let _ = writeln!(out, "impl {} {{", struct_name);
383    let _ = writeln!(
384        out,
385        "    pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
386    );
387    let _ = writeln!(out, "        Self {{");
388    for col in columns {
389        if col.neutral_type.starts_with("enum::") {
390            // Enum columns need string conversion
391            if col.nullable {
392                let _ = writeln!(
393                    out,
394                    "            {field}: row.get::<_, Option<String>>(\"{col}\").map(|s| s.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", s))),",
395                    field = col.field_name,
396                    col = col.name
397                );
398            } else {
399                let _ = writeln!(
400                    out,
401                    "            {field}: {{ let val = row.get::<_, String>(\"{col}\"); val.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", val)) }},",
402                    field = col.field_name,
403                    col = col.name
404                );
405            }
406        } else {
407            let _ = writeln!(
408                out,
409                "            {}: row.get(\"{}\"),",
410                col.field_name, col.name
411            );
412        }
413    }
414    let _ = writeln!(out, "        }}");
415    let _ = writeln!(out, "    }}");
416    let _ = write!(out, "}}");
417
418    Ok(out)
419}