Skip to main content

scythe_codegen/backends/
sqlx.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14use crate::singularize;
15
16/// Default embedded manifest TOML for rust-sqlx, used as fallback.
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
18
19/// SqlxBackend generates Rust code targeting the sqlx crate.
20pub struct SqlxBackend {
21    manifest: BackendManifest,
22}
23
24impl SqlxBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        // Multi-DB backend — accept all engines, load PG manifest as default
27        // TODO: Load engine-specific manifests once they exist
28        match engine {
29            "postgresql" | "postgres" | "pg" | "mysql" | "mariadb" | "sqlite" | "sqlite3" => {}
30            _ => {
31                return Err(ScytheError::new(
32                    ErrorCode::InternalError,
33                    format!("unsupported engine '{}' for rust-sqlx backend", engine),
34                ));
35            }
36        }
37        let manifest = load_sqlx_manifest()?;
38        Ok(Self { manifest })
39    }
40}
41
42fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
43    let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
44    if manifest_path.exists() {
45        load_manifest(manifest_path).map_err(|e| {
46            ScytheError::new(
47                ErrorCode::InternalError,
48                format!("failed to load manifest: {e}"),
49            )
50        })
51    } else {
52        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
53            ScytheError::new(
54                ErrorCode::InternalError,
55                format!("failed to parse embedded manifest: {e}"),
56            )
57        })
58    }
59}
60
61impl CodegenBackend for SqlxBackend {
62    fn name(&self) -> &str {
63        "rust-sqlx"
64    }
65
66    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
67        &self.manifest
68    }
69
70    fn supported_engines(&self) -> &[&str] {
71        &["postgresql", "mysql", "sqlite"]
72    }
73
74    fn file_header(&self) -> String {
75        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
76            .to_string()
77    }
78
79    fn generate_row_struct(
80        &self,
81        query_name: &str,
82        columns: &[ResolvedColumn],
83    ) -> Result<String, ScytheError> {
84        let struct_name = row_struct_name(query_name, &self.manifest.naming);
85        let mut out = String::new();
86
87        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
88        let _ = writeln!(out, "pub struct {} {{", struct_name);
89
90        for col in columns {
91            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
92        }
93
94        let _ = write!(out, "}}");
95        Ok(out)
96    }
97
98    fn generate_model_struct(
99        &self,
100        table_name: &str,
101        columns: &[ResolvedColumn],
102    ) -> Result<String, ScytheError> {
103        let singular = singularize(table_name);
104        let struct_name = to_pascal_case(&singular).into_owned();
105        let mut out = String::new();
106
107        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
108        let _ = writeln!(out, "pub struct {} {{", struct_name);
109
110        for col in columns {
111            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
112        }
113
114        let _ = write!(out, "}}");
115        Ok(out)
116    }
117
118    fn generate_query_fn(
119        &self,
120        analyzed: &AnalyzedQuery,
121        struct_name: &str,
122        _columns: &[ResolvedColumn],
123        params: &[ResolvedParam],
124    ) -> Result<String, ScytheError> {
125        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
126        let mut out = String::new();
127
128        // Deprecated annotation
129        if let Some(ref msg) = analyzed.deprecated {
130            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
131        }
132
133        // Build parameter list
134        let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
135        for param in params {
136            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
137        }
138
139        // Clean SQL
140        let sql_raw = super::clean_sql_with_optional(
141            &analyzed.sql,
142            &analyzed.optional_params,
143            &analyzed.params,
144        );
145        let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
146
147        // Build bind params string
148        let bind_params: String = analyzed
149            .params
150            .iter()
151            .map(|p| {
152                let param_name = to_snake_case(&p.name);
153                if p.neutral_type.starts_with("enum::") {
154                    let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
155                    let rust_type = enum_type_name(enum_name, &self.manifest.naming);
156                    format!(", {} as &{}", param_name, rust_type)
157                } else {
158                    format!(", {}", param_name)
159                }
160            })
161            .collect();
162
163        // Handle :batch separately — generates a different function signature
164        if matches!(analyzed.command, QueryCommand::Batch) {
165            let batch_fn_name = format!("{}_batch", func_name);
166
167            // Generate params struct if >1 param
168            if params.len() > 1 {
169                let params_struct_name = format!("{}BatchParams", struct_name);
170                let _ = writeln!(out, "#[derive(Debug, Clone)]");
171                let _ = writeln!(out, "pub struct {} {{", params_struct_name);
172                for param in params {
173                    let _ = writeln!(out, "    pub {}: {},", param.field_name, param.full_type);
174                }
175                let _ = writeln!(out, "}}");
176                let _ = writeln!(out);
177
178                // Batch function takes &[ParamsStruct]
179                let _ = writeln!(
180                    out,
181                    "pub async fn {}(pool: &sqlx::PgPool, items: &[{}]) -> Result<(), sqlx::Error> {{",
182                    batch_fn_name, params_struct_name
183                );
184                let _ = writeln!(out, "    let mut tx = pool.begin().await?;");
185                let _ = writeln!(out, "    for item in items {{");
186
187                // Build bind params from struct fields
188                let struct_bind_params: String = params
189                    .iter()
190                    .map(|p| {
191                        if p.neutral_type.starts_with("enum::") {
192                            let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
193                            let rust_type = enum_type_name(enum_name, &self.manifest.naming);
194                            format!(", item.{} as &{}", p.field_name, rust_type)
195                        } else {
196                            format!(", item.{}", p.field_name)
197                        }
198                    })
199                    .collect();
200
201                let _ = writeln!(
202                    out,
203                    "        sqlx::query!(\"{}\"{})",
204                    sql, struct_bind_params
205                );
206                let _ = writeln!(out, "            .execute(&mut *tx)");
207                let _ = writeln!(out, "            .await?;");
208                let _ = writeln!(out, "    }}");
209                let _ = writeln!(out, "    tx.commit().await?;");
210                let _ = writeln!(out, "    Ok(())");
211            } else if params.len() == 1 {
212                // Single param — takes a slice of that type
213                let param = &params[0];
214                let _ = writeln!(
215                    out,
216                    "pub async fn {}(pool: &sqlx::PgPool, items: &[{}]) -> Result<(), sqlx::Error> {{",
217                    batch_fn_name, param.full_type
218                );
219                let _ = writeln!(out, "    let mut tx = pool.begin().await?;");
220                let _ = writeln!(out, "    for item in items {{");
221                let _ = writeln!(out, "        sqlx::query!(\"{}\", item)", sql);
222                let _ = writeln!(out, "            .execute(&mut *tx)");
223                let _ = writeln!(out, "            .await?;");
224                let _ = writeln!(out, "    }}");
225                let _ = writeln!(out, "    tx.commit().await?;");
226                let _ = writeln!(out, "    Ok(())");
227            } else {
228                // No params — just execute N times (unusual but valid)
229                let _ = writeln!(
230                    out,
231                    "pub async fn {}(pool: &sqlx::PgPool, count: usize) -> Result<(), sqlx::Error> {{",
232                    batch_fn_name
233                );
234                let _ = writeln!(out, "    let mut tx = pool.begin().await?;");
235                let _ = writeln!(out, "    for _ in 0..count {{");
236                let _ = writeln!(out, "        sqlx::query!(\"{}\")", sql);
237                let _ = writeln!(out, "            .execute(&mut *tx)");
238                let _ = writeln!(out, "            .await?;");
239                let _ = writeln!(out, "    }}");
240                let _ = writeln!(out, "    tx.commit().await?;");
241                let _ = writeln!(out, "    Ok(())");
242            }
243
244            let _ = write!(out, "}}");
245            return Ok(out);
246        }
247
248        // Return type for non-batch commands
249        let return_type = match &analyzed.command {
250            QueryCommand::One => struct_name.to_string(),
251            QueryCommand::Many => format!("Vec<{}>", struct_name),
252            QueryCommand::Exec => "()".to_string(),
253            QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
254            QueryCommand::ExecRows => "u64".to_string(),
255            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
256        };
257
258        // Function signature
259        let _ = writeln!(
260            out,
261            "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
262            func_name,
263            param_parts.join(", "),
264            return_type
265        );
266
267        // Query body
268        let has_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
269
270        let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
271
272        if is_exec_rows {
273            if has_row_struct && !analyzed.columns.is_empty() {
274                let _ = write!(
275                    out,
276                    "    let result = sqlx::query_as!({}, \"{}\"{})",
277                    struct_name, sql, bind_params
278                );
279            } else {
280                let _ = write!(
281                    out,
282                    "    let result = sqlx::query!(\"{}\"{})",
283                    sql, bind_params
284                );
285            }
286        } else if has_row_struct && !analyzed.columns.is_empty() {
287            let _ = write!(
288                out,
289                "    sqlx::query_as!({}, \"{}\"{})",
290                struct_name, sql, bind_params
291            );
292        } else {
293            let _ = write!(out, "    sqlx::query!(\"{}\"{})", sql, bind_params);
294        }
295
296        let _ = writeln!(out);
297
298        // Fetch method
299        let fetch_method = match &analyzed.command {
300            QueryCommand::One => ".fetch_one(pool)",
301            QueryCommand::Many => ".fetch_all(pool)",
302            QueryCommand::Exec => ".execute(pool)",
303            QueryCommand::ExecResult => ".execute(pool)",
304            QueryCommand::ExecRows => ".execute(pool)",
305            QueryCommand::Batch | QueryCommand::Grouped => unreachable!(),
306        };
307
308        let _ = write!(out, "        {}", fetch_method);
309        let _ = writeln!(out);
310
311        // Post-processing for exec variants
312        match &analyzed.command {
313            QueryCommand::Exec => {
314                let _ = writeln!(out, "        .await?;");
315                let _ = writeln!(out, "    Ok(())");
316            }
317            QueryCommand::ExecRows => {
318                let _ = writeln!(out, "        .await?;");
319                let _ = writeln!(out, "    Ok(result.rows_affected())");
320            }
321            _ => {
322                let _ = writeln!(out, "        .await");
323            }
324        }
325
326        let _ = write!(out, "}}");
327        Ok(out)
328    }
329
330    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
331        let mut out = String::with_capacity(256);
332        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
333
334        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
335        let _ = writeln!(
336            out,
337            "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
338            enum_info.sql_name
339        );
340        let _ = writeln!(out, "pub enum {type_name} {{");
341
342        for value in &enum_info.values {
343            let variant = enum_variant_name(value, &self.manifest.naming);
344            let _ = writeln!(out, "    {variant},");
345        }
346
347        let _ = write!(out, "}}");
348        Ok(out)
349    }
350
351    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
352        use scythe_backend::types::resolve_type;
353
354        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
355        let mut out = String::new();
356
357        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
358        let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
359        let _ = writeln!(out, "pub struct {} {{", struct_name);
360        for field in &composite.fields {
361            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
362                .map(|t| t.into_owned())
363                .map_err(|e| {
364                    ScytheError::new(
365                        ErrorCode::InternalError,
366                        format!("composite field type error: {}", e),
367                    )
368                })?;
369            let _ = writeln!(
370                out,
371                "    pub {}: {},",
372                to_snake_case(&field.name),
373                rust_type
374            );
375        }
376        let _ = write!(out, "}}");
377        Ok(out)
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Internal helpers (moved from old modules)
383// ---------------------------------------------------------------------------
384
385/// Rewrite SQL to add enum type annotations for sqlx.
386fn rewrite_sql_for_enums(
387    sql: &str,
388    columns: &[AnalyzedColumn],
389    manifest: &BackendManifest,
390) -> String {
391    let enum_cols: Vec<(&str, String)> = columns
392        .iter()
393        .filter_map(|col| {
394            if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
395                let rust_type = enum_type_name(enum_name, &manifest.naming);
396                let annotation = if col.nullable {
397                    format!("Option<{}>", rust_type)
398                } else {
399                    rust_type
400                };
401                Some((col.name.as_str(), annotation))
402            } else {
403                None
404            }
405        })
406        .collect();
407
408    if enum_cols.is_empty() {
409        return sql.to_string();
410    }
411
412    let mut result = sql.to_string();
413    for (col_name, annotation) in &enum_cols {
414        let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
415        if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
416            let select_part = &result[..from_pos];
417            let rest = &result[from_pos..];
418            let new_select = replace_column_in_select(select_part, col_name, &alias);
419            result = format!("{}{}", new_select, rest);
420        }
421    }
422    result
423}
424
425fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
426    let mut result = select.to_string();
427    let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
428    for pattern in &patterns {
429        if let Some(pos) = result.rfind(pattern.as_str()) {
430            let after = pos + pattern.len();
431            let next_char = result[after..].chars().next();
432            if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
433                let prefix = &result[..pos + pattern.len() - col_name.len()];
434                let suffix = &result[after..];
435                result = format!("{}{}{}", prefix, replacement, suffix);
436                break;
437            }
438        }
439    }
440    result
441}