1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
19
20pub struct TokioPostgresBackend {
22 manifest: BackendManifest,
23}
24
25impl TokioPostgresBackend {
26 pub fn new(engine: &str) -> Result<Self, ScytheError> {
27 match engine {
28 "postgresql" | "postgres" | "pg" => {}
29 _ => {
30 return Err(ScytheError::new(
31 ErrorCode::InternalError,
32 format!(
33 "rust-tokio-postgres only supports PostgreSQL, got engine '{}'",
34 engine
35 ),
36 ));
37 }
38 }
39 let manifest = load_tokio_postgres_manifest()?;
40 Ok(Self { manifest })
41 }
42}
43
44fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
45 let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
46 if manifest_path.exists() {
47 load_manifest(manifest_path).map_err(|e| {
48 ScytheError::new(
49 ErrorCode::InternalError,
50 format!("failed to load manifest: {e}"),
51 )
52 })
53 } else {
54 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
55 ScytheError::new(
56 ErrorCode::InternalError,
57 format!("failed to parse embedded manifest: {e}"),
58 )
59 })
60 }
61}
62
63impl CodegenBackend for TokioPostgresBackend {
64 fn name(&self) -> &str {
65 "rust-tokio-postgres"
66 }
67
68 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
69 &self.manifest
70 }
71
72 fn file_header(&self) -> String {
73 "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
74 .to_string()
75 }
76
77 fn generate_row_struct(
78 &self,
79 query_name: &str,
80 columns: &[ResolvedColumn],
81 ) -> Result<String, ScytheError> {
82 let struct_name = row_struct_name(query_name, &self.manifest.naming);
83 generate_struct_with_from_row(&struct_name, columns)
84 }
85
86 fn generate_model_struct(
87 &self,
88 table_name: &str,
89 columns: &[ResolvedColumn],
90 ) -> Result<String, ScytheError> {
91 let singular = singularize(table_name);
92 let struct_name = to_pascal_case(&singular).into_owned();
93 generate_struct_with_from_row(&struct_name, columns)
94 }
95
96 fn generate_query_fn(
97 &self,
98 analyzed: &AnalyzedQuery,
99 struct_name: &str,
100 _columns: &[ResolvedColumn],
101 params: &[ResolvedParam],
102 ) -> Result<String, ScytheError> {
103 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
104 let mut out = String::new();
105
106 if let Some(ref msg) = analyzed.deprecated {
108 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
109 }
110
111 let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
113 for param in params {
114 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
115 }
116
117 let sql = super::clean_sql_with_optional(
119 &analyzed.sql,
120 &analyzed.optional_params,
121 &analyzed.params,
122 );
123
124 if matches!(analyzed.command, QueryCommand::Batch) {
126 let batch_fn_name = format!("{}_batch", func_name);
127
128 if params.len() > 1 {
129 let params_struct_name = format!("{}BatchParams", struct_name);
130 let _ = writeln!(out, "#[derive(Debug, Clone)]");
131 let _ = writeln!(out, "pub struct {} {{", params_struct_name);
132 for param in params {
133 let _ = writeln!(out, " pub {}: {},", param.field_name, param.full_type);
134 }
135 let _ = writeln!(out, "}}");
136 let _ = writeln!(out);
137 let _ = writeln!(
138 out,
139 "pub async fn {}(client: &tokio_postgres::Client, items: &[{}]) -> Result<(), tokio_postgres::Error> {{",
140 batch_fn_name, params_struct_name
141 );
142 let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
143 let _ = writeln!(out, " let tx = client.transaction().await?;");
144 let _ = writeln!(out, " for item in items {{");
145 let refs: Vec<String> = params
146 .iter()
147 .map(|p| {
148 if p.neutral_type.starts_with("enum::") {
149 format!("&item.{}.to_string()", p.field_name)
150 } else {
151 format!("&item.{}", p.field_name)
152 }
153 })
154 .collect();
155 let _ = writeln!(
156 out,
157 " tx.execute(&stmt, &[{}]).await?;",
158 refs.join(", ")
159 );
160 let _ = writeln!(out, " }}");
161 let _ = writeln!(out, " tx.commit().await?;");
162 let _ = writeln!(out, " Ok(())");
163 } else if params.len() == 1 {
164 let param = ¶ms[0];
165 let _ = writeln!(
166 out,
167 "pub async fn {}(client: &tokio_postgres::Client, items: &[{}]) -> Result<(), tokio_postgres::Error> {{",
168 batch_fn_name, param.full_type
169 );
170 let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
171 let _ = writeln!(out, " let tx = client.transaction().await?;");
172 let _ = writeln!(out, " for item in items {{");
173 let _ = writeln!(out, " tx.execute(&stmt, &[item]).await?;");
174 let _ = writeln!(out, " }}");
175 let _ = writeln!(out, " tx.commit().await?;");
176 let _ = writeln!(out, " Ok(())");
177 } else {
178 let _ = writeln!(
179 out,
180 "pub async fn {}(client: &tokio_postgres::Client, count: usize) -> Result<(), tokio_postgres::Error> {{",
181 batch_fn_name
182 );
183 let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
184 let _ = writeln!(out, " let tx = client.transaction().await?;");
185 let _ = writeln!(out, " for _ in 0..count {{");
186 let _ = writeln!(out, " tx.execute(&stmt, &[]).await?;");
187 let _ = writeln!(out, " }}");
188 let _ = writeln!(out, " tx.commit().await?;");
189 let _ = writeln!(out, " Ok(())");
190 }
191
192 let _ = write!(out, "}}");
193 return Ok(out);
194 }
195
196 let return_type = match &analyzed.command {
198 QueryCommand::One => struct_name.to_string(),
199 QueryCommand::Many => format!("Vec<{}>", struct_name),
200 QueryCommand::Exec => "()".to_string(),
201 QueryCommand::ExecResult => "u64".to_string(),
202 QueryCommand::ExecRows => "u64".to_string(),
203 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
204 };
205
206 let _ = writeln!(
208 out,
209 "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
210 func_name,
211 param_parts.join(", "),
212 return_type
213 );
214
215 let param_refs: String = if params.is_empty() {
217 "&[]".to_string()
218 } else {
219 let refs: Vec<String> = params
220 .iter()
221 .map(|p| {
222 if p.neutral_type.starts_with("enum::") {
223 format!("&{}.to_string()", p.field_name)
224 } else {
225 format!("&{}", p.field_name)
226 }
227 })
228 .collect();
229 format!("&[{}]", refs.join(", "))
230 };
231
232 match &analyzed.command {
233 QueryCommand::One => {
234 let _ = writeln!(
235 out,
236 " let row = client.query_one(r#\"{}\"#, {}).await?;",
237 sql, param_refs
238 );
239 let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
240 }
241 QueryCommand::Many => {
242 let _ = writeln!(
243 out,
244 " let rows = client.query(r#\"{}\"#, {}).await?;",
245 sql, param_refs
246 );
247 let _ = writeln!(
248 out,
249 " Ok(rows.iter().map({}::from_row).collect())",
250 struct_name
251 );
252 }
253 QueryCommand::Exec => {
254 let _ = writeln!(
255 out,
256 " client.execute(r#\"{}\"#, {}).await?;",
257 sql, param_refs
258 );
259 let _ = writeln!(out, " Ok(())");
260 }
261 QueryCommand::ExecResult | QueryCommand::ExecRows => {
262 let _ = writeln!(
263 out,
264 " let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
265 sql, param_refs
266 );
267 let _ = writeln!(out, " Ok(rows_affected)");
268 }
269 QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
270 }
271
272 let _ = write!(out, "}}");
273 Ok(out)
274 }
275
276 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
277 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
278 let mut out = String::with_capacity(512);
279
280 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
281 let _ = writeln!(out, "pub enum {} {{", type_name);
282 for value in &enum_info.values {
283 let variant = enum_variant_name(value, &self.manifest.naming);
284 let _ = writeln!(out, " {},", variant);
285 }
286 let _ = writeln!(out, "}}");
287 let _ = writeln!(out);
288
289 let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
291 let _ = writeln!(
292 out,
293 " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
294 );
295 let _ = writeln!(out, " match self {{");
296 for value in &enum_info.values {
297 let variant = enum_variant_name(value, &self.manifest.naming);
298 let _ = writeln!(
299 out,
300 " {}::{} => write!(f, \"{}\"),",
301 type_name, variant, value
302 );
303 }
304 let _ = writeln!(out, " }}");
305 let _ = writeln!(out, " }}");
306 let _ = writeln!(out, "}}");
307 let _ = writeln!(out);
308
309 let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
311 let _ = writeln!(out, " type Err = String;");
312 let _ = writeln!(
313 out,
314 " fn from_str(s: &str) -> Result<Self, Self::Err> {{"
315 );
316 let _ = writeln!(out, " match s {{");
317 for value in &enum_info.values {
318 let variant = enum_variant_name(value, &self.manifest.naming);
319 let _ = writeln!(
320 out,
321 " \"{}\" => Ok({}::{}),",
322 value, type_name, variant
323 );
324 }
325 let _ = writeln!(
326 out,
327 " _ => Err(format!(\"unknown variant: {{}}\", s)),"
328 );
329 let _ = writeln!(out, " }}");
330 let _ = writeln!(out, " }}");
331 let _ = write!(out, "}}");
332
333 Ok(out)
334 }
335
336 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
337 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
338 let mut out = String::new();
339
340 let _ = writeln!(out, "#[derive(Debug, Clone)]");
341 let _ = writeln!(out, "pub struct {} {{", struct_name);
342 for field in &composite.fields {
343 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
344 .map(|t| t.into_owned())
345 .map_err(|e| {
346 ScytheError::new(
347 ErrorCode::InternalError,
348 format!("composite field type error: {}", e),
349 )
350 })?;
351 let _ = writeln!(
352 out,
353 " pub {}: {},",
354 to_snake_case(&field.name),
355 rust_type
356 );
357 }
358 let _ = write!(out, "}}");
359 Ok(out)
360 }
361}
362
363fn generate_struct_with_from_row(
369 struct_name: &str,
370 columns: &[ResolvedColumn],
371) -> Result<String, ScytheError> {
372 let mut out = String::new();
373
374 let _ = writeln!(out, "#[derive(Debug, Clone)]");
375 let _ = writeln!(out, "pub struct {} {{", struct_name);
376 for col in columns {
377 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
378 }
379 let _ = writeln!(out, "}}");
380 let _ = writeln!(out);
381
382 let _ = writeln!(out, "impl {} {{", struct_name);
383 let _ = writeln!(
384 out,
385 " pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
386 );
387 let _ = writeln!(out, " Self {{");
388 for col in columns {
389 if col.neutral_type.starts_with("enum::") {
390 if col.nullable {
392 let _ = writeln!(
393 out,
394 " {field}: row.get::<_, Option<String>>(\"{col}\").map(|s| s.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", s))),",
395 field = col.field_name,
396 col = col.name
397 );
398 } else {
399 let _ = writeln!(
400 out,
401 " {field}: {{ let val = row.get::<_, String>(\"{col}\"); val.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", val)) }},",
402 field = col.field_name,
403 col = col.name
404 );
405 }
406 } else {
407 let _ = writeln!(
408 out,
409 " {}: row.get(\"{}\"),",
410 col.field_name, col.name
411 );
412 }
413 }
414 let _ = writeln!(out, " }}");
415 let _ = writeln!(out, " }}");
416 let _ = write!(out, "}}");
417
418 Ok(out)
419}