Skip to main content

scythe_codegen/backends/
tokio_postgres.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17/// Default embedded manifest TOML for rust-tokio-postgres, used as fallback.
18const DEFAULT_MANIFEST_TOML: &str =
19    include_str!("../../manifests/rust-tokio-postgres.toml");
20
21/// TokioPostgresBackend generates Rust code targeting the tokio-postgres crate.
22pub struct TokioPostgresBackend {
23    manifest: BackendManifest,
24}
25
26impl TokioPostgresBackend {
27    pub fn new() -> Result<Self, ScytheError> {
28        let manifest = load_tokio_postgres_manifest()?;
29        Ok(Self { manifest })
30    }
31
32    pub fn manifest(&self) -> &BackendManifest {
33        &self.manifest
34    }
35}
36
37fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
38    let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
39    if manifest_path.exists() {
40        load_manifest(manifest_path).map_err(|e| {
41            ScytheError::new(
42                ErrorCode::InternalError,
43                format!("failed to load manifest: {e}"),
44            )
45        })
46    } else {
47        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
48            ScytheError::new(
49                ErrorCode::InternalError,
50                format!("failed to parse embedded manifest: {e}"),
51            )
52        })
53    }
54}
55
56impl CodegenBackend for TokioPostgresBackend {
57    fn name(&self) -> &str {
58        "rust-tokio-postgres"
59    }
60
61    fn generate_row_struct(
62        &self,
63        query_name: &str,
64        columns: &[ResolvedColumn],
65    ) -> Result<String, ScytheError> {
66        let struct_name = row_struct_name(query_name, &self.manifest.naming);
67        generate_struct_with_from_row(&struct_name, columns)
68    }
69
70    fn generate_model_struct(
71        &self,
72        table_name: &str,
73        columns: &[ResolvedColumn],
74    ) -> Result<String, ScytheError> {
75        let singular = singularize(table_name);
76        let struct_name = to_pascal_case(&singular).into_owned();
77        generate_struct_with_from_row(&struct_name, columns)
78    }
79
80    fn generate_query_fn(
81        &self,
82        analyzed: &AnalyzedQuery,
83        struct_name: &str,
84        _columns: &[ResolvedColumn],
85        params: &[ResolvedParam],
86    ) -> Result<String, ScytheError> {
87        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
88        let mut out = String::new();
89
90        // Deprecated annotation
91        if let Some(ref msg) = analyzed.deprecated {
92            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
93        }
94
95        // Build parameter list
96        let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
97        for param in params {
98            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
99        }
100
101        // Return type
102        let return_type = match &analyzed.command {
103            QueryCommand::One => struct_name.to_string(),
104            QueryCommand::Many => format!("Vec<{}>", struct_name),
105            QueryCommand::Exec => "()".to_string(),
106            QueryCommand::ExecResult => "u64".to_string(),
107            QueryCommand::ExecRows => "u64".to_string(),
108            QueryCommand::Batch => format!("Vec<{}>", struct_name),
109        };
110
111        // Function signature
112        let _ = writeln!(
113            out,
114            "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
115            func_name,
116            param_parts.join(", "),
117            return_type
118        );
119
120        // Clean SQL
121        let sql = super::clean_sql(&analyzed.sql);
122
123        // Build param references for the query call
124        let param_refs: String = if params.is_empty() {
125            "&[]".to_string()
126        } else {
127            let refs: Vec<String> = params
128                .iter()
129                .map(|p| {
130                    if p.neutral_type.starts_with("enum::") {
131                        format!("&{}.to_string()", p.field_name)
132                    } else {
133                        format!("&{}", p.field_name)
134                    }
135                })
136                .collect();
137            format!("&[{}]", refs.join(", "))
138        };
139
140        match &analyzed.command {
141            QueryCommand::One => {
142                let _ = writeln!(
143                    out,
144                    "    let row = client.query_one(r#\"{}\"#, {}).await?;",
145                    sql, param_refs
146                );
147                let _ = writeln!(out, "    Ok({}::from_row(&row))", struct_name);
148            }
149            QueryCommand::Many | QueryCommand::Batch => {
150                let _ = writeln!(
151                    out,
152                    "    let rows = client.query(r#\"{}\"#, {}).await?;",
153                    sql, param_refs
154                );
155                let _ = writeln!(
156                    out,
157                    "    Ok(rows.iter().map({}::from_row).collect())",
158                    struct_name
159                );
160            }
161            QueryCommand::Exec => {
162                let _ = writeln!(
163                    out,
164                    "    client.execute(r#\"{}\"#, {}).await?;",
165                    sql, param_refs
166                );
167                let _ = writeln!(out, "    Ok(())");
168            }
169            QueryCommand::ExecResult | QueryCommand::ExecRows => {
170                let _ = writeln!(
171                    out,
172                    "    let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
173                    sql, param_refs
174                );
175                let _ = writeln!(out, "    Ok(rows_affected)");
176            }
177        }
178
179        let _ = write!(out, "}}");
180        Ok(out)
181    }
182
183    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
184        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
185        let mut out = String::with_capacity(512);
186
187        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
188        let _ = writeln!(out, "pub enum {} {{", type_name);
189        for value in &enum_info.values {
190            let variant = enum_variant_name(value, &self.manifest.naming);
191            let _ = writeln!(out, "    {},", variant);
192        }
193        let _ = writeln!(out, "}}");
194        let _ = writeln!(out);
195
196        // impl Display for serialization
197        let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
198        let _ = writeln!(
199            out,
200            "    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
201        );
202        let _ = writeln!(out, "        match self {{");
203        for value in &enum_info.values {
204            let variant = enum_variant_name(value, &self.manifest.naming);
205            let _ = writeln!(
206                out,
207                "            {}::{} => write!(f, \"{}\"),",
208                type_name, variant, value
209            );
210        }
211        let _ = writeln!(out, "        }}");
212        let _ = writeln!(out, "    }}");
213        let _ = writeln!(out, "}}");
214        let _ = writeln!(out);
215
216        // impl FromStr for deserialization
217        let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
218        let _ = writeln!(out, "    type Err = String;");
219        let _ = writeln!(
220            out,
221            "    fn from_str(s: &str) -> Result<Self, Self::Err> {{"
222        );
223        let _ = writeln!(out, "        match s {{");
224        for value in &enum_info.values {
225            let variant = enum_variant_name(value, &self.manifest.naming);
226            let _ = writeln!(
227                out,
228                "            \"{}\" => Ok({}::{}),",
229                value, type_name, variant
230            );
231        }
232        let _ = writeln!(
233            out,
234            "            _ => Err(format!(\"unknown variant: {{}}\", s)),"
235        );
236        let _ = writeln!(out, "        }}");
237        let _ = writeln!(out, "    }}");
238        let _ = write!(out, "}}");
239
240        Ok(out)
241    }
242
243    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
244        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
245        let mut out = String::new();
246
247        let _ = writeln!(out, "#[derive(Debug, Clone)]");
248        let _ = writeln!(out, "pub struct {} {{", struct_name);
249        for field in &composite.fields {
250            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
251                .map(|t| t.into_owned())
252                .map_err(|e| {
253                    ScytheError::new(
254                        ErrorCode::InternalError,
255                        format!("composite field type error: {}", e),
256                    )
257                })?;
258            let _ = writeln!(
259                out,
260                "    pub {}: {},",
261                to_snake_case(&field.name),
262                rust_type
263            );
264        }
265        let _ = write!(out, "}}");
266        Ok(out)
267    }
268}
269
270// ---------------------------------------------------------------------------
271// Internal helpers
272// ---------------------------------------------------------------------------
273
274/// Generate a struct with a `from_row` method for tokio-postgres.
275fn generate_struct_with_from_row(
276    struct_name: &str,
277    columns: &[ResolvedColumn],
278) -> Result<String, ScytheError> {
279    let mut out = String::new();
280
281    let _ = writeln!(out, "#[derive(Debug, Clone)]");
282    let _ = writeln!(out, "pub struct {} {{", struct_name);
283    for col in columns {
284        let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
285    }
286    let _ = writeln!(out, "}}");
287    let _ = writeln!(out);
288
289    let _ = writeln!(out, "impl {} {{", struct_name);
290    let _ = writeln!(
291        out,
292        "    pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
293    );
294    let _ = writeln!(out, "        Self {{");
295    for col in columns {
296        if col.neutral_type.starts_with("enum::") {
297            // Enum columns need string conversion
298            if col.nullable {
299                let _ = writeln!(
300                    out,
301                    "            {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
302                    col.field_name, col.name
303                );
304            } else {
305                let _ = writeln!(
306                    out,
307                    "            {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
308                    col.field_name, col.name
309                );
310            }
311        } else {
312            let _ = writeln!(
313                out,
314                "            {}: row.get(\"{}\"),",
315                col.field_name, col.name
316            );
317        }
318    }
319    let _ = writeln!(out, "        }}");
320    let _ = writeln!(out, "    }}");
321    let _ = write!(out, "}}");
322
323    Ok(out)
324}