Skip to main content

scythe_codegen/backends/
sqlx.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14use crate::singularize;
15
16/// Default embedded manifest TOML for rust-sqlx, used as fallback.
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
18
19/// SqlxBackend generates Rust code targeting the sqlx crate.
20pub struct SqlxBackend {
21    manifest: BackendManifest,
22}
23
24impl SqlxBackend {
25    pub fn new(engine: &str) -> Result<Self, ScytheError> {
26        // Multi-DB backend — accept all engines, load PG manifest as default
27        // TODO: Load engine-specific manifests once they exist
28        match engine {
29            "postgresql" | "postgres" | "pg" | "mysql" | "mariadb" | "sqlite" | "sqlite3" => {}
30            _ => {
31                return Err(ScytheError::new(
32                    ErrorCode::InternalError,
33                    format!("unsupported engine '{}' for rust-sqlx backend", engine),
34                ));
35            }
36        }
37        let manifest = load_sqlx_manifest()?;
38        Ok(Self { manifest })
39    }
40}
41
42fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
43    let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
44    if manifest_path.exists() {
45        load_manifest(manifest_path).map_err(|e| {
46            ScytheError::new(
47                ErrorCode::InternalError,
48                format!("failed to load manifest: {e}"),
49            )
50        })
51    } else {
52        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
53            ScytheError::new(
54                ErrorCode::InternalError,
55                format!("failed to parse embedded manifest: {e}"),
56            )
57        })
58    }
59}
60
61impl CodegenBackend for SqlxBackend {
62    fn name(&self) -> &str {
63        "rust-sqlx"
64    }
65
66    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
67        &self.manifest
68    }
69
70    fn supported_engines(&self) -> &[&str] {
71        &["postgresql", "mysql", "sqlite"]
72    }
73
74    fn file_header(&self) -> String {
75        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
76            .to_string()
77    }
78
79    fn generate_row_struct(
80        &self,
81        query_name: &str,
82        columns: &[ResolvedColumn],
83    ) -> Result<String, ScytheError> {
84        let struct_name = row_struct_name(query_name, &self.manifest.naming);
85        let mut out = String::new();
86
87        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
88        let _ = writeln!(out, "pub struct {} {{", struct_name);
89
90        for col in columns {
91            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
92        }
93
94        let _ = write!(out, "}}");
95        Ok(out)
96    }
97
98    fn generate_model_struct(
99        &self,
100        table_name: &str,
101        columns: &[ResolvedColumn],
102    ) -> Result<String, ScytheError> {
103        let singular = singularize(table_name);
104        let struct_name = to_pascal_case(&singular).into_owned();
105        let mut out = String::new();
106
107        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
108        let _ = writeln!(out, "pub struct {} {{", struct_name);
109
110        for col in columns {
111            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
112        }
113
114        let _ = write!(out, "}}");
115        Ok(out)
116    }
117
118    fn generate_query_fn(
119        &self,
120        analyzed: &AnalyzedQuery,
121        struct_name: &str,
122        _columns: &[ResolvedColumn],
123        params: &[ResolvedParam],
124    ) -> Result<String, ScytheError> {
125        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
126        let mut out = String::new();
127
128        // Deprecated annotation
129        if let Some(ref msg) = analyzed.deprecated {
130            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
131        }
132
133        // Build parameter list
134        let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
135        for param in params {
136            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
137        }
138
139        // Return type
140        let return_type = match &analyzed.command {
141            QueryCommand::One => struct_name.to_string(),
142            QueryCommand::Many => format!("Vec<{}>", struct_name),
143            QueryCommand::Exec => "()".to_string(),
144            QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
145            QueryCommand::ExecRows => "u64".to_string(),
146            QueryCommand::Batch => format!("Vec<{}>", struct_name),
147        };
148
149        // Function signature
150        let _ = writeln!(
151            out,
152            "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
153            func_name,
154            param_parts.join(", "),
155            return_type
156        );
157
158        // Clean SQL
159        let sql_raw = super::clean_sql(&analyzed.sql);
160        let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
161
162        // Query body
163        let has_row_struct = matches!(
164            analyzed.command,
165            QueryCommand::One | QueryCommand::Many | QueryCommand::Batch
166        );
167
168        // Build bind params string
169        let bind_params: String = analyzed
170            .params
171            .iter()
172            .map(|p| {
173                let param_name = to_snake_case(&p.name);
174                if p.neutral_type.starts_with("enum::") {
175                    let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
176                    let rust_type = enum_type_name(enum_name, &self.manifest.naming);
177                    format!(", {} as &{}", param_name, rust_type)
178                } else {
179                    format!(", {}", param_name)
180                }
181            })
182            .collect();
183
184        let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
185
186        if is_exec_rows {
187            if has_row_struct && !analyzed.columns.is_empty() {
188                let _ = write!(
189                    out,
190                    "    let result = sqlx::query_as!({}, \"{}\"{})",
191                    struct_name, sql, bind_params
192                );
193            } else {
194                let _ = write!(
195                    out,
196                    "    let result = sqlx::query!(\"{}\"{})",
197                    sql, bind_params
198                );
199            }
200        } else if has_row_struct && !analyzed.columns.is_empty() {
201            let _ = write!(
202                out,
203                "    sqlx::query_as!({}, \"{}\"{})",
204                struct_name, sql, bind_params
205            );
206        } else {
207            let _ = write!(out, "    sqlx::query!(\"{}\"{})", sql, bind_params);
208        }
209
210        let _ = writeln!(out);
211
212        // Fetch method
213        let fetch_method = match &analyzed.command {
214            QueryCommand::One => ".fetch_one(pool)",
215            QueryCommand::Many => ".fetch_all(pool)",
216            QueryCommand::Exec => ".execute(pool)",
217            QueryCommand::ExecResult => ".execute(pool)",
218            QueryCommand::ExecRows => ".execute(pool)",
219            QueryCommand::Batch => ".fetch_all(pool)",
220        };
221
222        let _ = write!(out, "        {}", fetch_method);
223        let _ = writeln!(out);
224
225        // Post-processing for exec variants
226        match &analyzed.command {
227            QueryCommand::Exec => {
228                let _ = writeln!(out, "        .await?;");
229                let _ = writeln!(out, "    Ok(())");
230            }
231            QueryCommand::ExecRows => {
232                let _ = writeln!(out, "        .await?;");
233                let _ = writeln!(out, "    Ok(result.rows_affected())");
234            }
235            _ => {
236                let _ = writeln!(out, "        .await");
237            }
238        }
239
240        let _ = write!(out, "}}");
241        Ok(out)
242    }
243
244    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
245        let mut out = String::with_capacity(256);
246        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
247
248        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
249        let _ = writeln!(
250            out,
251            "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
252            enum_info.sql_name
253        );
254        let _ = writeln!(out, "pub enum {type_name} {{");
255
256        for value in &enum_info.values {
257            let variant = enum_variant_name(value, &self.manifest.naming);
258            let _ = writeln!(out, "    {variant},");
259        }
260
261        let _ = write!(out, "}}");
262        Ok(out)
263    }
264
265    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
266        use scythe_backend::types::resolve_type;
267
268        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
269        let mut out = String::new();
270
271        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
272        let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
273        let _ = writeln!(out, "pub struct {} {{", struct_name);
274        for field in &composite.fields {
275            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
276                .map(|t| t.into_owned())
277                .map_err(|e| {
278                    ScytheError::new(
279                        ErrorCode::InternalError,
280                        format!("composite field type error: {}", e),
281                    )
282                })?;
283            let _ = writeln!(
284                out,
285                "    pub {}: {},",
286                to_snake_case(&field.name),
287                rust_type
288            );
289        }
290        let _ = write!(out, "}}");
291        Ok(out)
292    }
293}
294
295// ---------------------------------------------------------------------------
296// Internal helpers (moved from old modules)
297// ---------------------------------------------------------------------------
298
299/// Rewrite SQL to add enum type annotations for sqlx.
300fn rewrite_sql_for_enums(
301    sql: &str,
302    columns: &[AnalyzedColumn],
303    manifest: &BackendManifest,
304) -> String {
305    let enum_cols: Vec<(&str, String)> = columns
306        .iter()
307        .filter_map(|col| {
308            if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
309                let rust_type = enum_type_name(enum_name, &manifest.naming);
310                let annotation = if col.nullable {
311                    format!("Option<{}>", rust_type)
312                } else {
313                    rust_type
314                };
315                Some((col.name.as_str(), annotation))
316            } else {
317                None
318            }
319        })
320        .collect();
321
322    if enum_cols.is_empty() {
323        return sql.to_string();
324    }
325
326    let mut result = sql.to_string();
327    for (col_name, annotation) in &enum_cols {
328        let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
329        if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
330            let select_part = &result[..from_pos];
331            let rest = &result[from_pos..];
332            let new_select = replace_column_in_select(select_part, col_name, &alias);
333            result = format!("{}{}", new_select, rest);
334        }
335    }
336    result
337}
338
339fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
340    let mut result = select.to_string();
341    let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
342    for pattern in &patterns {
343        if let Some(pos) = result.rfind(pattern.as_str()) {
344            let after = pos + pattern.len();
345            let next_char = result[after..].chars().next();
346            if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
347                let prefix = &result[..pos + pattern.len() - col_name.len()];
348                let suffix = &result[after..];
349                result = format!("{}{}{}", prefix, replacement, suffix);
350                break;
351            }
352        }
353    }
354    result
355}