Skip to main content

scythe_codegen/backends/
tokio_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17/// Default embedded manifest TOML for rust-tokio-postgres, used as fallback.
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
19
20/// TokioPostgresBackend generates Rust code targeting the tokio-postgres crate.
21pub struct TokioPostgresBackend {
22    manifest: BackendManifest,
23}
24
25impl TokioPostgresBackend {
26    pub fn new(engine: &str) -> Result<Self, ScytheError> {
27        match engine {
28            "postgresql" | "postgres" | "pg" => {}
29            _ => {
30                return Err(ScytheError::new(
31                    ErrorCode::InternalError,
32                    format!(
33                        "rust-tokio-postgres only supports PostgreSQL, got engine '{}'",
34                        engine
35                    ),
36                ));
37            }
38        }
39        let manifest = load_tokio_postgres_manifest()?;
40        Ok(Self { manifest })
41    }
42}
43
44fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
45    let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
46    if manifest_path.exists() {
47        load_manifest(manifest_path).map_err(|e| {
48            ScytheError::new(
49                ErrorCode::InternalError,
50                format!("failed to load manifest: {e}"),
51            )
52        })
53    } else {
54        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
55            ScytheError::new(
56                ErrorCode::InternalError,
57                format!("failed to parse embedded manifest: {e}"),
58            )
59        })
60    }
61}
62
63impl CodegenBackend for TokioPostgresBackend {
64    fn name(&self) -> &str {
65        "rust-tokio-postgres"
66    }
67
68    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
69        &self.manifest
70    }
71
72    fn file_header(&self) -> String {
73        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
74            .to_string()
75    }
76
77    fn generate_row_struct(
78        &self,
79        query_name: &str,
80        columns: &[ResolvedColumn],
81    ) -> Result<String, ScytheError> {
82        let struct_name = row_struct_name(query_name, &self.manifest.naming);
83        generate_struct_with_from_row(&struct_name, columns)
84    }
85
86    fn generate_model_struct(
87        &self,
88        table_name: &str,
89        columns: &[ResolvedColumn],
90    ) -> Result<String, ScytheError> {
91        let singular = singularize(table_name);
92        let struct_name = to_pascal_case(&singular).into_owned();
93        generate_struct_with_from_row(&struct_name, columns)
94    }
95
96    fn generate_query_fn(
97        &self,
98        analyzed: &AnalyzedQuery,
99        struct_name: &str,
100        _columns: &[ResolvedColumn],
101        params: &[ResolvedParam],
102    ) -> Result<String, ScytheError> {
103        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
104        let mut out = String::new();
105
106        // Deprecated annotation
107        if let Some(ref msg) = analyzed.deprecated {
108            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
109        }
110
111        // Build parameter list
112        let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
113        for param in params {
114            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
115        }
116
117        // Return type
118        let return_type = match &analyzed.command {
119            QueryCommand::One => struct_name.to_string(),
120            QueryCommand::Many => format!("Vec<{}>", struct_name),
121            QueryCommand::Exec => "()".to_string(),
122            QueryCommand::ExecResult => "u64".to_string(),
123            QueryCommand::ExecRows => "u64".to_string(),
124            QueryCommand::Batch => format!("Vec<{}>", struct_name),
125        };
126
127        // Function signature
128        let _ = writeln!(
129            out,
130            "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
131            func_name,
132            param_parts.join(", "),
133            return_type
134        );
135
136        // Clean SQL
137        let sql = super::clean_sql(&analyzed.sql);
138
139        // Build param references for the query call
140        let param_refs: String = if params.is_empty() {
141            "&[]".to_string()
142        } else {
143            let refs: Vec<String> = params
144                .iter()
145                .map(|p| {
146                    if p.neutral_type.starts_with("enum::") {
147                        format!("&{}.to_string()", p.field_name)
148                    } else {
149                        format!("&{}", p.field_name)
150                    }
151                })
152                .collect();
153            format!("&[{}]", refs.join(", "))
154        };
155
156        match &analyzed.command {
157            QueryCommand::One => {
158                let _ = writeln!(
159                    out,
160                    "    let row = client.query_one(r#\"{}\"#, {}).await?;",
161                    sql, param_refs
162                );
163                let _ = writeln!(out, "    Ok({}::from_row(&row))", struct_name);
164            }
165            QueryCommand::Many | QueryCommand::Batch => {
166                let _ = writeln!(
167                    out,
168                    "    let rows = client.query(r#\"{}\"#, {}).await?;",
169                    sql, param_refs
170                );
171                let _ = writeln!(
172                    out,
173                    "    Ok(rows.iter().map({}::from_row).collect())",
174                    struct_name
175                );
176            }
177            QueryCommand::Exec => {
178                let _ = writeln!(
179                    out,
180                    "    client.execute(r#\"{}\"#, {}).await?;",
181                    sql, param_refs
182                );
183                let _ = writeln!(out, "    Ok(())");
184            }
185            QueryCommand::ExecResult | QueryCommand::ExecRows => {
186                let _ = writeln!(
187                    out,
188                    "    let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
189                    sql, param_refs
190                );
191                let _ = writeln!(out, "    Ok(rows_affected)");
192            }
193        }
194
195        let _ = write!(out, "}}");
196        Ok(out)
197    }
198
199    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
200        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
201        let mut out = String::with_capacity(512);
202
203        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
204        let _ = writeln!(out, "pub enum {} {{", type_name);
205        for value in &enum_info.values {
206            let variant = enum_variant_name(value, &self.manifest.naming);
207            let _ = writeln!(out, "    {},", variant);
208        }
209        let _ = writeln!(out, "}}");
210        let _ = writeln!(out);
211
212        // impl Display for serialization
213        let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
214        let _ = writeln!(
215            out,
216            "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
217        );
218        let _ = writeln!(out, "        match self {{");
219        for value in &enum_info.values {
220            let variant = enum_variant_name(value, &self.manifest.naming);
221            let _ = writeln!(
222                out,
223                "            {}::{} => write!(f, \"{}\"),",
224                type_name, variant, value
225            );
226        }
227        let _ = writeln!(out, "        }}");
228        let _ = writeln!(out, "    }}");
229        let _ = writeln!(out, "}}");
230        let _ = writeln!(out);
231
232        // impl FromStr for deserialization
233        let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
234        let _ = writeln!(out, "    type Err = String;");
235        let _ = writeln!(
236            out,
237            "    fn from_str(s: &str) -> Result<Self, Self::Err> {{"
238        );
239        let _ = writeln!(out, "        match s {{");
240        for value in &enum_info.values {
241            let variant = enum_variant_name(value, &self.manifest.naming);
242            let _ = writeln!(
243                out,
244                "            \"{}\" => Ok({}::{}),",
245                value, type_name, variant
246            );
247        }
248        let _ = writeln!(
249            out,
250            "            _ => Err(format!(\"unknown variant: {{}}\", s)),"
251        );
252        let _ = writeln!(out, "        }}");
253        let _ = writeln!(out, "    }}");
254        let _ = write!(out, "}}");
255
256        Ok(out)
257    }
258
259    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
260        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
261        let mut out = String::new();
262
263        let _ = writeln!(out, "#[derive(Debug, Clone)]");
264        let _ = writeln!(out, "pub struct {} {{", struct_name);
265        for field in &composite.fields {
266            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
267                .map(|t| t.into_owned())
268                .map_err(|e| {
269                    ScytheError::new(
270                        ErrorCode::InternalError,
271                        format!("composite field type error: {}", e),
272                    )
273                })?;
274            let _ = writeln!(
275                out,
276                "    pub {}: {},",
277                to_snake_case(&field.name),
278                rust_type
279            );
280        }
281        let _ = write!(out, "}}");
282        Ok(out)
283    }
284}
285
286// ---------------------------------------------------------------------------
287// Internal helpers
288// ---------------------------------------------------------------------------
289
290/// Generate a struct with a `from_row` method for tokio-postgres.
291fn generate_struct_with_from_row(
292    struct_name: &str,
293    columns: &[ResolvedColumn],
294) -> Result<String, ScytheError> {
295    let mut out = String::new();
296
297    let _ = writeln!(out, "#[derive(Debug, Clone)]");
298    let _ = writeln!(out, "pub struct {} {{", struct_name);
299    for col in columns {
300        let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
301    }
302    let _ = writeln!(out, "}}");
303    let _ = writeln!(out);
304
305    let _ = writeln!(out, "impl {} {{", struct_name);
306    let _ = writeln!(
307        out,
308        "    pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
309    );
310    let _ = writeln!(out, "        Self {{");
311    for col in columns {
312        if col.neutral_type.starts_with("enum::") {
313            // Enum columns need string conversion
314            if col.nullable {
315                let _ = writeln!(
316                    out,
317                    "            {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
318                    col.field_name, col.name
319                );
320            } else {
321                let _ = writeln!(
322                    out,
323                    "            {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
324                    col.field_name, col.name
325                );
326            }
327        } else {
328            let _ = writeln!(
329                out,
330                "            {}: row.get(\"{}\"),",
331                col.field_name, col.name
332            );
333        }
334    }
335    let _ = writeln!(out, "        }}");
336    let _ = writeln!(out, "    }}");
337    let _ = write!(out, "}}");
338
339    Ok(out)
340}