1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
19
20pub struct TokioPostgresBackend {
22 manifest: BackendManifest,
23}
24
25impl TokioPostgresBackend {
26 pub fn new(engine: &str) -> Result<Self, ScytheError> {
27 match engine {
28 "postgresql" | "postgres" | "pg" => {}
29 _ => {
30 return Err(ScytheError::new(
31 ErrorCode::InternalError,
32 format!(
33 "rust-tokio-postgres only supports PostgreSQL, got engine '{}'",
34 engine
35 ),
36 ));
37 }
38 }
39 let manifest = load_tokio_postgres_manifest()?;
40 Ok(Self { manifest })
41 }
42}
43
44fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
45 let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
46 if manifest_path.exists() {
47 load_manifest(manifest_path).map_err(|e| {
48 ScytheError::new(
49 ErrorCode::InternalError,
50 format!("failed to load manifest: {e}"),
51 )
52 })
53 } else {
54 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
55 ScytheError::new(
56 ErrorCode::InternalError,
57 format!("failed to parse embedded manifest: {e}"),
58 )
59 })
60 }
61}
62
63impl CodegenBackend for TokioPostgresBackend {
64 fn name(&self) -> &str {
65 "rust-tokio-postgres"
66 }
67
68 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
69 &self.manifest
70 }
71
72 fn file_header(&self) -> String {
73 "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
74 .to_string()
75 }
76
77 fn generate_row_struct(
78 &self,
79 query_name: &str,
80 columns: &[ResolvedColumn],
81 ) -> Result<String, ScytheError> {
82 let struct_name = row_struct_name(query_name, &self.manifest.naming);
83 generate_struct_with_from_row(&struct_name, columns)
84 }
85
86 fn generate_model_struct(
87 &self,
88 table_name: &str,
89 columns: &[ResolvedColumn],
90 ) -> Result<String, ScytheError> {
91 let singular = singularize(table_name);
92 let struct_name = to_pascal_case(&singular).into_owned();
93 generate_struct_with_from_row(&struct_name, columns)
94 }
95
96 fn generate_query_fn(
97 &self,
98 analyzed: &AnalyzedQuery,
99 struct_name: &str,
100 _columns: &[ResolvedColumn],
101 params: &[ResolvedParam],
102 ) -> Result<String, ScytheError> {
103 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
104 let mut out = String::new();
105
106 if let Some(ref msg) = analyzed.deprecated {
108 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
109 }
110
111 let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
113 for param in params {
114 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
115 }
116
117 let return_type = match &analyzed.command {
119 QueryCommand::One => struct_name.to_string(),
120 QueryCommand::Many => format!("Vec<{}>", struct_name),
121 QueryCommand::Exec => "()".to_string(),
122 QueryCommand::ExecResult => "u64".to_string(),
123 QueryCommand::ExecRows => "u64".to_string(),
124 QueryCommand::Batch => format!("Vec<{}>", struct_name),
125 };
126
127 let _ = writeln!(
129 out,
130 "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
131 func_name,
132 param_parts.join(", "),
133 return_type
134 );
135
136 let sql = super::clean_sql(&analyzed.sql);
138
139 let param_refs: String = if params.is_empty() {
141 "&[]".to_string()
142 } else {
143 let refs: Vec<String> = params
144 .iter()
145 .map(|p| {
146 if p.neutral_type.starts_with("enum::") {
147 format!("&{}.to_string()", p.field_name)
148 } else {
149 format!("&{}", p.field_name)
150 }
151 })
152 .collect();
153 format!("&[{}]", refs.join(", "))
154 };
155
156 match &analyzed.command {
157 QueryCommand::One => {
158 let _ = writeln!(
159 out,
160 " let row = client.query_one(r#\"{}\"#, {}).await?;",
161 sql, param_refs
162 );
163 let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
164 }
165 QueryCommand::Many | QueryCommand::Batch => {
166 let _ = writeln!(
167 out,
168 " let rows = client.query(r#\"{}\"#, {}).await?;",
169 sql, param_refs
170 );
171 let _ = writeln!(
172 out,
173 " Ok(rows.iter().map({}::from_row).collect())",
174 struct_name
175 );
176 }
177 QueryCommand::Exec => {
178 let _ = writeln!(
179 out,
180 " client.execute(r#\"{}\"#, {}).await?;",
181 sql, param_refs
182 );
183 let _ = writeln!(out, " Ok(())");
184 }
185 QueryCommand::ExecResult | QueryCommand::ExecRows => {
186 let _ = writeln!(
187 out,
188 " let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
189 sql, param_refs
190 );
191 let _ = writeln!(out, " Ok(rows_affected)");
192 }
193 }
194
195 let _ = write!(out, "}}");
196 Ok(out)
197 }
198
199 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
200 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
201 let mut out = String::with_capacity(512);
202
203 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
204 let _ = writeln!(out, "pub enum {} {{", type_name);
205 for value in &enum_info.values {
206 let variant = enum_variant_name(value, &self.manifest.naming);
207 let _ = writeln!(out, " {},", variant);
208 }
209 let _ = writeln!(out, "}}");
210 let _ = writeln!(out);
211
212 let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
214 let _ = writeln!(
215 out,
216 " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
217 );
218 let _ = writeln!(out, " match self {{");
219 for value in &enum_info.values {
220 let variant = enum_variant_name(value, &self.manifest.naming);
221 let _ = writeln!(
222 out,
223 " {}::{} => write!(f, \"{}\"),",
224 type_name, variant, value
225 );
226 }
227 let _ = writeln!(out, " }}");
228 let _ = writeln!(out, " }}");
229 let _ = writeln!(out, "}}");
230 let _ = writeln!(out);
231
232 let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
234 let _ = writeln!(out, " type Err = String;");
235 let _ = writeln!(
236 out,
237 " fn from_str(s: &str) -> Result<Self, Self::Err> {{"
238 );
239 let _ = writeln!(out, " match s {{");
240 for value in &enum_info.values {
241 let variant = enum_variant_name(value, &self.manifest.naming);
242 let _ = writeln!(
243 out,
244 " \"{}\" => Ok({}::{}),",
245 value, type_name, variant
246 );
247 }
248 let _ = writeln!(
249 out,
250 " _ => Err(format!(\"unknown variant: {{}}\", s)),"
251 );
252 let _ = writeln!(out, " }}");
253 let _ = writeln!(out, " }}");
254 let _ = write!(out, "}}");
255
256 Ok(out)
257 }
258
259 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
260 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
261 let mut out = String::new();
262
263 let _ = writeln!(out, "#[derive(Debug, Clone)]");
264 let _ = writeln!(out, "pub struct {} {{", struct_name);
265 for field in &composite.fields {
266 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
267 .map(|t| t.into_owned())
268 .map_err(|e| {
269 ScytheError::new(
270 ErrorCode::InternalError,
271 format!("composite field type error: {}", e),
272 )
273 })?;
274 let _ = writeln!(
275 out,
276 " pub {}: {},",
277 to_snake_case(&field.name),
278 rust_type
279 );
280 }
281 let _ = write!(out, "}}");
282 Ok(out)
283 }
284}
285
286fn generate_struct_with_from_row(
292 struct_name: &str,
293 columns: &[ResolvedColumn],
294) -> Result<String, ScytheError> {
295 let mut out = String::new();
296
297 let _ = writeln!(out, "#[derive(Debug, Clone)]");
298 let _ = writeln!(out, "pub struct {} {{", struct_name);
299 for col in columns {
300 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
301 }
302 let _ = writeln!(out, "}}");
303 let _ = writeln!(out);
304
305 let _ = writeln!(out, "impl {} {{", struct_name);
306 let _ = writeln!(
307 out,
308 " pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
309 );
310 let _ = writeln!(out, " Self {{");
311 for col in columns {
312 if col.neutral_type.starts_with("enum::") {
313 if col.nullable {
315 let _ = writeln!(
316 out,
317 " {field}: row.get::<_, Option<String>>(\"{col}\").map(|s| s.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", s))),",
318 field = col.field_name,
319 col = col.name
320 );
321 } else {
322 let _ = writeln!(
323 out,
324 " {field}: {{ let val = row.get::<_, String>(\"{col}\"); val.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", val)) }},",
325 field = col.field_name,
326 col = col.name
327 );
328 }
329 } else {
330 let _ = writeln!(
331 out,
332 " {}: row.get(\"{}\"),",
333 col.field_name, col.name
334 );
335 }
336 }
337 let _ = writeln!(out, " }}");
338 let _ = writeln!(out, " }}");
339 let _ = write!(out, "}}");
340
341 Ok(out)
342}