Skip to main content

scythe_codegen/backends/
mod.rs

1pub mod csharp_microsoft_sqlite;
2pub mod csharp_mysqlconnector;
3pub mod csharp_npgsql;
4pub mod elixir_ecto;
5pub mod elixir_exqlite;
6pub mod elixir_myxql;
7pub mod elixir_postgrex;
8pub mod go_database_sql;
9pub mod go_pgx;
10pub mod java_jdbc;
11pub mod java_r2dbc;
12pub mod kotlin_exposed;
13pub mod kotlin_jdbc;
14pub mod kotlin_r2dbc;
15pub mod php_amphp;
16pub mod php_pdo;
17pub mod python_aiomysql;
18pub mod python_aiosqlite;
19pub mod python_asyncpg;
20pub mod python_common;
21pub mod python_duckdb;
22pub mod python_psycopg3;
23pub mod ruby_mysql2;
24pub mod ruby_pg;
25pub(crate) mod ruby_rbs;
26pub mod ruby_sqlite3;
27pub mod ruby_trilogy;
28pub mod sqlx;
29pub mod tokio_postgres;
30pub mod typescript_better_sqlite3;
31pub mod typescript_common;
32pub mod typescript_duckdb;
33pub mod typescript_mysql2;
34pub mod typescript_pg;
35pub mod typescript_postgres;
36
37use scythe_core::analyzer::AnalyzedParam;
38use scythe_core::errors::{ErrorCode, ScytheError};
39
40use crate::backend_trait::CodegenBackend;
41
42/// Strip SQL comments, trailing semicolons, and excess whitespace.
43/// Preserves newlines between lines.
44pub(crate) fn clean_sql(sql: &str) -> String {
45    sql.lines()
46        .filter(|line| !line.trim_start().starts_with("--"))
47        .collect::<Vec<_>>()
48        .join("\n")
49        .trim()
50        .trim_end_matches(';')
51        .trim()
52        .to_string()
53}
54
55/// Like clean_sql but joins lines with spaces (for languages that embed SQL inline).
56pub(crate) fn clean_sql_oneline(sql: &str) -> String {
57    sql.lines()
58        .filter(|line| !line.trim_start().starts_with("--"))
59        .collect::<Vec<_>>()
60        .join(" ")
61        .trim()
62        .trim_end_matches(';')
63        .trim()
64        .to_string()
65}
66
67/// Rewrite SQL for optional parameters.
68///
69/// For each optional param, finds `column = $N` (or `column <> $N`, `column != $N`)
70/// and rewrites to `($N IS NULL OR column = $N)`. This allows callers to pass NULL
71/// to skip a filter condition at runtime.
72///
73/// This operates on the raw SQL before any backend-specific placeholder rewriting.
74pub(crate) fn rewrite_optional_params(
75    sql: &str,
76    optional_params: &[String],
77    params: &[AnalyzedParam],
78) -> String {
79    if optional_params.is_empty() {
80        return sql.to_string();
81    }
82
83    let mut result = sql.to_string();
84
85    for opt_name in optional_params {
86        let Some(param) = params.iter().find(|p| p.name == *opt_name) else {
87            continue;
88        };
89        let placeholder = format!("${}", param.position);
90
91        // Try each comparison operator
92        for op in &[
93            ">=", "<=", "<>", "!=", ">", "<", "=", "ILIKE", "ilike", "LIKE", "like",
94        ] {
95            result = rewrite_comparison(&result, &placeholder, op);
96        }
97    }
98
99    result
100}
101
102/// Rewrite a single `column <op> $N` pattern to `($N IS NULL OR column <op> $N)`.
103/// Handles both `column <op> $N` and `$N <op> column` orderings.
104fn rewrite_comparison(sql: &str, placeholder: &str, op: &str) -> String {
105    let mut result = String::with_capacity(sql.len() + 32);
106    let chars: Vec<char> = sql.chars().collect();
107    let len = chars.len();
108    let mut i = 0;
109
110    while i < len {
111        // Try to match `identifier <op> $N` at this position
112        if let Some((start, col, end)) = try_match_col_op_ph(&chars, i, op, placeholder) {
113            // Write everything before the match start
114            if start > i {
115                // This shouldn't happen since we iterate char by char
116            }
117            result.push_str(&format!(
118                "({placeholder} IS NULL OR {col} {op} {placeholder})"
119            ));
120            i = end;
121            continue;
122        }
123
124        // Try to match `$N <op> identifier` at this position
125        if let Some((end, col)) = try_match_ph_op_col(&chars, i, op, placeholder) {
126            result.push_str(&format!(
127                "({placeholder} IS NULL OR {col} {op} {placeholder})"
128            ));
129            i = end;
130            continue;
131        }
132
133        result.push(chars[i]);
134        i += 1;
135    }
136
137    result
138}
139
140/// Try to match `identifier <ws>* <op> <ws>* placeholder` starting at position `i`.
141/// Returns `(match_start, column_name, match_end)` if found.
142fn try_match_col_op_ph(
143    chars: &[char],
144    i: usize,
145    op: &str,
146    placeholder: &str,
147) -> Option<(usize, String, usize)> {
148    // Must start with an identifier character (word char)
149    if !is_ident_char(chars[i]) {
150        return None;
151    }
152    // Must not be preceded by another ident char (whole-word boundary)
153    if i > 0 && is_ident_char(chars[i - 1]) {
154        return None;
155    }
156
157    // Read the identifier
158    let ident_start = i;
159    let mut j = i;
160    while j < chars.len() && is_ident_char(chars[j]) {
161        j += 1;
162    }
163    let ident: String = chars[ident_start..j].iter().collect();
164
165    // Skip whitespace
166    while j < chars.len() && chars[j].is_whitespace() {
167        j += 1;
168    }
169
170    // Match operator
171    let op_chars: Vec<char> = op.chars().collect();
172    if j + op_chars.len() > chars.len() {
173        return None;
174    }
175    for (k, oc) in op_chars.iter().enumerate() {
176        if chars[j + k] != *oc {
177            return None;
178        }
179    }
180    j += op_chars.len();
181
182    // Skip whitespace
183    while j < chars.len() && chars[j].is_whitespace() {
184        j += 1;
185    }
186
187    // Match placeholder
188    let ph_chars: Vec<char> = placeholder.chars().collect();
189    if j + ph_chars.len() > chars.len() {
190        return None;
191    }
192    for (k, pc) in ph_chars.iter().enumerate() {
193        if chars[j + k] != *pc {
194            return None;
195        }
196    }
197    j += ph_chars.len();
198
199    // Ensure placeholder is not followed by a digit (e.g., $1 vs $10)
200    if j < chars.len() && chars[j].is_ascii_digit() {
201        return None;
202    }
203
204    Some((i, ident, j))
205}
206
207/// Try to match `placeholder <ws>* <op> <ws>* identifier` starting at position `i`.
208/// Returns `(match_end, column_name)` if found.
209fn try_match_ph_op_col(
210    chars: &[char],
211    i: usize,
212    op: &str,
213    placeholder: &str,
214) -> Option<(usize, String)> {
215    let ph_chars: Vec<char> = placeholder.chars().collect();
216    if i + ph_chars.len() > chars.len() {
217        return None;
218    }
219
220    // Must not be preceded by $ or digit (boundary check)
221    if i > 0 && (chars[i - 1] == '$' || chars[i - 1].is_ascii_digit()) {
222        return None;
223    }
224
225    // Match placeholder
226    for (k, pc) in ph_chars.iter().enumerate() {
227        if chars[i + k] != *pc {
228            return None;
229        }
230    }
231    let mut j = i + ph_chars.len();
232
233    // Ensure placeholder is not followed by a digit
234    if j < chars.len() && chars[j].is_ascii_digit() {
235        return None;
236    }
237
238    // Skip whitespace
239    while j < chars.len() && chars[j].is_whitespace() {
240        j += 1;
241    }
242
243    // Match operator
244    let op_chars: Vec<char> = op.chars().collect();
245    if j + op_chars.len() > chars.len() {
246        return None;
247    }
248    for (k, oc) in op_chars.iter().enumerate() {
249        if chars[j + k] != *oc {
250            return None;
251        }
252    }
253    j += op_chars.len();
254
255    // Skip whitespace
256    while j < chars.len() && chars[j].is_whitespace() {
257        j += 1;
258    }
259
260    // Read the identifier
261    if j >= chars.len() || !is_ident_char(chars[j]) {
262        return None;
263    }
264    let ident_start = j;
265    while j < chars.len() && is_ident_char(chars[j]) {
266        j += 1;
267    }
268    let ident: String = chars[ident_start..j].iter().collect();
269
270    // Avoid matching "NULL" (from already-rewritten text)
271    if ident == "NULL" {
272        return None;
273    }
274
275    Some((j, ident))
276}
277
278/// Clean SQL and apply optional parameter rewriting.
279pub(crate) fn clean_sql_with_optional(
280    sql: &str,
281    optional_params: &[String],
282    params: &[AnalyzedParam],
283) -> String {
284    let cleaned = clean_sql(sql);
285    rewrite_optional_params(&cleaned, optional_params, params)
286}
287
288/// Clean SQL (oneline) and apply optional parameter rewriting.
289pub(crate) fn clean_sql_oneline_with_optional(
290    sql: &str,
291    optional_params: &[String],
292    params: &[AnalyzedParam],
293) -> String {
294    let cleaned = clean_sql_oneline(sql);
295    rewrite_optional_params(&cleaned, optional_params, params)
296}
297
298fn is_ident_char(c: char) -> bool {
299    c.is_alphanumeric() || c == '_' || c == '.'
300}
301
302/// Get a backend by name and database engine.
303///
304/// The `engine` parameter (e.g., "postgresql", "mysql", "sqlite") determines
305/// which manifest is loaded for type mappings. PG-only backends reject non-PG engines.
306pub fn get_backend(name: &str, engine: &str) -> Result<Box<dyn CodegenBackend>, ScytheError> {
307    // Normalize engine aliases (e.g., "cockroachdb" -> "postgresql") before
308    // passing to backend constructors so each backend only needs to match
309    // canonical engine names.
310    let canonical_engine = normalize_engine(engine);
311    let backend: Box<dyn CodegenBackend> = match name {
312        "rust-sqlx" | "sqlx" | "rust" => Box::new(sqlx::SqlxBackend::new(canonical_engine)?),
313        "rust-tokio-postgres" | "tokio-postgres" => {
314            Box::new(tokio_postgres::TokioPostgresBackend::new(canonical_engine)?)
315        }
316        "python-psycopg3" | "python" => Box::new(python_psycopg3::PythonPsycopg3Backend::new(
317            canonical_engine,
318        )?),
319        "python-asyncpg" => Box::new(python_asyncpg::PythonAsyncpgBackend::new(canonical_engine)?),
320        "python-aiomysql" => Box::new(python_aiomysql::PythonAiomysqlBackend::new(
321            canonical_engine,
322        )?),
323        "python-aiosqlite" => Box::new(python_aiosqlite::PythonAiosqliteBackend::new(
324            canonical_engine,
325        )?),
326        "python-duckdb" => Box::new(python_duckdb::PythonDuckdbBackend::new(canonical_engine)?),
327        "typescript-postgres" | "ts" | "typescript" => Box::new(
328            typescript_postgres::TypescriptPostgresBackend::new(canonical_engine)?,
329        ),
330        "typescript-pg" => Box::new(typescript_pg::TypescriptPgBackend::new(canonical_engine)?),
331        "typescript-mysql2" => Box::new(typescript_mysql2::TypescriptMysql2Backend::new(
332            canonical_engine,
333        )?),
334        "typescript-better-sqlite3" => Box::new(
335            typescript_better_sqlite3::TypescriptBetterSqlite3Backend::new(canonical_engine)?,
336        ),
337        "typescript-duckdb" => Box::new(typescript_duckdb::TypescriptDuckdbBackend::new(
338            canonical_engine,
339        )?),
340        "go-database-sql" => Box::new(go_database_sql::GoDatabaseSqlBackend::new(
341            canonical_engine,
342        )?),
343        "go-pgx" | "go" => Box::new(go_pgx::GoPgxBackend::new(canonical_engine)?),
344        "java-jdbc" | "java" => Box::new(java_jdbc::JavaJdbcBackend::new(canonical_engine)?),
345        "java-r2dbc" | "r2dbc-java" => {
346            Box::new(java_r2dbc::JavaR2dbcBackend::new(canonical_engine)?)
347        }
348        "kotlin-exposed" | "exposed" => {
349            Box::new(kotlin_exposed::KotlinExposedBackend::new(canonical_engine)?)
350        }
351        "kotlin-jdbc" | "kotlin" | "kt" => {
352            Box::new(kotlin_jdbc::KotlinJdbcBackend::new(canonical_engine)?)
353        }
354        "kotlin-r2dbc" | "r2dbc-kotlin" => {
355            Box::new(kotlin_r2dbc::KotlinR2dbcBackend::new(canonical_engine)?)
356        }
357        "csharp-npgsql" | "csharp" | "c#" | "dotnet" => {
358            Box::new(csharp_npgsql::CsharpNpgsqlBackend::new(canonical_engine)?)
359        }
360        "csharp-mysqlconnector" => Box::new(
361            csharp_mysqlconnector::CsharpMysqlConnectorBackend::new(canonical_engine)?,
362        ),
363        "csharp-microsoft-sqlite" => Box::new(
364            csharp_microsoft_sqlite::CsharpMicrosoftSqliteBackend::new(canonical_engine)?,
365        ),
366        "elixir-postgrex" | "elixir" | "ex" => Box::new(
367            elixir_postgrex::ElixirPostgrexBackend::new(canonical_engine)?,
368        ),
369        "elixir-ecto" | "ecto" => Box::new(elixir_ecto::ElixirEctoBackend::new(canonical_engine)?),
370        "elixir-myxql" => Box::new(elixir_myxql::ElixirMyxqlBackend::new(canonical_engine)?),
371        "elixir-exqlite" => Box::new(elixir_exqlite::ElixirExqliteBackend::new(canonical_engine)?),
372        "ruby-pg" | "ruby" | "rb" => Box::new(ruby_pg::RubyPgBackend::new(canonical_engine)?),
373        "ruby-mysql2" => Box::new(ruby_mysql2::RubyMysql2Backend::new(canonical_engine)?),
374        "ruby-sqlite3" => Box::new(ruby_sqlite3::RubySqlite3Backend::new(canonical_engine)?),
375        "ruby-trilogy" | "trilogy" => {
376            Box::new(ruby_trilogy::RubyTrilogyBackend::new(canonical_engine)?)
377        }
378        "php-pdo" | "php" => Box::new(php_pdo::PhpPdoBackend::new(canonical_engine)?),
379        "php-amphp" | "amphp" => Box::new(php_amphp::PhpAmphpBackend::new(canonical_engine)?),
380        _ => {
381            return Err(ScytheError::new(
382                ErrorCode::InternalError,
383                format!("unknown backend: {}", name),
384            ));
385        }
386    };
387
388    // Validate engine is supported by this backend
389    if !backend
390        .supported_engines()
391        .iter()
392        .any(|e| normalize_engine(e) == canonical_engine)
393    {
394        return Err(ScytheError::new(
395            ErrorCode::InternalError,
396            format!(
397                "backend '{}' does not support engine '{}'. Supported: {:?}",
398                name,
399                engine,
400                backend.supported_engines()
401            ),
402        ));
403    }
404
405    Ok(backend)
406}
407
408/// Normalize engine name to canonical form.
409fn normalize_engine(engine: &str) -> &str {
410    match engine {
411        "postgresql" | "postgres" | "pg" | "cockroachdb" | "crdb" => "postgresql",
412        "mysql" | "mariadb" => "mysql",
413        "sqlite" | "sqlite3" => "sqlite",
414        "duckdb" => "duckdb",
415        other => other,
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    fn param(name: &str, position: i64) -> AnalyzedParam {
424        AnalyzedParam {
425            name: name.to_string(),
426            neutral_type: "string".to_string(),
427            nullable: true,
428            position,
429        }
430    }
431
432    #[test]
433    fn test_normalize_engine_cockroachdb() {
434        assert_eq!(normalize_engine("cockroachdb"), "postgresql");
435        assert_eq!(normalize_engine("crdb"), "postgresql");
436    }
437
438    #[test]
439    fn test_get_backend_cockroachdb_with_pg_backends() {
440        // CockroachDB should work with all PostgreSQL-compatible backends
441        let pg_backends = [
442            "rust-sqlx",
443            "rust-tokio-postgres",
444            "python-psycopg3",
445            "python-asyncpg",
446            "typescript-postgres",
447            "typescript-pg",
448            "go-pgx",
449            "ruby-pg",
450            "elixir-postgrex",
451            "csharp-npgsql",
452            "php-pdo",
453            "php-amphp",
454        ];
455        for backend_name in &pg_backends {
456            let result = get_backend(backend_name, "cockroachdb");
457            assert!(
458                result.is_ok(),
459                "backend '{}' should accept cockroachdb engine, got: {:?}",
460                backend_name,
461                result.err()
462            );
463        }
464    }
465
466    #[test]
467    fn test_get_backend_crdb_alias() {
468        let result = get_backend("rust-sqlx", "crdb");
469        assert!(
470            result.is_ok(),
471            "rust-sqlx should accept 'crdb' engine alias"
472        );
473    }
474
475    #[test]
476    fn test_normalize_engine_duckdb() {
477        assert_eq!(normalize_engine("duckdb"), "duckdb");
478    }
479
480    #[test]
481    fn test_get_backend_duckdb_with_compatible_backends() {
482        let duckdb_backends = [
483            "python-duckdb",
484            "typescript-duckdb",
485            "go-database-sql",
486            "java-jdbc",
487            "kotlin-jdbc",
488        ];
489        for backend_name in &duckdb_backends {
490            let result = get_backend(backend_name, "duckdb");
491            assert!(
492                result.is_ok(),
493                "backend '{}' should accept duckdb engine, got: {:?}",
494                backend_name,
495                result.err()
496            );
497        }
498    }
499
500    #[test]
501    fn test_get_backend_duckdb_rejected_by_pg_only() {
502        let result = get_backend("rust-sqlx", "duckdb");
503        assert!(result.is_err(), "rust-sqlx should reject duckdb engine");
504    }
505
506    #[test]
507    fn test_rewrite_simple_equality() {
508        let sql = "SELECT * FROM users WHERE status = $1";
509        let params = vec![param("status", 1)];
510        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
511        assert_eq!(
512            result,
513            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1)"
514        );
515    }
516
517    #[test]
518    fn test_rewrite_qualified_column() {
519        let sql = "SELECT * FROM users u WHERE u.status = $1";
520        let params = vec![param("status", 1)];
521        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
522        assert_eq!(
523            result,
524            "SELECT * FROM users u WHERE ($1 IS NULL OR u.status = $1)"
525        );
526    }
527
528    #[test]
529    fn test_rewrite_multiple_optional() {
530        let sql = "SELECT * FROM users WHERE status = $1 AND name = $2";
531        let params = vec![param("status", 1), param("name", 2)];
532        let result =
533            rewrite_optional_params(sql, &["status".to_string(), "name".to_string()], &params);
534        assert_eq!(
535            result,
536            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1) AND ($2 IS NULL OR name = $2)"
537        );
538    }
539
540    #[test]
541    fn test_rewrite_mixed_optional_required() {
542        let sql = "SELECT * FROM users WHERE id = $1 AND status = $2";
543        let params = vec![param("id", 1), param("status", 2)];
544        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
545        assert_eq!(
546            result,
547            "SELECT * FROM users WHERE id = $1 AND ($2 IS NULL OR status = $2)"
548        );
549    }
550
551    #[test]
552    fn test_rewrite_like_operator() {
553        let sql = "SELECT * FROM users WHERE name LIKE $1";
554        let params = vec![param("name", 1)];
555        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
556        assert_eq!(
557            result,
558            "SELECT * FROM users WHERE ($1 IS NULL OR name LIKE $1)"
559        );
560    }
561
562    #[test]
563    fn test_rewrite_ilike_operator() {
564        let sql = "SELECT * FROM users WHERE name ILIKE $1";
565        let params = vec![param("name", 1)];
566        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
567        assert_eq!(
568            result,
569            "SELECT * FROM users WHERE ($1 IS NULL OR name ILIKE $1)"
570        );
571    }
572
573    #[test]
574    fn test_rewrite_comparison_operators() {
575        let sql = "SELECT * FROM users WHERE age >= $1";
576        let params = vec![param("age", 1)];
577        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
578        assert_eq!(
579            result,
580            "SELECT * FROM users WHERE ($1 IS NULL OR age >= $1)"
581        );
582    }
583
584    #[test]
585    fn test_rewrite_less_than() {
586        let sql = "SELECT * FROM users WHERE age < $1";
587        let params = vec![param("age", 1)];
588        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
589        assert_eq!(result, "SELECT * FROM users WHERE ($1 IS NULL OR age < $1)");
590    }
591
592    #[test]
593    fn test_no_rewrite_without_optional() {
594        let sql = "SELECT * FROM users WHERE status = $1";
595        let params = vec![param("status", 1)];
596        let result = rewrite_optional_params(sql, &[], &params);
597        assert_eq!(result, sql);
598    }
599
600    #[test]
601    fn test_rewrite_not_equal() {
602        let sql = "SELECT * FROM users WHERE status <> $1";
603        let params = vec![param("status", 1)];
604        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
605        assert_eq!(
606            result,
607            "SELECT * FROM users WHERE ($1 IS NULL OR status <> $1)"
608        );
609    }
610
611    #[test]
612    fn test_rewrite_does_not_match_similar_placeholder() {
613        // $1 should not match $10
614        let sql = "SELECT * FROM users WHERE status = $10";
615        let params = vec![param("status", 1)];
616        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
617        // $1 placeholder doesn't appear, so no rewrite
618        assert_eq!(result, sql);
619    }
620}