1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
19
20pub struct TokioPostgresBackend {
22 manifest: BackendManifest,
23}
24
25impl TokioPostgresBackend {
26 pub fn new() -> Result<Self, ScytheError> {
27 let manifest = load_tokio_postgres_manifest()?;
28 Ok(Self { manifest })
29 }
30
31 pub fn manifest(&self) -> &BackendManifest {
32 &self.manifest
33 }
34}
35
36fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
37 let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
38 if manifest_path.exists() {
39 load_manifest(manifest_path).map_err(|e| {
40 ScytheError::new(
41 ErrorCode::InternalError,
42 format!("failed to load manifest: {e}"),
43 )
44 })
45 } else {
46 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
47 ScytheError::new(
48 ErrorCode::InternalError,
49 format!("failed to parse embedded manifest: {e}"),
50 )
51 })
52 }
53}
54
55impl CodegenBackend for TokioPostgresBackend {
56 fn name(&self) -> &str {
57 "rust-tokio-postgres"
58 }
59
60 fn generate_row_struct(
61 &self,
62 query_name: &str,
63 columns: &[ResolvedColumn],
64 ) -> Result<String, ScytheError> {
65 let struct_name = row_struct_name(query_name, &self.manifest.naming);
66 generate_struct_with_from_row(&struct_name, columns)
67 }
68
69 fn generate_model_struct(
70 &self,
71 table_name: &str,
72 columns: &[ResolvedColumn],
73 ) -> Result<String, ScytheError> {
74 let singular = singularize(table_name);
75 let struct_name = to_pascal_case(&singular).into_owned();
76 generate_struct_with_from_row(&struct_name, columns)
77 }
78
79 fn generate_query_fn(
80 &self,
81 analyzed: &AnalyzedQuery,
82 struct_name: &str,
83 _columns: &[ResolvedColumn],
84 params: &[ResolvedParam],
85 ) -> Result<String, ScytheError> {
86 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
87 let mut out = String::new();
88
89 if let Some(ref msg) = analyzed.deprecated {
91 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
92 }
93
94 let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
96 for param in params {
97 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
98 }
99
100 let return_type = match &analyzed.command {
102 QueryCommand::One => struct_name.to_string(),
103 QueryCommand::Many => format!("Vec<{}>", struct_name),
104 QueryCommand::Exec => "()".to_string(),
105 QueryCommand::ExecResult => "u64".to_string(),
106 QueryCommand::ExecRows => "u64".to_string(),
107 QueryCommand::Batch => format!("Vec<{}>", struct_name),
108 };
109
110 let _ = writeln!(
112 out,
113 "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
114 func_name,
115 param_parts.join(", "),
116 return_type
117 );
118
119 let sql = super::clean_sql(&analyzed.sql);
121
122 let param_refs: String = if params.is_empty() {
124 "&[]".to_string()
125 } else {
126 let refs: Vec<String> = params
127 .iter()
128 .map(|p| {
129 if p.neutral_type.starts_with("enum::") {
130 format!("&{}.to_string()", p.field_name)
131 } else {
132 format!("&{}", p.field_name)
133 }
134 })
135 .collect();
136 format!("&[{}]", refs.join(", "))
137 };
138
139 match &analyzed.command {
140 QueryCommand::One => {
141 let _ = writeln!(
142 out,
143 " let row = client.query_one(r#\"{}\"#, {}).await?;",
144 sql, param_refs
145 );
146 let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
147 }
148 QueryCommand::Many | QueryCommand::Batch => {
149 let _ = writeln!(
150 out,
151 " let rows = client.query(r#\"{}\"#, {}).await?;",
152 sql, param_refs
153 );
154 let _ = writeln!(
155 out,
156 " Ok(rows.iter().map({}::from_row).collect())",
157 struct_name
158 );
159 }
160 QueryCommand::Exec => {
161 let _ = writeln!(
162 out,
163 " client.execute(r#\"{}\"#, {}).await?;",
164 sql, param_refs
165 );
166 let _ = writeln!(out, " Ok(())");
167 }
168 QueryCommand::ExecResult | QueryCommand::ExecRows => {
169 let _ = writeln!(
170 out,
171 " let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
172 sql, param_refs
173 );
174 let _ = writeln!(out, " Ok(rows_affected)");
175 }
176 }
177
178 let _ = write!(out, "}}");
179 Ok(out)
180 }
181
182 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
183 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
184 let mut out = String::with_capacity(512);
185
186 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
187 let _ = writeln!(out, "pub enum {} {{", type_name);
188 for value in &enum_info.values {
189 let variant = enum_variant_name(value, &self.manifest.naming);
190 let _ = writeln!(out, " {},", variant);
191 }
192 let _ = writeln!(out, "}}");
193 let _ = writeln!(out);
194
195 let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
197 let _ = writeln!(
198 out,
199 " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
200 );
201 let _ = writeln!(out, " match self {{");
202 for value in &enum_info.values {
203 let variant = enum_variant_name(value, &self.manifest.naming);
204 let _ = writeln!(
205 out,
206 " {}::{} => write!(f, \"{}\"),",
207 type_name, variant, value
208 );
209 }
210 let _ = writeln!(out, " }}");
211 let _ = writeln!(out, " }}");
212 let _ = writeln!(out, "}}");
213 let _ = writeln!(out);
214
215 let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
217 let _ = writeln!(out, " type Err = String;");
218 let _ = writeln!(
219 out,
220 " fn from_str(s: &str) -> Result<Self, Self::Err> {{"
221 );
222 let _ = writeln!(out, " match s {{");
223 for value in &enum_info.values {
224 let variant = enum_variant_name(value, &self.manifest.naming);
225 let _ = writeln!(
226 out,
227 " \"{}\" => Ok({}::{}),",
228 value, type_name, variant
229 );
230 }
231 let _ = writeln!(
232 out,
233 " _ => Err(format!(\"unknown variant: {{}}\", s)),"
234 );
235 let _ = writeln!(out, " }}");
236 let _ = writeln!(out, " }}");
237 let _ = write!(out, "}}");
238
239 Ok(out)
240 }
241
242 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
243 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
244 let mut out = String::new();
245
246 let _ = writeln!(out, "#[derive(Debug, Clone)]");
247 let _ = writeln!(out, "pub struct {} {{", struct_name);
248 for field in &composite.fields {
249 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
250 .map(|t| t.into_owned())
251 .map_err(|e| {
252 ScytheError::new(
253 ErrorCode::InternalError,
254 format!("composite field type error: {}", e),
255 )
256 })?;
257 let _ = writeln!(
258 out,
259 " pub {}: {},",
260 to_snake_case(&field.name),
261 rust_type
262 );
263 }
264 let _ = write!(out, "}}");
265 Ok(out)
266 }
267}
268
269fn generate_struct_with_from_row(
275 struct_name: &str,
276 columns: &[ResolvedColumn],
277) -> Result<String, ScytheError> {
278 let mut out = String::new();
279
280 let _ = writeln!(out, "#[derive(Debug, Clone)]");
281 let _ = writeln!(out, "pub struct {} {{", struct_name);
282 for col in columns {
283 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
284 }
285 let _ = writeln!(out, "}}");
286 let _ = writeln!(out);
287
288 let _ = writeln!(out, "impl {} {{", struct_name);
289 let _ = writeln!(
290 out,
291 " pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
292 );
293 let _ = writeln!(out, " Self {{");
294 for col in columns {
295 if col.neutral_type.starts_with("enum::") {
296 if col.nullable {
298 let _ = writeln!(
299 out,
300 " {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
301 col.field_name, col.name
302 );
303 } else {
304 let _ = writeln!(
305 out,
306 " {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
307 col.field_name, col.name
308 );
309 }
310 } else {
311 let _ = writeln!(
312 out,
313 " {}: row.get(\"{}\"),",
314 col.field_name, col.name
315 );
316 }
317 }
318 let _ = writeln!(out, " }}");
319 let _ = writeln!(out, " }}");
320 let _ = write!(out, "}}");
321
322 Ok(out)
323}