1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8use scythe_backend::types::resolve_type;
9
10use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
11use scythe_core::errors::{ErrorCode, ScytheError};
12use scythe_core::parser::QueryCommand;
13
14use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
15use crate::singularize;
16
17const DEFAULT_MANIFEST_TOML: &str =
19 include_str!("../../manifests/rust-tokio-postgres.toml");
20
21pub struct TokioPostgresBackend {
23 manifest: BackendManifest,
24}
25
26impl TokioPostgresBackend {
27 pub fn new() -> Result<Self, ScytheError> {
28 let manifest = load_tokio_postgres_manifest()?;
29 Ok(Self { manifest })
30 }
31
32 pub fn manifest(&self) -> &BackendManifest {
33 &self.manifest
34 }
35}
36
37fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
38 let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
39 if manifest_path.exists() {
40 load_manifest(manifest_path).map_err(|e| {
41 ScytheError::new(
42 ErrorCode::InternalError,
43 format!("failed to load manifest: {e}"),
44 )
45 })
46 } else {
47 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
48 ScytheError::new(
49 ErrorCode::InternalError,
50 format!("failed to parse embedded manifest: {e}"),
51 )
52 })
53 }
54}
55
56impl CodegenBackend for TokioPostgresBackend {
57 fn name(&self) -> &str {
58 "rust-tokio-postgres"
59 }
60
61 fn generate_row_struct(
62 &self,
63 query_name: &str,
64 columns: &[ResolvedColumn],
65 ) -> Result<String, ScytheError> {
66 let struct_name = row_struct_name(query_name, &self.manifest.naming);
67 generate_struct_with_from_row(&struct_name, columns)
68 }
69
70 fn generate_model_struct(
71 &self,
72 table_name: &str,
73 columns: &[ResolvedColumn],
74 ) -> Result<String, ScytheError> {
75 let singular = singularize(table_name);
76 let struct_name = to_pascal_case(&singular).into_owned();
77 generate_struct_with_from_row(&struct_name, columns)
78 }
79
80 fn generate_query_fn(
81 &self,
82 analyzed: &AnalyzedQuery,
83 struct_name: &str,
84 _columns: &[ResolvedColumn],
85 params: &[ResolvedParam],
86 ) -> Result<String, ScytheError> {
87 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
88 let mut out = String::new();
89
90 if let Some(ref msg) = analyzed.deprecated {
92 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
93 }
94
95 let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
97 for param in params {
98 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
99 }
100
101 let return_type = match &analyzed.command {
103 QueryCommand::One => struct_name.to_string(),
104 QueryCommand::Many => format!("Vec<{}>", struct_name),
105 QueryCommand::Exec => "()".to_string(),
106 QueryCommand::ExecResult => "u64".to_string(),
107 QueryCommand::ExecRows => "u64".to_string(),
108 QueryCommand::Batch => format!("Vec<{}>", struct_name),
109 };
110
111 let _ = writeln!(
113 out,
114 "pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
115 func_name,
116 param_parts.join(", "),
117 return_type
118 );
119
120 let sql = super::clean_sql(&analyzed.sql);
122
123 let param_refs: String = if params.is_empty() {
125 "&[]".to_string()
126 } else {
127 let refs: Vec<String> = params
128 .iter()
129 .map(|p| {
130 if p.neutral_type.starts_with("enum::") {
131 format!("&{}.to_string()", p.field_name)
132 } else {
133 format!("&{}", p.field_name)
134 }
135 })
136 .collect();
137 format!("&[{}]", refs.join(", "))
138 };
139
140 match &analyzed.command {
141 QueryCommand::One => {
142 let _ = writeln!(
143 out,
144 " let row = client.query_one(r#\"{}\"#, {}).await?;",
145 sql, param_refs
146 );
147 let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
148 }
149 QueryCommand::Many | QueryCommand::Batch => {
150 let _ = writeln!(
151 out,
152 " let rows = client.query(r#\"{}\"#, {}).await?;",
153 sql, param_refs
154 );
155 let _ = writeln!(
156 out,
157 " Ok(rows.iter().map({}::from_row).collect())",
158 struct_name
159 );
160 }
161 QueryCommand::Exec => {
162 let _ = writeln!(
163 out,
164 " client.execute(r#\"{}\"#, {}).await?;",
165 sql, param_refs
166 );
167 let _ = writeln!(out, " Ok(())");
168 }
169 QueryCommand::ExecResult | QueryCommand::ExecRows => {
170 let _ = writeln!(
171 out,
172 " let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
173 sql, param_refs
174 );
175 let _ = writeln!(out, " Ok(rows_affected)");
176 }
177 }
178
179 let _ = write!(out, "}}");
180 Ok(out)
181 }
182
183 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
184 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
185 let mut out = String::with_capacity(512);
186
187 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
188 let _ = writeln!(out, "pub enum {} {{", type_name);
189 for value in &enum_info.values {
190 let variant = enum_variant_name(value, &self.manifest.naming);
191 let _ = writeln!(out, " {},", variant);
192 }
193 let _ = writeln!(out, "}}");
194 let _ = writeln!(out);
195
196 let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
198 let _ = writeln!(
199 out,
200 " fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
201 );
202 let _ = writeln!(out, " match self {{");
203 for value in &enum_info.values {
204 let variant = enum_variant_name(value, &self.manifest.naming);
205 let _ = writeln!(
206 out,
207 " {}::{} => write!(f, \"{}\"),",
208 type_name, variant, value
209 );
210 }
211 let _ = writeln!(out, " }}");
212 let _ = writeln!(out, " }}");
213 let _ = writeln!(out, "}}");
214 let _ = writeln!(out);
215
216 let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
218 let _ = writeln!(out, " type Err = String;");
219 let _ = writeln!(
220 out,
221 " fn from_str(s: &str) -> Result<Self, Self::Err> {{"
222 );
223 let _ = writeln!(out, " match s {{");
224 for value in &enum_info.values {
225 let variant = enum_variant_name(value, &self.manifest.naming);
226 let _ = writeln!(
227 out,
228 " \"{}\" => Ok({}::{}),",
229 value, type_name, variant
230 );
231 }
232 let _ = writeln!(
233 out,
234 " _ => Err(format!(\"unknown variant: {{}}\", s)),"
235 );
236 let _ = writeln!(out, " }}");
237 let _ = writeln!(out, " }}");
238 let _ = write!(out, "}}");
239
240 Ok(out)
241 }
242
243 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
244 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
245 let mut out = String::new();
246
247 let _ = writeln!(out, "#[derive(Debug, Clone)]");
248 let _ = writeln!(out, "pub struct {} {{", struct_name);
249 for field in &composite.fields {
250 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
251 .map(|t| t.into_owned())
252 .map_err(|e| {
253 ScytheError::new(
254 ErrorCode::InternalError,
255 format!("composite field type error: {}", e),
256 )
257 })?;
258 let _ = writeln!(
259 out,
260 " pub {}: {},",
261 to_snake_case(&field.name),
262 rust_type
263 );
264 }
265 let _ = write!(out, "}}");
266 Ok(out)
267 }
268}
269
270fn generate_struct_with_from_row(
276 struct_name: &str,
277 columns: &[ResolvedColumn],
278) -> Result<String, ScytheError> {
279 let mut out = String::new();
280
281 let _ = writeln!(out, "#[derive(Debug, Clone)]");
282 let _ = writeln!(out, "pub struct {} {{", struct_name);
283 for col in columns {
284 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
285 }
286 let _ = writeln!(out, "}}");
287 let _ = writeln!(out);
288
289 let _ = writeln!(out, "impl {} {{", struct_name);
290 let _ = writeln!(
291 out,
292 " pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
293 );
294 let _ = writeln!(out, " Self {{");
295 for col in columns {
296 if col.neutral_type.starts_with("enum::") {
297 if col.nullable {
299 let _ = writeln!(
300 out,
301 " {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
302 col.field_name, col.name
303 );
304 } else {
305 let _ = writeln!(
306 out,
307 " {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
308 col.field_name, col.name
309 );
310 }
311 } else {
312 let _ = writeln!(
313 out,
314 " {}: row.get(\"{}\"),",
315 col.field_name, col.name
316 );
317 }
318 }
319 let _ = writeln!(out, " }}");
320 let _ = writeln!(out, " }}");
321 let _ = write!(out, "}}");
322
323 Ok(out)
324}