Skip to main content

scythe_codegen/backends/
tokio_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17/// Default embedded manifest TOML for rust-tokio-postgres, used as fallback.
18const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
19
20/// TokioPostgresBackend generates Rust code targeting the tokio-postgres crate.
21pub struct TokioPostgresBackend {
22    manifest: BackendManifest,
23}
24
25impl TokioPostgresBackend {
26    pub fn new() -> Result<Self, ScytheError> {
27        let manifest = load_tokio_postgres_manifest()?;
28        Ok(Self { manifest })
29    }
30
31    pub fn manifest(&self) -> &BackendManifest {
32        &self.manifest
33    }
34}
35
36fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
37    let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
38    if manifest_path.exists() {
39        load_manifest(manifest_path).map_err(|e| {
40            ScytheError::new(
41                ErrorCode::InternalError,
42                format!("failed to load manifest: {e}"),
43            )
44        })
45    } else {
46        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
47            ScytheError::new(
48                ErrorCode::InternalError,
49                format!("failed to parse embedded manifest: {e}"),
50            )
51        })
52    }
53}
54
55impl CodegenBackend for TokioPostgresBackend {
56    fn name(&self) -> &str {
57        "rust-tokio-postgres"
58    }
59
60    fn generate_row_struct(
61        &self,
62        query_name: &str,
63        columns: &[ResolvedColumn],
64    ) -> Result<String, ScytheError> {
65        let struct_name = row_struct_name(query_name, &self.manifest.naming);
66        generate_struct_with_from_row(&struct_name, columns)
67    }
68
69    fn generate_model_struct(
70        &self,
71        table_name: &str,
72        columns: &[ResolvedColumn],
73    ) -> Result<String, ScytheError> {
74        let singular = singularize(table_name);
75        let struct_name = to_pascal_case(&singular).into_owned();
76        generate_struct_with_from_row(&struct_name, columns)
77    }
78
79    fn generate_query_fn(
80        &self,
81        analyzed: &AnalyzedQuery,
82        struct_name: &str,
83        _columns: &[ResolvedColumn],
84        params: &[ResolvedParam],
85    ) -> Result<String, ScytheError> {
86        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
87        let mut out = String::new();
88
89        // Deprecated annotation
90        if let Some(ref msg) = analyzed.deprecated {
91            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
92        }
93
94        // Build parameter list
95        let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
96        for param in params {
97            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
98        }
99
100        // Return type
101        let return_type = match &analyzed.command {
102            QueryCommand::One => struct_name.to_string(),
103            QueryCommand::Many => format!("Vec<{}>", struct_name),
104            QueryCommand::Exec => "()".to_string(),
105            QueryCommand::ExecResult => "u64".to_string(),
106            QueryCommand::ExecRows => "u64".to_string(),
107            QueryCommand::Batch => format!("Vec<{}>", struct_name),
108        };
109
110        // Function signature
111        let _ = writeln!(
112            out,
113            "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
114            func_name,
115            param_parts.join(", "),
116            return_type
117        );
118
119        // Clean SQL
120        let sql = super::clean_sql(&analyzed.sql);
121
122        // Build param references for the query call
123        let param_refs: String = if params.is_empty() {
124            "&[]".to_string()
125        } else {
126            let refs: Vec<String> = params
127                .iter()
128                .map(|p| {
129                    if p.neutral_type.starts_with("enum::") {
130                        format!("&{}.to_string()", p.field_name)
131                    } else {
132                        format!("&{}", p.field_name)
133                    }
134                })
135                .collect();
136            format!("&[{}]", refs.join(", "))
137        };
138
139        match &analyzed.command {
140            QueryCommand::One => {
141                let _ = writeln!(
142                    out,
143                    "    let row = client.query_one(r#\"{}\"#, {}).await?;",
144                    sql, param_refs
145                );
146                let _ = writeln!(out, "    Ok({}::from_row(&row))", struct_name);
147            }
148            QueryCommand::Many | QueryCommand::Batch => {
149                let _ = writeln!(
150                    out,
151                    "    let rows = client.query(r#\"{}\"#, {}).await?;",
152                    sql, param_refs
153                );
154                let _ = writeln!(
155                    out,
156                    "    Ok(rows.iter().map({}::from_row).collect())",
157                    struct_name
158                );
159            }
160            QueryCommand::Exec => {
161                let _ = writeln!(
162                    out,
163                    "    client.execute(r#\"{}\"#, {}).await?;",
164                    sql, param_refs
165                );
166                let _ = writeln!(out, "    Ok(())");
167            }
168            QueryCommand::ExecResult | QueryCommand::ExecRows => {
169                let _ = writeln!(
170                    out,
171                    "    let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
172                    sql, param_refs
173                );
174                let _ = writeln!(out, "    Ok(rows_affected)");
175            }
176        }
177
178        let _ = write!(out, "}}");
179        Ok(out)
180    }
181
182    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
183        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
184        let mut out = String::with_capacity(512);
185
186        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
187        let _ = writeln!(out, "pub enum {} {{", type_name);
188        for value in &enum_info.values {
189            let variant = enum_variant_name(value, &self.manifest.naming);
190            let _ = writeln!(out, "    {},", variant);
191        }
192        let _ = writeln!(out, "}}");
193        let _ = writeln!(out);
194
195        // impl Display for serialization
196        let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
197        let _ = writeln!(
198            out,
199            "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
200        );
201        let _ = writeln!(out, "        match self {{");
202        for value in &enum_info.values {
203            let variant = enum_variant_name(value, &self.manifest.naming);
204            let _ = writeln!(
205                out,
206                "            {}::{} => write!(f, \"{}\"),",
207                type_name, variant, value
208            );
209        }
210        let _ = writeln!(out, "        }}");
211        let _ = writeln!(out, "    }}");
212        let _ = writeln!(out, "}}");
213        let _ = writeln!(out);
214
215        // impl FromStr for deserialization
216        let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
217        let _ = writeln!(out, "    type Err = String;");
218        let _ = writeln!(
219            out,
220            "    fn from_str(s: &str) -> Result<Self, Self::Err> {{"
221        );
222        let _ = writeln!(out, "        match s {{");
223        for value in &enum_info.values {
224            let variant = enum_variant_name(value, &self.manifest.naming);
225            let _ = writeln!(
226                out,
227                "            \"{}\" => Ok({}::{}),",
228                value, type_name, variant
229            );
230        }
231        let _ = writeln!(
232            out,
233            "            _ => Err(format!(\"unknown variant: {{}}\", s)),"
234        );
235        let _ = writeln!(out, "        }}");
236        let _ = writeln!(out, "    }}");
237        let _ = write!(out, "}}");
238
239        Ok(out)
240    }
241
242    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
243        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
244        let mut out = String::new();
245
246        let _ = writeln!(out, "#[derive(Debug, Clone)]");
247        let _ = writeln!(out, "pub struct {} {{", struct_name);
248        for field in &composite.fields {
249            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
250                .map(|t| t.into_owned())
251                .map_err(|e| {
252                    ScytheError::new(
253                        ErrorCode::InternalError,
254                        format!("composite field type error: {}", e),
255                    )
256                })?;
257            let _ = writeln!(
258                out,
259                "    pub {}: {},",
260                to_snake_case(&field.name),
261                rust_type
262            );
263        }
264        let _ = write!(out, "}}");
265        Ok(out)
266    }
267}
268
269// ---------------------------------------------------------------------------
270// Internal helpers
271// ---------------------------------------------------------------------------
272
273/// Generate a struct with a `from_row` method for tokio-postgres.
274fn generate_struct_with_from_row(
275    struct_name: &str,
276    columns: &[ResolvedColumn],
277) -> Result<String, ScytheError> {
278    let mut out = String::new();
279
280    let _ = writeln!(out, "#[derive(Debug, Clone)]");
281    let _ = writeln!(out, "pub struct {} {{", struct_name);
282    for col in columns {
283        let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
284    }
285    let _ = writeln!(out, "}}");
286    let _ = writeln!(out);
287
288    let _ = writeln!(out, "impl {} {{", struct_name);
289    let _ = writeln!(
290        out,
291        "    pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
292    );
293    let _ = writeln!(out, "        Self {{");
294    for col in columns {
295        if col.neutral_type.starts_with("enum::") {
296            // Enum columns need string conversion
297            if col.nullable {
298                let _ = writeln!(
299                    out,
300                    "            {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
301                    col.field_name, col.name
302                );
303            } else {
304                let _ = writeln!(
305                    out,
306                    "            {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
307                    col.field_name, col.name
308                );
309            }
310        } else {
311            let _ = writeln!(
312                out,
313                "            {}: row.get(\"{}\"),",
314                col.field_name, col.name
315            );
316        }
317    }
318    let _ = writeln!(out, "        }}");
319    let _ = writeln!(out, "    }}");
320    let _ = write!(out, "}}");
321
322    Ok(out)
323}