Skip to main content

scythe_codegen/backends/
mod.rs

1pub(crate) mod csharp_microsoft_sqlite;
2pub(crate) mod csharp_mysqlconnector;
3pub(crate) mod csharp_npgsql;
4pub(crate) mod csharp_oracle;
5pub(crate) mod csharp_snowflake;
6pub(crate) mod csharp_sqlclient;
7pub(crate) mod elixir_ecto;
8pub(crate) mod elixir_exqlite;
9pub(crate) mod elixir_jamdb;
10pub(crate) mod elixir_myxql;
11pub(crate) mod elixir_postgrex;
12pub(crate) mod elixir_tds;
13pub(crate) mod go_database_sql;
14pub(crate) mod go_godror;
15pub(crate) mod go_gosnowflake;
16pub(crate) mod go_pgx;
17pub(crate) mod java_jdbc;
18pub(crate) mod java_r2dbc;
19pub(crate) mod kotlin_exposed;
20pub(crate) mod kotlin_jdbc;
21pub(crate) mod kotlin_r2dbc;
22pub(crate) mod php_amphp;
23pub(crate) mod php_pdo;
24pub(crate) mod python_aiomysql;
25pub(crate) mod python_aiosqlite;
26pub(crate) mod python_asyncpg;
27pub(crate) mod python_common;
28pub(crate) mod python_duckdb;
29pub(crate) mod python_oracledb;
30pub(crate) mod python_psycopg3;
31pub(crate) mod python_pyodbc;
32pub(crate) mod python_snowflake;
33pub(crate) mod ruby_mysql2;
34pub(crate) mod ruby_oci8;
35pub(crate) mod ruby_pg;
36pub(crate) mod ruby_rbs;
37pub(crate) mod ruby_sqlite3;
38pub(crate) mod ruby_tiny_tds;
39pub(crate) mod ruby_trilogy;
40pub(crate) mod rust_sibyl;
41pub(crate) mod rust_tiberius;
42pub(crate) mod sqlx;
43pub(crate) mod tokio_postgres;
44pub(crate) mod typescript_better_sqlite3;
45pub(crate) mod typescript_common;
46pub(crate) mod typescript_duckdb;
47pub(crate) mod typescript_mssql;
48pub(crate) mod typescript_mysql2;
49pub(crate) mod typescript_oracledb;
50pub(crate) mod typescript_pg;
51pub(crate) mod typescript_postgres;
52pub(crate) mod typescript_snowflake;
53
54use scythe_backend::manifest::BackendManifest;
55use scythe_core::analyzer::AnalyzedParam;
56use scythe_core::errors::{ErrorCode, ScytheError};
57
58use crate::backend_trait::CodegenBackend;
59
60/// Load a backend manifest, preferring a user-provided file at `override_path`
61/// and falling back to the embedded `default_toml` string.
62pub(crate) fn load_or_default_manifest(
63    override_path: &str,
64    default_toml: &str,
65) -> Result<BackendManifest, ScytheError> {
66    let path = std::path::Path::new(override_path);
67    if path.exists() {
68        scythe_backend::manifest::load_manifest(path)
69            .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))
70    } else {
71        toml::from_str(default_toml)
72            .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))
73    }
74}
75
76/// Strip SQL comments, trailing semicolons, and excess whitespace.
77/// Preserves newlines between lines.
78pub(crate) fn clean_sql(sql: &str) -> String {
79    sql.lines()
80        .filter(|line| !line.trim_start().starts_with("--"))
81        .collect::<Vec<_>>()
82        .join("\n")
83        .trim()
84        .trim_end_matches(';')
85        .trim()
86        .to_string()
87}
88
89/// Like clean_sql but joins lines with spaces (for languages that embed SQL inline).
90pub(crate) fn clean_sql_oneline(sql: &str) -> String {
91    sql.lines()
92        .filter(|line| !line.trim_start().starts_with("--"))
93        .collect::<Vec<_>>()
94        .join(" ")
95        .trim()
96        .trim_end_matches(';')
97        .trim()
98        .to_string()
99}
100
101/// Rewrite SQL for optional parameters.
102///
103/// For each optional param, finds `column = $N` (or `column <> $N`, `column != $N`)
104/// and rewrites to `($N IS NULL OR column = $N)`. This allows callers to pass NULL
105/// to skip a filter condition at runtime.
106///
107/// This operates on the raw SQL before any backend-specific placeholder rewriting.
108pub(crate) fn rewrite_optional_params(
109    sql: &str,
110    optional_params: &[String],
111    params: &[AnalyzedParam],
112) -> String {
113    if optional_params.is_empty() {
114        return sql.to_string();
115    }
116
117    let mut result = sql.to_string();
118
119    for opt_name in optional_params {
120        let Some(param) = params.iter().find(|p| p.name == *opt_name) else {
121            continue;
122        };
123        let placeholder = format!("${}", param.position);
124
125        // Try each comparison operator
126        for op in &[
127            ">=", "<=", "<>", "!=", ">", "<", "=", "ILIKE", "ilike", "LIKE", "like",
128        ] {
129            result = rewrite_comparison(&result, &placeholder, op);
130        }
131    }
132
133    result
134}
135
136/// Rewrite a single `column <op> $N` pattern to `($N IS NULL OR column <op> $N)`.
137/// Handles both `column <op> $N` and `$N <op> column` orderings.
138fn rewrite_comparison(sql: &str, placeholder: &str, op: &str) -> String {
139    let mut result = String::with_capacity(sql.len() + 32);
140    let chars: Vec<char> = sql.chars().collect();
141    let len = chars.len();
142    let mut i = 0;
143
144    while i < len {
145        // Try to match `identifier <op> $N` at this position
146        if let Some((_start, col, end)) = try_match_col_op_ph(&chars, i, op, placeholder) {
147            result.push_str(&format!(
148                "({placeholder} IS NULL OR {col} {op} {placeholder})"
149            ));
150            i = end;
151            continue;
152        }
153
154        // Try to match `$N <op> identifier` at this position
155        if let Some((end, col)) = try_match_ph_op_col(&chars, i, op, placeholder) {
156            result.push_str(&format!(
157                "({placeholder} IS NULL OR {col} {op} {placeholder})"
158            ));
159            i = end;
160            continue;
161        }
162
163        result.push(chars[i]);
164        i += 1;
165    }
166
167    result
168}
169
170/// Try to match `identifier <ws>* <op> <ws>* placeholder` starting at position `i`.
171/// Returns `(match_start, column_name, match_end)` if found.
172fn try_match_col_op_ph(
173    chars: &[char],
174    i: usize,
175    op: &str,
176    placeholder: &str,
177) -> Option<(usize, String, usize)> {
178    // Must start with an identifier character (word char)
179    if !is_ident_char(chars[i]) {
180        return None;
181    }
182    // Must not be preceded by another ident char (whole-word boundary)
183    if i > 0 && is_ident_char(chars[i - 1]) {
184        return None;
185    }
186
187    // Read the identifier
188    let ident_start = i;
189    let mut j = i;
190    while j < chars.len() && is_ident_char(chars[j]) {
191        j += 1;
192    }
193    let ident: String = chars[ident_start..j].iter().collect();
194
195    // Skip whitespace
196    while j < chars.len() && chars[j].is_whitespace() {
197        j += 1;
198    }
199
200    // Match operator
201    let op_chars: Vec<char> = op.chars().collect();
202    if j + op_chars.len() > chars.len() {
203        return None;
204    }
205    for (k, oc) in op_chars.iter().enumerate() {
206        if chars[j + k] != *oc {
207            return None;
208        }
209    }
210    j += op_chars.len();
211
212    // Skip whitespace
213    while j < chars.len() && chars[j].is_whitespace() {
214        j += 1;
215    }
216
217    // Match placeholder
218    let ph_chars: Vec<char> = placeholder.chars().collect();
219    if j + ph_chars.len() > chars.len() {
220        return None;
221    }
222    for (k, pc) in ph_chars.iter().enumerate() {
223        if chars[j + k] != *pc {
224            return None;
225        }
226    }
227    j += ph_chars.len();
228
229    // Ensure placeholder is not followed by a digit (e.g., $1 vs $10)
230    if j < chars.len() && chars[j].is_ascii_digit() {
231        return None;
232    }
233
234    Some((i, ident, j))
235}
236
237/// Try to match `placeholder <ws>* <op> <ws>* identifier` starting at position `i`.
238/// Returns `(match_end, column_name)` if found.
239fn try_match_ph_op_col(
240    chars: &[char],
241    i: usize,
242    op: &str,
243    placeholder: &str,
244) -> Option<(usize, String)> {
245    let ph_chars: Vec<char> = placeholder.chars().collect();
246    if i + ph_chars.len() > chars.len() {
247        return None;
248    }
249
250    // Must not be preceded by $ or digit (boundary check)
251    if i > 0 && (chars[i - 1] == '$' || chars[i - 1].is_ascii_digit()) {
252        return None;
253    }
254
255    // Match placeholder
256    for (k, pc) in ph_chars.iter().enumerate() {
257        if chars[i + k] != *pc {
258            return None;
259        }
260    }
261    let mut j = i + ph_chars.len();
262
263    // Ensure placeholder is not followed by a digit
264    if j < chars.len() && chars[j].is_ascii_digit() {
265        return None;
266    }
267
268    // Skip whitespace
269    while j < chars.len() && chars[j].is_whitespace() {
270        j += 1;
271    }
272
273    // Match operator
274    let op_chars: Vec<char> = op.chars().collect();
275    if j + op_chars.len() > chars.len() {
276        return None;
277    }
278    for (k, oc) in op_chars.iter().enumerate() {
279        if chars[j + k] != *oc {
280            return None;
281        }
282    }
283    j += op_chars.len();
284
285    // Skip whitespace
286    while j < chars.len() && chars[j].is_whitespace() {
287        j += 1;
288    }
289
290    // Read the identifier
291    if j >= chars.len() || !is_ident_char(chars[j]) {
292        return None;
293    }
294    let ident_start = j;
295    while j < chars.len() && is_ident_char(chars[j]) {
296        j += 1;
297    }
298    let ident: String = chars[ident_start..j].iter().collect();
299
300    // Avoid matching "NULL" (from already-rewritten text)
301    if ident == "NULL" {
302        return None;
303    }
304
305    Some((j, ident))
306}
307
308/// Clean SQL and apply optional parameter rewriting.
309pub(crate) fn clean_sql_with_optional(
310    sql: &str,
311    optional_params: &[String],
312    params: &[AnalyzedParam],
313) -> String {
314    let cleaned = clean_sql(sql);
315    rewrite_optional_params(&cleaned, optional_params, params)
316}
317
318/// Clean SQL (oneline) and apply optional parameter rewriting.
319pub(crate) fn clean_sql_oneline_with_optional(
320    sql: &str,
321    optional_params: &[String],
322    params: &[AnalyzedParam],
323) -> String {
324    let cleaned = clean_sql_oneline(sql);
325    rewrite_optional_params(&cleaned, optional_params, params)
326}
327
328fn is_ident_char(c: char) -> bool {
329    c.is_alphanumeric() || c == '_' || c == '.'
330}
331
332/// Rewrite PostgreSQL `$1, $2, ...` positional placeholders to a target format.
333/// Skips placeholders inside single-quoted SQL string literals.
334/// The `formatter` closure receives the parameter number and returns the replacement string.
335pub(crate) fn rewrite_pg_placeholders(sql: &str, formatter: impl Fn(u32) -> String) -> String {
336    let mut result = String::with_capacity(sql.len());
337    let mut chars = sql.chars().peekable();
338    while let Some(ch) = chars.next() {
339        if ch == '\'' {
340            result.push(ch);
341            while let Some(inner) = chars.next() {
342                result.push(inner);
343                if inner == '\'' {
344                    if chars.peek() == Some(&'\'') {
345                        result.push(chars.next().unwrap());
346                    } else {
347                        break;
348                    }
349                }
350            }
351        } else if ch == '$' {
352            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
353                let mut num_str = String::new();
354                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
355                    num_str.push(chars.next().unwrap());
356                }
357                let num: u32 = num_str.parse().unwrap_or(0);
358                result.push_str(&formatter(num));
359            } else {
360                result.push(ch);
361            }
362        } else {
363            result.push(ch);
364        }
365    }
366    result
367}
368
369/// Get a backend by name and database engine.
370///
371/// The `engine` parameter (e.g., "postgresql", "mysql", "sqlite") determines
372/// which manifest is loaded for type mappings. PG-only backends reject non-PG engines.
373pub fn get_backend(name: &str, engine: &str) -> Result<Box<dyn CodegenBackend>, ScytheError> {
374    // Normalize engine aliases (e.g., "cockroachdb" -> "postgresql") before
375    // passing to backend constructors so each backend only needs to match
376    // canonical engine names.
377    let canonical_engine = normalize_engine(engine);
378    let backend: Box<dyn CodegenBackend> = match name {
379        "rust-sqlx" | "sqlx" | "rust" => Box::new(sqlx::SqlxBackend::new(canonical_engine)?),
380        "rust-tokio-postgres" | "tokio-postgres" => {
381            Box::new(tokio_postgres::TokioPostgresBackend::new(canonical_engine)?)
382        }
383        "python-psycopg3" | "python" => Box::new(python_psycopg3::PythonPsycopg3Backend::new(
384            canonical_engine,
385        )?),
386        "python-asyncpg" => Box::new(python_asyncpg::PythonAsyncpgBackend::new(canonical_engine)?),
387        "python-aiomysql" => Box::new(python_aiomysql::PythonAiomysqlBackend::new(
388            canonical_engine,
389        )?),
390        "python-aiosqlite" => Box::new(python_aiosqlite::PythonAiosqliteBackend::new(
391            canonical_engine,
392        )?),
393        "python-duckdb" => Box::new(python_duckdb::PythonDuckdbBackend::new(canonical_engine)?),
394        "typescript-postgres" | "ts" | "typescript" => Box::new(
395            typescript_postgres::TypescriptPostgresBackend::new(canonical_engine)?,
396        ),
397        "typescript-pg" => Box::new(typescript_pg::TypescriptPgBackend::new(canonical_engine)?),
398        "typescript-mysql2" => Box::new(typescript_mysql2::TypescriptMysql2Backend::new(
399            canonical_engine,
400        )?),
401        "typescript-better-sqlite3" => Box::new(
402            typescript_better_sqlite3::TypescriptBetterSqlite3Backend::new(canonical_engine)?,
403        ),
404        "typescript-duckdb" => Box::new(typescript_duckdb::TypescriptDuckdbBackend::new(
405            canonical_engine,
406        )?),
407        "go-database-sql" => Box::new(go_database_sql::GoDatabaseSqlBackend::new(
408            canonical_engine,
409        )?),
410        "go-pgx" | "go" => Box::new(go_pgx::GoPgxBackend::new(canonical_engine)?),
411        "java-jdbc" | "java" => Box::new(java_jdbc::JavaJdbcBackend::new(canonical_engine)?),
412        "java-r2dbc" | "r2dbc-java" => {
413            Box::new(java_r2dbc::JavaR2dbcBackend::new(canonical_engine)?)
414        }
415        "kotlin-exposed" | "exposed" => {
416            Box::new(kotlin_exposed::KotlinExposedBackend::new(canonical_engine)?)
417        }
418        "kotlin-jdbc" | "kotlin" | "kt" => {
419            Box::new(kotlin_jdbc::KotlinJdbcBackend::new(canonical_engine)?)
420        }
421        "kotlin-r2dbc" | "r2dbc-kotlin" => {
422            Box::new(kotlin_r2dbc::KotlinR2dbcBackend::new(canonical_engine)?)
423        }
424        "csharp-npgsql" | "csharp" | "c#" | "dotnet" => {
425            Box::new(csharp_npgsql::CsharpNpgsqlBackend::new(canonical_engine)?)
426        }
427        "csharp-mysqlconnector" => Box::new(
428            csharp_mysqlconnector::CsharpMysqlConnectorBackend::new(canonical_engine)?,
429        ),
430        "csharp-microsoft-sqlite" => Box::new(
431            csharp_microsoft_sqlite::CsharpMicrosoftSqliteBackend::new(canonical_engine)?,
432        ),
433        "elixir-postgrex" | "elixir" | "ex" => Box::new(
434            elixir_postgrex::ElixirPostgrexBackend::new(canonical_engine)?,
435        ),
436        "elixir-ecto" | "ecto" => Box::new(elixir_ecto::ElixirEctoBackend::new(canonical_engine)?),
437        "elixir-myxql" => Box::new(elixir_myxql::ElixirMyxqlBackend::new(canonical_engine)?),
438        "elixir-exqlite" => Box::new(elixir_exqlite::ElixirExqliteBackend::new(canonical_engine)?),
439        "ruby-pg" | "ruby" | "rb" => Box::new(ruby_pg::RubyPgBackend::new(canonical_engine)?),
440        "ruby-mysql2" => Box::new(ruby_mysql2::RubyMysql2Backend::new(canonical_engine)?),
441        "ruby-sqlite3" => Box::new(ruby_sqlite3::RubySqlite3Backend::new(canonical_engine)?),
442        "ruby-trilogy" | "trilogy" => {
443            Box::new(ruby_trilogy::RubyTrilogyBackend::new(canonical_engine)?)
444        }
445        "php-pdo" | "php" => Box::new(php_pdo::PhpPdoBackend::new(canonical_engine)?),
446        "php-amphp" | "amphp" => Box::new(php_amphp::PhpAmphpBackend::new(canonical_engine)?),
447        // MSSQL backends
448        "rust-tiberius" | "tiberius" => {
449            Box::new(rust_tiberius::RustTiberiusBackend::new(canonical_engine)?)
450        }
451        "python-pyodbc" | "pyodbc" => {
452            Box::new(python_pyodbc::PythonPyodbcBackend::new(canonical_engine)?)
453        }
454        "typescript-mssql" | "tedious" => Box::new(typescript_mssql::TypescriptMssqlBackend::new(
455            canonical_engine,
456        )?),
457        "csharp-sqlclient" => Box::new(csharp_sqlclient::CsharpSqlClientBackend::new(
458            canonical_engine,
459        )?),
460        "ruby-tiny-tds" | "tiny-tds" | "tiny_tds" => {
461            Box::new(ruby_tiny_tds::RubyTinyTdsBackend::new(canonical_engine)?)
462        }
463        "elixir-tds" | "tds" => Box::new(elixir_tds::ElixirTdsBackend::new(canonical_engine)?),
464        // Oracle backends
465        "rust-sibyl" | "sibyl" => Box::new(rust_sibyl::RustSibylBackend::new(canonical_engine)?),
466        "python-oracledb" | "oracledb" => Box::new(python_oracledb::PythonOracledbBackend::new(
467            canonical_engine,
468        )?),
469        "typescript-oracledb" => Box::new(typescript_oracledb::TypescriptOracledbBackend::new(
470            canonical_engine,
471        )?),
472        "go-godror" | "godror" => Box::new(go_godror::GoGodrorBackend::new(canonical_engine)?),
473        "csharp-oracle" => Box::new(csharp_oracle::CsharpOracleBackend::new(canonical_engine)?),
474        "ruby-oci8" | "oci8" => Box::new(ruby_oci8::RubyOci8Backend::new(canonical_engine)?),
475        "elixir-jamdb" | "jamdb" => {
476            Box::new(elixir_jamdb::ElixirJamdbBackend::new(canonical_engine)?)
477        }
478        // Snowflake backends
479        "python-snowflake" => Box::new(python_snowflake::PythonSnowflakeBackend::new(
480            canonical_engine,
481        )?),
482        "typescript-snowflake" => Box::new(typescript_snowflake::TypescriptSnowflakeBackend::new(
483            canonical_engine,
484        )?),
485        "go-gosnowflake" | "gosnowflake" => {
486            Box::new(go_gosnowflake::GoGosnowflakeBackend::new(canonical_engine)?)
487        }
488        "csharp-snowflake" => Box::new(csharp_snowflake::CsharpSnowflakeBackend::new(
489            canonical_engine,
490        )?),
491        _ => {
492            return Err(ScytheError::new(
493                ErrorCode::InternalError,
494                format!("unknown backend: {}", name),
495            ));
496        }
497    };
498
499    // Validate engine is supported by this backend
500    if !backend
501        .supported_engines()
502        .iter()
503        .any(|e| normalize_engine(e) == canonical_engine)
504    {
505        return Err(ScytheError::new(
506            ErrorCode::InternalError,
507            format!(
508                "backend '{}' does not support engine '{}'. Supported: {:?}",
509                name,
510                engine,
511                backend.supported_engines()
512            ),
513        ));
514    }
515
516    Ok(backend)
517}
518
519/// Normalize engine name to canonical form.
520fn normalize_engine(engine: &str) -> &str {
521    match engine {
522        "postgresql" | "postgres" | "pg" | "cockroachdb" | "crdb" => "postgresql",
523        "mysql" => "mysql",
524        "mariadb" => "mariadb",
525        "sqlite" | "sqlite3" => "sqlite",
526        "duckdb" => "duckdb",
527        "mssql" | "sqlserver" | "tsql" => "mssql",
528        "oracle" => "oracle",
529        "snowflake" => "snowflake",
530        "redshift" => "redshift",
531        other => other,
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    fn param(name: &str, position: i64) -> AnalyzedParam {
540        AnalyzedParam {
541            name: name.to_string(),
542            neutral_type: "string".to_string(),
543            nullable: true,
544            position,
545        }
546    }
547
548    #[test]
549    fn test_normalize_engine_cockroachdb() {
550        assert_eq!(normalize_engine("cockroachdb"), "postgresql");
551        assert_eq!(normalize_engine("crdb"), "postgresql");
552    }
553
554    #[test]
555    fn test_get_backend_cockroachdb_with_pg_backends() {
556        // CockroachDB should work with all PostgreSQL-compatible backends
557        let pg_backends = [
558            "rust-sqlx",
559            "rust-tokio-postgres",
560            "python-psycopg3",
561            "python-asyncpg",
562            "typescript-postgres",
563            "typescript-pg",
564            "go-pgx",
565            "ruby-pg",
566            "elixir-postgrex",
567            "csharp-npgsql",
568            "php-pdo",
569            "php-amphp",
570        ];
571        for backend_name in &pg_backends {
572            let result = get_backend(backend_name, "cockroachdb");
573            assert!(
574                result.is_ok(),
575                "backend '{}' should accept cockroachdb engine, got: {:?}",
576                backend_name,
577                result.err()
578            );
579        }
580    }
581
582    #[test]
583    fn test_get_backend_crdb_alias() {
584        let result = get_backend("rust-sqlx", "crdb");
585        assert!(
586            result.is_ok(),
587            "rust-sqlx should accept 'crdb' engine alias"
588        );
589    }
590
591    #[test]
592    fn test_normalize_engine_duckdb() {
593        assert_eq!(normalize_engine("duckdb"), "duckdb");
594    }
595
596    #[test]
597    fn test_get_backend_duckdb_with_compatible_backends() {
598        let duckdb_backends = [
599            "python-duckdb",
600            "typescript-duckdb",
601            "go-database-sql",
602            "java-jdbc",
603            "kotlin-jdbc",
604        ];
605        for backend_name in &duckdb_backends {
606            let result = get_backend(backend_name, "duckdb");
607            assert!(
608                result.is_ok(),
609                "backend '{}' should accept duckdb engine, got: {:?}",
610                backend_name,
611                result.err()
612            );
613        }
614    }
615
616    #[test]
617    fn test_get_backend_duckdb_rejected_by_pg_only() {
618        let result = get_backend("rust-sqlx", "duckdb");
619        assert!(result.is_err(), "rust-sqlx should reject duckdb engine");
620    }
621
622    #[test]
623    fn test_rewrite_simple_equality() {
624        let sql = "SELECT * FROM users WHERE status = $1";
625        let params = vec![param("status", 1)];
626        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
627        assert_eq!(
628            result,
629            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1)"
630        );
631    }
632
633    #[test]
634    fn test_rewrite_qualified_column() {
635        let sql = "SELECT * FROM users u WHERE u.status = $1";
636        let params = vec![param("status", 1)];
637        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
638        assert_eq!(
639            result,
640            "SELECT * FROM users u WHERE ($1 IS NULL OR u.status = $1)"
641        );
642    }
643
644    #[test]
645    fn test_rewrite_multiple_optional() {
646        let sql = "SELECT * FROM users WHERE status = $1 AND name = $2";
647        let params = vec![param("status", 1), param("name", 2)];
648        let result =
649            rewrite_optional_params(sql, &["status".to_string(), "name".to_string()], &params);
650        assert_eq!(
651            result,
652            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1) AND ($2 IS NULL OR name = $2)"
653        );
654    }
655
656    #[test]
657    fn test_rewrite_mixed_optional_required() {
658        let sql = "SELECT * FROM users WHERE id = $1 AND status = $2";
659        let params = vec![param("id", 1), param("status", 2)];
660        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
661        assert_eq!(
662            result,
663            "SELECT * FROM users WHERE id = $1 AND ($2 IS NULL OR status = $2)"
664        );
665    }
666
667    #[test]
668    fn test_rewrite_like_operator() {
669        let sql = "SELECT * FROM users WHERE name LIKE $1";
670        let params = vec![param("name", 1)];
671        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
672        assert_eq!(
673            result,
674            "SELECT * FROM users WHERE ($1 IS NULL OR name LIKE $1)"
675        );
676    }
677
678    #[test]
679    fn test_rewrite_ilike_operator() {
680        let sql = "SELECT * FROM users WHERE name ILIKE $1";
681        let params = vec![param("name", 1)];
682        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
683        assert_eq!(
684            result,
685            "SELECT * FROM users WHERE ($1 IS NULL OR name ILIKE $1)"
686        );
687    }
688
689    #[test]
690    fn test_rewrite_comparison_operators() {
691        let sql = "SELECT * FROM users WHERE age >= $1";
692        let params = vec![param("age", 1)];
693        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
694        assert_eq!(
695            result,
696            "SELECT * FROM users WHERE ($1 IS NULL OR age >= $1)"
697        );
698    }
699
700    #[test]
701    fn test_rewrite_less_than() {
702        let sql = "SELECT * FROM users WHERE age < $1";
703        let params = vec![param("age", 1)];
704        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
705        assert_eq!(result, "SELECT * FROM users WHERE ($1 IS NULL OR age < $1)");
706    }
707
708    #[test]
709    fn test_no_rewrite_without_optional() {
710        let sql = "SELECT * FROM users WHERE status = $1";
711        let params = vec![param("status", 1)];
712        let result = rewrite_optional_params(sql, &[], &params);
713        assert_eq!(result, sql);
714    }
715
716    #[test]
717    fn test_rewrite_not_equal() {
718        let sql = "SELECT * FROM users WHERE status <> $1";
719        let params = vec![param("status", 1)];
720        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
721        assert_eq!(
722            result,
723            "SELECT * FROM users WHERE ($1 IS NULL OR status <> $1)"
724        );
725    }
726
727    #[test]
728    fn test_rewrite_does_not_match_similar_placeholder() {
729        // $1 should not match $10
730        let sql = "SELECT * FROM users WHERE status = $10";
731        let params = vec![param("status", 1)];
732        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
733        // $1 placeholder doesn't appear, so no rewrite
734        assert_eq!(result, sql);
735    }
736
737    #[test]
738    fn test_normalize_engine_mariadb() {
739        assert_eq!(normalize_engine("mariadb"), "mariadb");
740    }
741
742    #[test]
743    fn test_get_backend_mariadb_with_mysql_backends() {
744        let mariadb_backends = [
745            "rust-sqlx",
746            "python-aiomysql",
747            "typescript-mysql2",
748            "go-database-sql",
749            "java-jdbc",
750            "java-r2dbc",
751            "kotlin-jdbc",
752            "kotlin-r2dbc",
753            "csharp-mysqlconnector",
754            "elixir-myxql",
755            "ruby-mysql2",
756            "ruby-trilogy",
757            "php-pdo",
758            "php-amphp",
759        ];
760        for backend_name in &mariadb_backends {
761            let result = get_backend(backend_name, "mariadb");
762            assert!(
763                result.is_ok(),
764                "backend '{}' should accept mariadb engine, got: {:?}",
765                backend_name,
766                result.err()
767            );
768        }
769    }
770}