Skip to main content

scythe_codegen/backends/
sqlx.rs

1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14use crate::singularize;
15
16/// Default embedded manifest TOML for rust-sqlx, used as fallback.
17const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
18
19/// SqlxBackend generates Rust code targeting the sqlx crate.
20pub struct SqlxBackend {
21    manifest: BackendManifest,
22}
23
24impl SqlxBackend {
25    pub fn new() -> Result<Self, ScytheError> {
26        let manifest = load_sqlx_manifest()?;
27        Ok(Self { manifest })
28    }
29
30    /// Access the internal manifest (for backward-compat callers).
31    pub fn manifest(&self) -> &BackendManifest {
32        &self.manifest
33    }
34}
35
36fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
37    let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
38    if manifest_path.exists() {
39        load_manifest(manifest_path).map_err(|e| {
40            ScytheError::new(
41                ErrorCode::InternalError,
42                format!("failed to load manifest: {e}"),
43            )
44        })
45    } else {
46        toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
47            ScytheError::new(
48                ErrorCode::InternalError,
49                format!("failed to parse embedded manifest: {e}"),
50            )
51        })
52    }
53}
54
55impl CodegenBackend for SqlxBackend {
56    fn name(&self) -> &str {
57        "rust-sqlx"
58    }
59
60    fn generate_row_struct(
61        &self,
62        query_name: &str,
63        columns: &[ResolvedColumn],
64    ) -> Result<String, ScytheError> {
65        let struct_name = row_struct_name(query_name, &self.manifest.naming);
66        let mut out = String::new();
67
68        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
69        let _ = writeln!(out, "pub struct {} {{", struct_name);
70
71        for col in columns {
72            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
73        }
74
75        let _ = write!(out, "}}");
76        Ok(out)
77    }
78
79    fn generate_model_struct(
80        &self,
81        table_name: &str,
82        columns: &[ResolvedColumn],
83    ) -> Result<String, ScytheError> {
84        let singular = singularize(table_name);
85        let struct_name = to_pascal_case(&singular).into_owned();
86        let mut out = String::new();
87
88        let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
89        let _ = writeln!(out, "pub struct {} {{", struct_name);
90
91        for col in columns {
92            let _ = writeln!(out, "    pub {}: {},", col.field_name, col.full_type);
93        }
94
95        let _ = write!(out, "}}");
96        Ok(out)
97    }
98
99    fn generate_query_fn(
100        &self,
101        analyzed: &AnalyzedQuery,
102        struct_name: &str,
103        _columns: &[ResolvedColumn],
104        params: &[ResolvedParam],
105    ) -> Result<String, ScytheError> {
106        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
107        let mut out = String::new();
108
109        // Deprecated annotation
110        if let Some(ref msg) = analyzed.deprecated {
111            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
112        }
113
114        // Build parameter list
115        let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
116        for param in params {
117            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
118        }
119
120        // Return type
121        let return_type = match &analyzed.command {
122            QueryCommand::One => struct_name.to_string(),
123            QueryCommand::Many => format!("Vec<{}>", struct_name),
124            QueryCommand::Exec => "()".to_string(),
125            QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
126            QueryCommand::ExecRows => "u64".to_string(),
127            QueryCommand::Batch => format!("Vec<{}>", struct_name),
128        };
129
130        // Function signature
131        let _ = writeln!(
132            out,
133            "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
134            func_name,
135            param_parts.join(", "),
136            return_type
137        );
138
139        // Clean SQL
140        let sql_raw = super::clean_sql(&analyzed.sql);
141        let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
142
143        // Query body
144        let has_row_struct = matches!(
145            analyzed.command,
146            QueryCommand::One | QueryCommand::Many | QueryCommand::Batch
147        );
148
149        // Build bind params string
150        let bind_params: String = analyzed
151            .params
152            .iter()
153            .map(|p| {
154                let param_name = to_snake_case(&p.name);
155                if p.neutral_type.starts_with("enum::") {
156                    let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
157                    let rust_type = enum_type_name(enum_name, &self.manifest.naming);
158                    format!(", {} as &{}", param_name, rust_type)
159                } else {
160                    format!(", {}", param_name)
161                }
162            })
163            .collect();
164
165        let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
166
167        if is_exec_rows {
168            if has_row_struct && !analyzed.columns.is_empty() {
169                let _ = write!(
170                    out,
171                    "    let result = sqlx::query_as!({}, \"{}\"{})",
172                    struct_name, sql, bind_params
173                );
174            } else {
175                let _ = write!(
176                    out,
177                    "    let result = sqlx::query!(\"{}\"{})",
178                    sql, bind_params
179                );
180            }
181        } else if has_row_struct && !analyzed.columns.is_empty() {
182            let _ = write!(
183                out,
184                "    sqlx::query_as!({}, \"{}\"{})",
185                struct_name, sql, bind_params
186            );
187        } else {
188            let _ = write!(out, "    sqlx::query!(\"{}\"{})", sql, bind_params);
189        }
190
191        let _ = writeln!(out);
192
193        // Fetch method
194        let fetch_method = match &analyzed.command {
195            QueryCommand::One => ".fetch_one(pool)",
196            QueryCommand::Many => ".fetch_all(pool)",
197            QueryCommand::Exec => ".execute(pool)",
198            QueryCommand::ExecResult => ".execute(pool)",
199            QueryCommand::ExecRows => ".execute(pool)",
200            QueryCommand::Batch => ".fetch_all(pool)",
201        };
202
203        let _ = write!(out, "        {}", fetch_method);
204        let _ = writeln!(out);
205
206        // Post-processing for exec variants
207        match &analyzed.command {
208            QueryCommand::Exec => {
209                let _ = writeln!(out, "        .await?;");
210                let _ = writeln!(out, "    Ok(())");
211            }
212            QueryCommand::ExecRows => {
213                let _ = writeln!(out, "        .await?;");
214                let _ = writeln!(out, "    Ok(result.rows_affected())");
215            }
216            _ => {
217                let _ = writeln!(out, "        .await");
218            }
219        }
220
221        let _ = write!(out, "}}");
222        Ok(out)
223    }
224
225    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
226        let mut out = String::with_capacity(256);
227        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
228
229        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
230        let _ = writeln!(
231            out,
232            "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
233            enum_info.sql_name
234        );
235        let _ = writeln!(out, "pub enum {type_name} {{");
236
237        for value in &enum_info.values {
238            let variant = enum_variant_name(value, &self.manifest.naming);
239            let _ = writeln!(out, "    {variant},");
240        }
241
242        let _ = write!(out, "}}");
243        Ok(out)
244    }
245
246    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
247        use scythe_backend::types::resolve_type;
248
249        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
250        let mut out = String::new();
251
252        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
253        let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
254        let _ = writeln!(out, "pub struct {} {{", struct_name);
255        for field in &composite.fields {
256            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
257                .map(|t| t.into_owned())
258                .map_err(|e| {
259                    ScytheError::new(
260                        ErrorCode::InternalError,
261                        format!("composite field type error: {}", e),
262                    )
263                })?;
264            let _ = writeln!(
265                out,
266                "    pub {}: {},",
267                to_snake_case(&field.name),
268                rust_type
269            );
270        }
271        let _ = write!(out, "}}");
272        Ok(out)
273    }
274}
275
276// ---------------------------------------------------------------------------
277// Internal helpers (moved from old modules)
278// ---------------------------------------------------------------------------
279
280/// Rewrite SQL to add enum type annotations for sqlx.
281fn rewrite_sql_for_enums(
282    sql: &str,
283    columns: &[AnalyzedColumn],
284    manifest: &BackendManifest,
285) -> String {
286    let enum_cols: Vec<(&str, String)> = columns
287        .iter()
288        .filter_map(|col| {
289            if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
290                let rust_type = enum_type_name(enum_name, &manifest.naming);
291                let annotation = if col.nullable {
292                    format!("Option<{}>", rust_type)
293                } else {
294                    rust_type
295                };
296                Some((col.name.as_str(), annotation))
297            } else {
298                None
299            }
300        })
301        .collect();
302
303    if enum_cols.is_empty() {
304        return sql.to_string();
305    }
306
307    let mut result = sql.to_string();
308    for (col_name, annotation) in &enum_cols {
309        let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
310        if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
311            let select_part = &result[..from_pos];
312            let rest = &result[from_pos..];
313            let new_select = replace_column_in_select(select_part, col_name, &alias);
314            result = format!("{}{}", new_select, rest);
315        }
316    }
317    result
318}
319
320fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
321    let mut result = select.to_string();
322    let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
323    for pattern in &patterns {
324        if let Some(pos) = result.rfind(pattern.as_str()) {
325            let after = pos + pattern.len();
326            let next_char = result[after..].chars().next();
327            if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
328                let prefix = &result[..pos + pattern.len() - col_name.len()];
329                let suffix = &result[after..];
330                result = format!("{}{}{}", prefix, replacement, suffix);
331                break;
332            }
333        }
334    }
335    result
336}