1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14use crate::singularize;
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
18
19pub struct SqlxBackend {
21 manifest: BackendManifest,
22}
23
24impl SqlxBackend {
25 pub fn new(engine: &str) -> Result<Self, ScytheError> {
26 match engine {
29 "postgresql" | "postgres" | "pg" | "mysql" | "mariadb" | "sqlite" | "sqlite3" => {}
30 _ => {
31 return Err(ScytheError::new(
32 ErrorCode::InternalError,
33 format!("unsupported engine '{}' for rust-sqlx backend", engine),
34 ));
35 }
36 }
37 let manifest = load_sqlx_manifest()?;
38 Ok(Self { manifest })
39 }
40}
41
42fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
43 let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
44 if manifest_path.exists() {
45 load_manifest(manifest_path).map_err(|e| {
46 ScytheError::new(
47 ErrorCode::InternalError,
48 format!("failed to load manifest: {e}"),
49 )
50 })
51 } else {
52 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
53 ScytheError::new(
54 ErrorCode::InternalError,
55 format!("failed to parse embedded manifest: {e}"),
56 )
57 })
58 }
59}
60
61impl CodegenBackend for SqlxBackend {
62 fn name(&self) -> &str {
63 "rust-sqlx"
64 }
65
66 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
67 &self.manifest
68 }
69
70 fn supported_engines(&self) -> &[&str] {
71 &["postgresql", "mysql", "sqlite"]
72 }
73
74 fn file_header(&self) -> String {
75 "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
76 .to_string()
77 }
78
79 fn generate_row_struct(
80 &self,
81 query_name: &str,
82 columns: &[ResolvedColumn],
83 ) -> Result<String, ScytheError> {
84 let struct_name = row_struct_name(query_name, &self.manifest.naming);
85 let mut out = String::new();
86
87 let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
88 let _ = writeln!(out, "pub struct {} {{", struct_name);
89
90 for col in columns {
91 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
92 }
93
94 let _ = write!(out, "}}");
95 Ok(out)
96 }
97
98 fn generate_model_struct(
99 &self,
100 table_name: &str,
101 columns: &[ResolvedColumn],
102 ) -> Result<String, ScytheError> {
103 let singular = singularize(table_name);
104 let struct_name = to_pascal_case(&singular).into_owned();
105 let mut out = String::new();
106
107 let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
108 let _ = writeln!(out, "pub struct {} {{", struct_name);
109
110 for col in columns {
111 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
112 }
113
114 let _ = write!(out, "}}");
115 Ok(out)
116 }
117
118 fn generate_query_fn(
119 &self,
120 analyzed: &AnalyzedQuery,
121 struct_name: &str,
122 _columns: &[ResolvedColumn],
123 params: &[ResolvedParam],
124 ) -> Result<String, ScytheError> {
125 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
126 let mut out = String::new();
127
128 if let Some(ref msg) = analyzed.deprecated {
130 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
131 }
132
133 let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
135 for param in params {
136 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
137 }
138
139 let sql_raw = super::clean_sql_with_optional(
141 &analyzed.sql,
142 &analyzed.optional_params,
143 &analyzed.params,
144 );
145 let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
146
147 let bind_params: String = analyzed
149 .params
150 .iter()
151 .map(|p| {
152 let param_name = to_snake_case(&p.name);
153 if p.neutral_type.starts_with("enum::") {
154 let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
155 let rust_type = enum_type_name(enum_name, &self.manifest.naming);
156 format!(", {} as &{}", param_name, rust_type)
157 } else {
158 format!(", {}", param_name)
159 }
160 })
161 .collect();
162
163 if matches!(analyzed.command, QueryCommand::Batch) {
165 let batch_fn_name = format!("{}_batch", func_name);
166
167 if params.len() > 1 {
169 let params_struct_name = format!("{}BatchParams", struct_name);
170 let _ = writeln!(out, "#[derive(Debug, Clone)]");
171 let _ = writeln!(out, "pub struct {} {{", params_struct_name);
172 for param in params {
173 let _ = writeln!(out, " pub {}: {},", param.field_name, param.full_type);
174 }
175 let _ = writeln!(out, "}}");
176 let _ = writeln!(out);
177
178 let _ = writeln!(
180 out,
181 "pub async fn {}(pool: &sqlx::PgPool, items: &[{}]) -> Result<(), sqlx::Error> {{",
182 batch_fn_name, params_struct_name
183 );
184 let _ = writeln!(out, " let mut tx = pool.begin().await?;");
185 let _ = writeln!(out, " for item in items {{");
186
187 let struct_bind_params: String = params
189 .iter()
190 .map(|p| {
191 if p.neutral_type.starts_with("enum::") {
192 let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
193 let rust_type = enum_type_name(enum_name, &self.manifest.naming);
194 format!(", item.{} as &{}", p.field_name, rust_type)
195 } else {
196 format!(", item.{}", p.field_name)
197 }
198 })
199 .collect();
200
201 let _ = writeln!(
202 out,
203 " sqlx::query!(\"{}\"{})",
204 sql, struct_bind_params
205 );
206 let _ = writeln!(out, " .execute(&mut *tx)");
207 let _ = writeln!(out, " .await?;");
208 let _ = writeln!(out, " }}");
209 let _ = writeln!(out, " tx.commit().await?;");
210 let _ = writeln!(out, " Ok(())");
211 } else if params.len() == 1 {
212 let param = ¶ms[0];
214 let _ = writeln!(
215 out,
216 "pub async fn {}(pool: &sqlx::PgPool, items: &[{}]) -> Result<(), sqlx::Error> {{",
217 batch_fn_name, param.full_type
218 );
219 let _ = writeln!(out, " let mut tx = pool.begin().await?;");
220 let _ = writeln!(out, " for item in items {{");
221 let _ = writeln!(out, " sqlx::query!(\"{}\", item)", sql);
222 let _ = writeln!(out, " .execute(&mut *tx)");
223 let _ = writeln!(out, " .await?;");
224 let _ = writeln!(out, " }}");
225 let _ = writeln!(out, " tx.commit().await?;");
226 let _ = writeln!(out, " Ok(())");
227 } else {
228 let _ = writeln!(
230 out,
231 "pub async fn {}(pool: &sqlx::PgPool, count: usize) -> Result<(), sqlx::Error> {{",
232 batch_fn_name
233 );
234 let _ = writeln!(out, " let mut tx = pool.begin().await?;");
235 let _ = writeln!(out, " for _ in 0..count {{");
236 let _ = writeln!(out, " sqlx::query!(\"{}\")", sql);
237 let _ = writeln!(out, " .execute(&mut *tx)");
238 let _ = writeln!(out, " .await?;");
239 let _ = writeln!(out, " }}");
240 let _ = writeln!(out, " tx.commit().await?;");
241 let _ = writeln!(out, " Ok(())");
242 }
243
244 let _ = write!(out, "}}");
245 return Ok(out);
246 }
247
248 let return_type = match &analyzed.command {
250 QueryCommand::One => struct_name.to_string(),
251 QueryCommand::Many => format!("Vec<{}>", struct_name),
252 QueryCommand::Exec => "()".to_string(),
253 QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
254 QueryCommand::ExecRows => "u64".to_string(),
255 QueryCommand::Batch => unreachable!(),
256 };
257
258 let _ = writeln!(
260 out,
261 "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
262 func_name,
263 param_parts.join(", "),
264 return_type
265 );
266
267 let has_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
269
270 let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
271
272 if is_exec_rows {
273 if has_row_struct && !analyzed.columns.is_empty() {
274 let _ = write!(
275 out,
276 " let result = sqlx::query_as!({}, \"{}\"{})",
277 struct_name, sql, bind_params
278 );
279 } else {
280 let _ = write!(
281 out,
282 " let result = sqlx::query!(\"{}\"{})",
283 sql, bind_params
284 );
285 }
286 } else if has_row_struct && !analyzed.columns.is_empty() {
287 let _ = write!(
288 out,
289 " sqlx::query_as!({}, \"{}\"{})",
290 struct_name, sql, bind_params
291 );
292 } else {
293 let _ = write!(out, " sqlx::query!(\"{}\"{})", sql, bind_params);
294 }
295
296 let _ = writeln!(out);
297
298 let fetch_method = match &analyzed.command {
300 QueryCommand::One => ".fetch_one(pool)",
301 QueryCommand::Many => ".fetch_all(pool)",
302 QueryCommand::Exec => ".execute(pool)",
303 QueryCommand::ExecResult => ".execute(pool)",
304 QueryCommand::ExecRows => ".execute(pool)",
305 QueryCommand::Batch => unreachable!(),
306 };
307
308 let _ = write!(out, " {}", fetch_method);
309 let _ = writeln!(out);
310
311 match &analyzed.command {
313 QueryCommand::Exec => {
314 let _ = writeln!(out, " .await?;");
315 let _ = writeln!(out, " Ok(())");
316 }
317 QueryCommand::ExecRows => {
318 let _ = writeln!(out, " .await?;");
319 let _ = writeln!(out, " Ok(result.rows_affected())");
320 }
321 _ => {
322 let _ = writeln!(out, " .await");
323 }
324 }
325
326 let _ = write!(out, "}}");
327 Ok(out)
328 }
329
330 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
331 let mut out = String::with_capacity(256);
332 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
333
334 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
335 let _ = writeln!(
336 out,
337 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
338 enum_info.sql_name
339 );
340 let _ = writeln!(out, "pub enum {type_name} {{");
341
342 for value in &enum_info.values {
343 let variant = enum_variant_name(value, &self.manifest.naming);
344 let _ = writeln!(out, " {variant},");
345 }
346
347 let _ = write!(out, "}}");
348 Ok(out)
349 }
350
351 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
352 use scythe_backend::types::resolve_type;
353
354 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
355 let mut out = String::new();
356
357 let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
358 let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
359 let _ = writeln!(out, "pub struct {} {{", struct_name);
360 for field in &composite.fields {
361 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
362 .map(|t| t.into_owned())
363 .map_err(|e| {
364 ScytheError::new(
365 ErrorCode::InternalError,
366 format!("composite field type error: {}", e),
367 )
368 })?;
369 let _ = writeln!(
370 out,
371 " pub {}: {},",
372 to_snake_case(&field.name),
373 rust_type
374 );
375 }
376 let _ = write!(out, "}}");
377 Ok(out)
378 }
379}
380
381fn rewrite_sql_for_enums(
387 sql: &str,
388 columns: &[AnalyzedColumn],
389 manifest: &BackendManifest,
390) -> String {
391 let enum_cols: Vec<(&str, String)> = columns
392 .iter()
393 .filter_map(|col| {
394 if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
395 let rust_type = enum_type_name(enum_name, &manifest.naming);
396 let annotation = if col.nullable {
397 format!("Option<{}>", rust_type)
398 } else {
399 rust_type
400 };
401 Some((col.name.as_str(), annotation))
402 } else {
403 None
404 }
405 })
406 .collect();
407
408 if enum_cols.is_empty() {
409 return sql.to_string();
410 }
411
412 let mut result = sql.to_string();
413 for (col_name, annotation) in &enum_cols {
414 let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
415 if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
416 let select_part = &result[..from_pos];
417 let rest = &result[from_pos..];
418 let new_select = replace_column_in_select(select_part, col_name, &alias);
419 result = format!("{}{}", new_select, rest);
420 }
421 }
422 result
423}
424
425fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
426 let mut result = select.to_string();
427 let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
428 for pattern in &patterns {
429 if let Some(pos) = result.rfind(pattern.as_str()) {
430 let after = pos + pattern.len();
431 let next_char = result[after..].chars().next();
432 if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
433 let prefix = &result[..pos + pattern.len() - col_name.len()];
434 let suffix = &result[after..];
435 result = format!("{}{}{}", prefix, replacement, suffix);
436 break;
437 }
438 }
439 }
440 result
441}