Skip to main content

scythe_codegen/backends/
sqlx.rs

1use scythe_backend::manifest::BackendManifest;
2use scythe_backend::naming::{
3    enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
4};
5use std::fmt::Write;
6
7use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
8use scythe_core::errors::{ErrorCode, ScytheError};
9use scythe_core::parser::QueryCommand;
10
11use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
12use crate::singularize;
13
14/// Default embedded manifest TOML for rust-sqlx, used as fallback.
15const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
16const DEFAULT_MANIFEST_MARIADB: &str = include_str!("../../manifests/rust-sqlx.mariadb.toml");
17const DEFAULT_MANIFEST_REDSHIFT: &str = include_str!("../../manifests/rust-sqlx.redshift.toml");
18
19/// SqlxBackend generates Rust code targeting the sqlx crate.
20pub struct SqlxBackend {
21    manifest: BackendManifest,
22    engine: String,
23    /// When true, only emit struct/enum definitions (no query functions).
24    /// This avoids the `sqlx::query_as!()` macro which requires `DATABASE_URL` at compile time.
25    structs_only: bool,
26}
27
28impl SqlxBackend {
29    pub fn new(engine: &str) -> Result<Self, ScytheError> {
30        // Multi-DB backend — accept all engines, load PG manifest as default
31        // TODO: Load engine-specific manifests once they exist
32        match engine {
33            "postgresql" | "postgres" | "pg" | "mysql" | "mariadb" | "sqlite" | "sqlite3"
34            | "redshift" => {}
35            _ => {
36                return Err(ScytheError::new(
37                    ErrorCode::InternalError,
38                    format!("unsupported engine '{}' for rust-sqlx backend", engine),
39                ));
40            }
41        }
42        let manifest = match engine {
43            "mariadb" => super::load_or_default_manifest(
44                "backends/rust-sqlx/manifest.toml",
45                DEFAULT_MANIFEST_MARIADB,
46            )?,
47            "redshift" => super::load_or_default_manifest(
48                "backends/rust-sqlx/manifest.toml",
49                DEFAULT_MANIFEST_REDSHIFT,
50            )?,
51            _ => super::load_or_default_manifest(
52                "backends/rust-sqlx/manifest.toml",
53                DEFAULT_MANIFEST_TOML,
54            )?,
55        };
56        Ok(Self {
57            manifest,
58            engine: engine.to_string(),
59            structs_only: false,
60        })
61    }
62}
63
64impl SqlxBackend {
65    /// Return true if this engine uses inline ENUMs (not named custom types).
66    ///
67    /// MySQL, MariaDB, and SQLite represent ENUMs as plain strings at the wire
68    /// level. sqlx's `#[derive(sqlx::Type)]` generates `type_info()` returning
69    /// `MySqlTypeInfo::__enum()` (ColumnType::String + ENUM flag), but the
70    /// server sends `ColumnType::Enum`. The PartialEq check in MySqlTypeInfo
71    /// fails because the r#type fields differ, producing a runtime
72    /// "mismatched types" ColumnDecode error.
73    ///
74    /// For these engines, row struct fields must use `String` (or `Option<String>`)
75    /// instead of the generated Rust enum type.
76    fn uses_inline_enums(&self) -> bool {
77        matches!(
78            self.engine.as_str(),
79            "mysql" | "mariadb" | "sqlite" | "sqlite3"
80        )
81    }
82
83    /// Resolve the field type for a row struct column.
84    ///
85    /// For engines that use inline ENUMs, enum-typed columns are mapped to
86    /// `String` / `Option<String>` because sqlx cannot type-check them against
87    /// the generated Rust enum at runtime (see `uses_inline_enums`).
88    fn row_field_type<'a>(&self, col: &'a ResolvedColumn) -> &'a str {
89        if self.uses_inline_enums() && col.neutral_type.starts_with("enum::") {
90            if col.nullable {
91                "Option<String>"
92            } else {
93                "String"
94            }
95        } else {
96            &col.full_type
97        }
98    }
99
100    /// Return the sqlx pool type for the configured engine.
101    fn pool_type(&self) -> &str {
102        match self.engine.as_str() {
103            "mysql" | "mariadb" => "sqlx::MySqlPool",
104            "sqlite" | "sqlite3" => "sqlx::SqlitePool",
105            _ => "sqlx::PgPool",
106        }
107    }
108
109    /// Return the sqlx query-result type for the configured engine.
110    fn query_result_type(&self) -> &str {
111        match self.engine.as_str() {
112            "mysql" | "mariadb" => "sqlx::mysql::MySqlQueryResult",
113            "sqlite" | "sqlite3" => "sqlx::sqlite::SqliteQueryResult",
114            _ => "sqlx::postgres::PgQueryResult",
115        }
116    }
117}
118
119impl CodegenBackend for SqlxBackend {
120    fn name(&self) -> &str {
121        "rust-sqlx"
122    }
123
124    fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
125        &self.manifest
126    }
127
128    fn supported_engines(&self) -> &[&str] {
129        &["postgresql", "mysql", "mariadb", "sqlite", "redshift"]
130    }
131
132    fn file_header(&self) -> String {
133        "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::needless_question_mark, clippy::redundant_closure)]"
134            .to_string()
135    }
136
137    fn apply_options(
138        &mut self,
139        options: &std::collections::HashMap<String, String>,
140    ) -> Result<(), ScytheError> {
141        if options.get("structs_only").is_some_and(|v| v == "true") {
142            self.structs_only = true;
143        }
144        Ok(())
145    }
146
147    fn generate_row_struct(
148        &self,
149        query_name: &str,
150        columns: &[ResolvedColumn],
151    ) -> Result<String, ScytheError> {
152        let struct_name = row_struct_name(query_name, &self.manifest.naming);
153        let mut out = String::new();
154
155        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::FromRow)]");
156        let _ = writeln!(out, "pub struct {} {{", struct_name);
157
158        for col in columns {
159            let field_type = self.row_field_type(col);
160            let _ = writeln!(out, "    pub {}: {},", col.field_name, field_type);
161        }
162
163        let _ = write!(out, "}}");
164        Ok(out)
165    }
166
167    fn generate_model_struct(
168        &self,
169        table_name: &str,
170        columns: &[ResolvedColumn],
171    ) -> Result<String, ScytheError> {
172        let singular = singularize(table_name);
173        let struct_name = to_pascal_case(&singular).into_owned();
174        let mut out = String::new();
175
176        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::FromRow)]");
177        let _ = writeln!(out, "pub struct {} {{", struct_name);
178
179        for col in columns {
180            let field_type = self.row_field_type(col);
181            let _ = writeln!(out, "    pub {}: {},", col.field_name, field_type);
182        }
183
184        let _ = write!(out, "}}");
185        Ok(out)
186    }
187
188    fn generate_query_fn(
189        &self,
190        analyzed: &AnalyzedQuery,
191        struct_name: &str,
192        _columns: &[ResolvedColumn],
193        params: &[ResolvedParam],
194    ) -> Result<String, ScytheError> {
195        // In structs_only mode, skip all function generation (avoids sqlx::query!() macros
196        // which require DATABASE_URL at compile time).
197        if self.structs_only {
198            return Ok(String::new());
199        }
200
201        let func_name = fn_name(&analyzed.name, &self.manifest.naming);
202        let mut out = String::new();
203
204        // Deprecated annotation
205        if let Some(ref msg) = analyzed.deprecated {
206            let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
207        }
208
209        // Build parameter list
210        let pool_type = self.pool_type();
211        let mut param_parts: Vec<String> = vec![format!("pool: &{}", pool_type)];
212        for param in params {
213            param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
214        }
215
216        // Clean SQL
217        let sql_raw = super::clean_sql_with_optional(
218            &analyzed.sql,
219            &analyzed.optional_params,
220            &analyzed.params,
221        );
222        let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
223
224        // Build bind params string
225        let bind_params: String = analyzed
226            .params
227            .iter()
228            .map(|p| {
229                let param_name = to_snake_case(&p.name);
230                if p.neutral_type.starts_with("enum::") {
231                    let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
232                    let rust_type = enum_type_name(enum_name, &self.manifest.naming);
233                    format!(", {} as &{}", param_name, rust_type)
234                } else {
235                    format!(", {}", param_name)
236                }
237            })
238            .collect();
239
240        // Handle :batch separately — generates a different function signature
241        if matches!(analyzed.command, QueryCommand::Batch) {
242            let batch_fn_name = format!("{}_batch", func_name);
243
244            // Generate params struct if >1 param
245            if params.len() > 1 {
246                let params_struct_name = format!("{}BatchParams", struct_name);
247                let _ = writeln!(out, "#[derive(Debug, Clone)]");
248                let _ = writeln!(out, "pub struct {} {{", params_struct_name);
249                for param in params {
250                    let _ = writeln!(out, "    pub {}: {},", param.field_name, param.full_type);
251                }
252                let _ = writeln!(out, "}}");
253                let _ = writeln!(out);
254
255                // Batch function takes &[ParamsStruct]
256                let _ = writeln!(
257                    out,
258                    "pub async fn {}(pool: &{}, items: &[{}]) -> Result<(), sqlx::Error> {{",
259                    batch_fn_name, pool_type, params_struct_name
260                );
261                let _ = writeln!(out, "    let mut tx = pool.begin().await?;");
262                let _ = writeln!(out, "    for item in items {{");
263
264                // Build bind params from struct fields
265                let struct_bind_params: String = params
266                    .iter()
267                    .map(|p| {
268                        if p.neutral_type.starts_with("enum::") {
269                            let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
270                            let rust_type = enum_type_name(enum_name, &self.manifest.naming);
271                            format!(", item.{} as &{}", p.field_name, rust_type)
272                        } else {
273                            format!(", item.{}", p.field_name)
274                        }
275                    })
276                    .collect();
277
278                let _ = writeln!(
279                    out,
280                    "        sqlx::query!(\"{}\"{})",
281                    sql, struct_bind_params
282                );
283                let _ = writeln!(out, "            .execute(&mut *tx)");
284                let _ = writeln!(out, "            .await?;");
285                let _ = writeln!(out, "    }}");
286                let _ = writeln!(out, "    tx.commit().await?;");
287                let _ = writeln!(out, "    Ok(())");
288            } else if params.len() == 1 {
289                // Single param — takes a slice of that type
290                let param = &params[0];
291                let _ = writeln!(
292                    out,
293                    "pub async fn {}(pool: &{}, items: &[{}]) -> Result<(), sqlx::Error> {{",
294                    batch_fn_name, pool_type, param.full_type
295                );
296                let _ = writeln!(out, "    let mut tx = pool.begin().await?;");
297                let _ = writeln!(out, "    for item in items {{");
298                let _ = writeln!(out, "        sqlx::query!(\"{}\", item)", sql);
299                let _ = writeln!(out, "            .execute(&mut *tx)");
300                let _ = writeln!(out, "            .await?;");
301                let _ = writeln!(out, "    }}");
302                let _ = writeln!(out, "    tx.commit().await?;");
303                let _ = writeln!(out, "    Ok(())");
304            } else {
305                // No params — just execute N times (unusual but valid)
306                let _ = writeln!(
307                    out,
308                    "pub async fn {}(pool: &{}, count: usize) -> Result<(), sqlx::Error> {{",
309                    batch_fn_name, pool_type
310                );
311                let _ = writeln!(out, "    let mut tx = pool.begin().await?;");
312                let _ = writeln!(out, "    for _ in 0..count {{");
313                let _ = writeln!(out, "        sqlx::query!(\"{}\")", sql);
314                let _ = writeln!(out, "            .execute(&mut *tx)");
315                let _ = writeln!(out, "            .await?;");
316                let _ = writeln!(out, "    }}");
317                let _ = writeln!(out, "    tx.commit().await?;");
318                let _ = writeln!(out, "    Ok(())");
319            }
320
321            let _ = write!(out, "}}");
322            return Ok(out);
323        }
324
325        // Return type for non-batch commands
326        let return_type = match &analyzed.command {
327            QueryCommand::One | QueryCommand::Opt => struct_name.to_string(),
328            QueryCommand::Many => format!("Vec<{}>", struct_name),
329            QueryCommand::Exec => "()".to_string(),
330            QueryCommand::ExecResult => self.query_result_type().to_string(),
331            QueryCommand::ExecRows => "u64".to_string(),
332            QueryCommand::Batch => unreachable!(),
333            QueryCommand::Grouped => {
334                return Err(ScytheError::new(
335                    ErrorCode::InternalError,
336                    "Grouped queries should be rewritten before codegen".to_string(),
337                ));
338            }
339        };
340
341        // Function signature
342        let _ = writeln!(
343            out,
344            "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
345            func_name,
346            param_parts.join(", "),
347            return_type
348        );
349
350        // Query body
351        let has_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
352
353        let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
354
355        if is_exec_rows {
356            if has_row_struct && !analyzed.columns.is_empty() {
357                let _ = write!(
358                    out,
359                    "    let result = sqlx::query_as!({}, \"{}\"{})",
360                    struct_name, sql, bind_params
361                );
362            } else {
363                let _ = write!(
364                    out,
365                    "    let result = sqlx::query!(\"{}\"{})",
366                    sql, bind_params
367                );
368            }
369        } else if has_row_struct && !analyzed.columns.is_empty() {
370            let _ = write!(
371                out,
372                "    sqlx::query_as!({}, \"{}\"{})",
373                struct_name, sql, bind_params
374            );
375        } else {
376            let _ = write!(out, "    sqlx::query!(\"{}\"{})", sql, bind_params);
377        }
378
379        let _ = writeln!(out);
380
381        // Fetch method
382        let fetch_method = match &analyzed.command {
383            QueryCommand::One | QueryCommand::Opt => ".fetch_one(pool)",
384            QueryCommand::Many => ".fetch_all(pool)",
385            QueryCommand::Exec => ".execute(pool)",
386            QueryCommand::ExecResult => ".execute(pool)",
387            QueryCommand::ExecRows => ".execute(pool)",
388            QueryCommand::Batch => unreachable!(),
389            QueryCommand::Grouped => {
390                return Err(ScytheError::new(
391                    ErrorCode::InternalError,
392                    "Grouped queries should be rewritten before codegen".to_string(),
393                ));
394            }
395        };
396
397        let _ = write!(out, "        {}", fetch_method);
398        let _ = writeln!(out);
399
400        // Post-processing for exec variants
401        match &analyzed.command {
402            QueryCommand::Exec => {
403                let _ = writeln!(out, "        .await?;");
404                let _ = writeln!(out, "    Ok(())");
405            }
406            QueryCommand::ExecRows => {
407                let _ = writeln!(out, "        .await?;");
408                let _ = writeln!(out, "    Ok(result.rows_affected())");
409            }
410            _ => {
411                let _ = writeln!(out, "        .await");
412            }
413        }
414
415        let _ = write!(out, "}}");
416        Ok(out)
417    }
418
419    fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
420        let mut out = String::with_capacity(256);
421        let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
422
423        let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
424        // MySQL/MariaDB/SQLite use inline ENUMs — sqlx decodes them by value matching only.
425        // The `type_name` annotation is PostgreSQL-specific (for named custom types) and
426        // causes a "mismatched types" error on MySQL at runtime.
427        match self.engine.as_str() {
428            "mysql" | "mariadb" | "sqlite" | "sqlite3" => {
429                let _ = writeln!(out, "#[sqlx(rename_all = \"snake_case\")]");
430            }
431            _ => {
432                let _ = writeln!(
433                    out,
434                    "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
435                    enum_info.sql_name
436                );
437            }
438        }
439        let _ = writeln!(out, "pub enum {type_name} {{");
440
441        for value in &enum_info.values {
442            let variant = enum_variant_name(value, &self.manifest.naming);
443            let _ = writeln!(out, "    {variant},");
444        }
445
446        let _ = write!(out, "}}");
447        Ok(out)
448    }
449
450    fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
451        use scythe_backend::types::resolve_type;
452
453        let struct_name = to_pascal_case(&composite.sql_name).into_owned();
454        let mut out = String::new();
455
456        let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
457        let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
458        let _ = writeln!(out, "pub struct {} {{", struct_name);
459        for field in &composite.fields {
460            let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
461                .map(|t| t.into_owned())
462                .map_err(|e| {
463                    ScytheError::new(
464                        ErrorCode::InternalError,
465                        format!("composite field type error: {}", e),
466                    )
467                })?;
468            let _ = writeln!(
469                out,
470                "    pub {}: {},",
471                to_snake_case(&field.name),
472                rust_type
473            );
474        }
475        let _ = write!(out, "}}");
476        Ok(out)
477    }
478}
479
480// ---------------------------------------------------------------------------
481// Internal helpers (moved from old modules)
482// ---------------------------------------------------------------------------
483
484/// Rewrite SQL to add enum type annotations for sqlx.
485fn rewrite_sql_for_enums(
486    sql: &str,
487    columns: &[AnalyzedColumn],
488    manifest: &BackendManifest,
489) -> String {
490    let enum_cols: Vec<(&str, String)> = columns
491        .iter()
492        .filter_map(|col| {
493            if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
494                let rust_type = enum_type_name(enum_name, &manifest.naming);
495                let annotation = if col.nullable {
496                    format!("Option<{}>", rust_type)
497                } else {
498                    rust_type
499                };
500                Some((col.name.as_str(), annotation))
501            } else {
502                None
503            }
504        })
505        .collect();
506
507    if enum_cols.is_empty() {
508        return sql.to_string();
509    }
510
511    let mut result = sql.to_string();
512    for (col_name, annotation) in &enum_cols {
513        let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
514        if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
515            let select_part = &result[..from_pos];
516            let rest = &result[from_pos..];
517            let new_select = replace_column_in_select(select_part, col_name, &alias);
518            result = format!("{}{}", new_select, rest);
519        }
520    }
521    result
522}
523
524fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
525    let mut result = select.to_string();
526    let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
527    for pattern in &patterns {
528        if let Some(pos) = result.rfind(pattern.as_str()) {
529            let after = pos + pattern.len();
530            let next_char = result[after..].chars().next();
531            if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
532                let prefix = &result[..pos + pattern.len() - col_name.len()];
533                let suffix = &result[after..];
534                result = format!("{}{}{}", prefix, replacement, suffix);
535                break;
536            }
537        }
538    }
539    result
540}