Skip to main content

scythe_codegen/backends/
mod.rs

1pub(crate) mod csharp_microsoft_sqlite;
2pub(crate) mod csharp_mysqlconnector;
3pub(crate) mod csharp_npgsql;
4pub(crate) mod csharp_oracle;
5pub(crate) mod csharp_snowflake;
6pub(crate) mod csharp_sqlclient;
7pub(crate) mod elixir_ecto;
8pub(crate) mod elixir_exqlite;
9pub(crate) mod elixir_jamdb;
10pub(crate) mod elixir_myxql;
11pub(crate) mod elixir_postgrex;
12pub(crate) mod elixir_tds;
13pub(crate) mod go_database_sql;
14pub(crate) mod go_godror;
15pub(crate) mod go_gosnowflake;
16pub(crate) mod go_pgx;
17pub(crate) mod java_jdbc;
18pub(crate) mod java_r2dbc;
19pub(crate) mod kotlin_exposed;
20pub(crate) mod kotlin_jdbc;
21pub(crate) mod kotlin_r2dbc;
22pub(crate) mod php_amphp;
23pub(crate) mod php_pdo;
24pub(crate) mod python_aiomysql;
25pub(crate) mod python_aiosqlite;
26pub(crate) mod python_asyncpg;
27pub(crate) mod python_common;
28pub(crate) mod python_duckdb;
29pub(crate) mod python_oracledb;
30pub(crate) mod python_psycopg3;
31pub(crate) mod python_pyodbc;
32pub(crate) mod python_snowflake;
33pub(crate) mod ruby_mysql2;
34pub(crate) mod ruby_oci8;
35pub(crate) mod ruby_pg;
36pub(crate) mod ruby_rbs;
37pub(crate) mod ruby_sqlite3;
38pub(crate) mod ruby_tiny_tds;
39pub(crate) mod ruby_trilogy;
40pub(crate) mod rust_sibyl;
41pub(crate) mod rust_tiberius;
42pub(crate) mod sqlx;
43pub(crate) mod tokio_postgres;
44pub(crate) mod typescript_better_sqlite3;
45pub(crate) mod typescript_common;
46pub(crate) mod typescript_duckdb;
47pub(crate) mod typescript_mssql;
48pub(crate) mod typescript_mysql2;
49pub(crate) mod typescript_oracledb;
50pub(crate) mod typescript_pg;
51pub(crate) mod typescript_postgres;
52pub(crate) mod typescript_snowflake;
53
54use scythe_backend::manifest::BackendManifest;
55use scythe_core::analyzer::AnalyzedParam;
56use scythe_core::errors::{ErrorCode, ScytheError};
57
58use crate::backend_trait::CodegenBackend;
59
60/// Load a backend manifest, preferring a user-provided file at `override_path`
61/// and falling back to the embedded `default_toml` string.
62pub(crate) fn load_or_default_manifest(
63    override_path: &str,
64    default_toml: &str,
65) -> Result<BackendManifest, ScytheError> {
66    let path = std::path::Path::new(override_path);
67    if path.exists() {
68        scythe_backend::manifest::load_manifest(path)
69            .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))
70    } else {
71        toml::from_str(default_toml)
72            .map_err(|e| ScytheError::new(ErrorCode::InternalError, format!("manifest: {e}")))
73    }
74}
75
76/// Strip SQL comments, trailing semicolons, and excess whitespace.
77/// Preserves newlines between lines.
78pub(crate) fn clean_sql(sql: &str) -> String {
79    sql.lines()
80        .filter(|line| !line.trim_start().starts_with("--"))
81        .collect::<Vec<_>>()
82        .join("\n")
83        .trim()
84        .trim_end_matches(';')
85        .trim()
86        .to_string()
87}
88
89/// Like clean_sql but joins lines with spaces (for languages that embed SQL inline).
90pub(crate) fn clean_sql_oneline(sql: &str) -> String {
91    sql.lines()
92        .filter(|line| !line.trim_start().starts_with("--"))
93        .collect::<Vec<_>>()
94        .join(" ")
95        .trim()
96        .trim_end_matches(';')
97        .trim()
98        .to_string()
99}
100
101/// Rewrite SQL for optional parameters.
102///
103/// For each optional param, finds `column = $N` (or `column <> $N`, `column != $N`)
104/// and rewrites to `($N IS NULL OR column = $N)`. This allows callers to pass NULL
105/// to skip a filter condition at runtime.
106///
107/// This operates on the raw SQL before any backend-specific placeholder rewriting.
108pub(crate) fn rewrite_optional_params(
109    sql: &str,
110    optional_params: &[String],
111    params: &[AnalyzedParam],
112) -> String {
113    if optional_params.is_empty() {
114        return sql.to_string();
115    }
116
117    let mut result = sql.to_string();
118
119    for opt_name in optional_params {
120        let Some(param) = params.iter().find(|p| p.name == *opt_name) else {
121            continue;
122        };
123        let placeholder = format!("${}", param.position);
124
125        // Try each comparison operator
126        for op in &[
127            ">=", "<=", "<>", "!=", ">", "<", "=", "ILIKE", "ilike", "LIKE", "like",
128        ] {
129            result = rewrite_comparison(&result, &placeholder, op);
130        }
131    }
132
133    result
134}
135
136/// Rewrite a single `column <op> $N` pattern to `($N IS NULL OR column <op> $N)`.
137/// Handles both `column <op> $N` and `$N <op> column` orderings.
138fn rewrite_comparison(sql: &str, placeholder: &str, op: &str) -> String {
139    let mut result = String::with_capacity(sql.len() + 32);
140    let chars: Vec<char> = sql.chars().collect();
141    let len = chars.len();
142    let mut i = 0;
143
144    while i < len {
145        // Try to match `identifier <op> $N` at this position
146        if let Some((_start, col, end)) = try_match_col_op_ph(&chars, i, op, placeholder) {
147            result.push_str(&format!(
148                "({placeholder} IS NULL OR {col} {op} {placeholder})"
149            ));
150            i = end;
151            continue;
152        }
153
154        // Try to match `$N <op> identifier` at this position
155        if let Some((end, col)) = try_match_ph_op_col(&chars, i, op, placeholder) {
156            result.push_str(&format!(
157                "({placeholder} IS NULL OR {col} {op} {placeholder})"
158            ));
159            i = end;
160            continue;
161        }
162
163        result.push(chars[i]);
164        i += 1;
165    }
166
167    result
168}
169
170/// Try to match `identifier <ws>* <op> <ws>* placeholder` starting at position `i`.
171/// Returns `(match_start, column_name, match_end)` if found.
172fn try_match_col_op_ph(
173    chars: &[char],
174    i: usize,
175    op: &str,
176    placeholder: &str,
177) -> Option<(usize, String, usize)> {
178    // Must start with an identifier character (word char)
179    if !is_ident_char(chars[i]) {
180        return None;
181    }
182    // Must not be preceded by another ident char (whole-word boundary)
183    if i > 0 && is_ident_char(chars[i - 1]) {
184        return None;
185    }
186
187    // Read the identifier
188    let ident_start = i;
189    let mut j = i;
190    while j < chars.len() && is_ident_char(chars[j]) {
191        j += 1;
192    }
193    let ident: String = chars[ident_start..j].iter().collect();
194
195    // Skip whitespace
196    while j < chars.len() && chars[j].is_whitespace() {
197        j += 1;
198    }
199
200    // Match operator
201    let op_chars: Vec<char> = op.chars().collect();
202    if j + op_chars.len() > chars.len() {
203        return None;
204    }
205    for (k, oc) in op_chars.iter().enumerate() {
206        if chars[j + k] != *oc {
207            return None;
208        }
209    }
210    j += op_chars.len();
211
212    // Skip whitespace
213    while j < chars.len() && chars[j].is_whitespace() {
214        j += 1;
215    }
216
217    // Match placeholder
218    let ph_chars: Vec<char> = placeholder.chars().collect();
219    if j + ph_chars.len() > chars.len() {
220        return None;
221    }
222    for (k, pc) in ph_chars.iter().enumerate() {
223        if chars[j + k] != *pc {
224            return None;
225        }
226    }
227    j += ph_chars.len();
228
229    // Ensure placeholder is not followed by a digit (e.g., $1 vs $10)
230    if j < chars.len() && chars[j].is_ascii_digit() {
231        return None;
232    }
233
234    Some((i, ident, j))
235}
236
237/// Try to match `placeholder <ws>* <op> <ws>* identifier` starting at position `i`.
238/// Returns `(match_end, column_name)` if found.
239fn try_match_ph_op_col(
240    chars: &[char],
241    i: usize,
242    op: &str,
243    placeholder: &str,
244) -> Option<(usize, String)> {
245    let ph_chars: Vec<char> = placeholder.chars().collect();
246    if i + ph_chars.len() > chars.len() {
247        return None;
248    }
249
250    // Must not be preceded by $ or digit (boundary check)
251    if i > 0 && (chars[i - 1] == '$' || chars[i - 1].is_ascii_digit()) {
252        return None;
253    }
254
255    // Match placeholder
256    for (k, pc) in ph_chars.iter().enumerate() {
257        if chars[i + k] != *pc {
258            return None;
259        }
260    }
261    let mut j = i + ph_chars.len();
262
263    // Ensure placeholder is not followed by a digit
264    if j < chars.len() && chars[j].is_ascii_digit() {
265        return None;
266    }
267
268    // Skip whitespace
269    while j < chars.len() && chars[j].is_whitespace() {
270        j += 1;
271    }
272
273    // Match operator
274    let op_chars: Vec<char> = op.chars().collect();
275    if j + op_chars.len() > chars.len() {
276        return None;
277    }
278    for (k, oc) in op_chars.iter().enumerate() {
279        if chars[j + k] != *oc {
280            return None;
281        }
282    }
283    j += op_chars.len();
284
285    // Skip whitespace
286    while j < chars.len() && chars[j].is_whitespace() {
287        j += 1;
288    }
289
290    // Read the identifier
291    if j >= chars.len() || !is_ident_char(chars[j]) {
292        return None;
293    }
294    let ident_start = j;
295    while j < chars.len() && is_ident_char(chars[j]) {
296        j += 1;
297    }
298    let ident: String = chars[ident_start..j].iter().collect();
299
300    // Avoid matching "NULL" (from already-rewritten text)
301    if ident == "NULL" {
302        return None;
303    }
304
305    Some((j, ident))
306}
307
308/// Clean SQL and apply optional parameter rewriting.
309pub(crate) fn clean_sql_with_optional(
310    sql: &str,
311    optional_params: &[String],
312    params: &[AnalyzedParam],
313) -> String {
314    let cleaned = clean_sql(sql);
315    rewrite_optional_params(&cleaned, optional_params, params)
316}
317
318/// Clean SQL (oneline) and apply optional parameter rewriting.
319pub(crate) fn clean_sql_oneline_with_optional(
320    sql: &str,
321    optional_params: &[String],
322    params: &[AnalyzedParam],
323) -> String {
324    let cleaned = clean_sql_oneline(sql);
325    rewrite_optional_params(&cleaned, optional_params, params)
326}
327
328fn is_ident_char(c: char) -> bool {
329    c.is_alphanumeric() || c == '_' || c == '.'
330}
331
332/// Rewrite PostgreSQL `$1, $2, ...` positional placeholders to a target format.
333/// Rewrite SQL placeholders (`$N` or `?`) to a target format.
334/// Skips placeholders inside single-quoted SQL string literals.
335/// The `formatter` closure receives the parameter number (1-based) and returns the replacement.
336/// Handles both PostgreSQL `$N` and positional `?` placeholders.
337pub(crate) fn rewrite_pg_placeholders(sql: &str, formatter: impl Fn(u32) -> String) -> String {
338    let mut result = String::with_capacity(sql.len());
339    let mut chars = sql.chars().peekable();
340    let mut positional_counter: u32 = 0;
341    while let Some(ch) = chars.next() {
342        if ch == '\'' {
343            result.push(ch);
344            while let Some(inner) = chars.next() {
345                result.push(inner);
346                if inner == '\'' {
347                    if chars.peek() == Some(&'\'') {
348                        result.push(chars.next().unwrap());
349                    } else {
350                        break;
351                    }
352                }
353            }
354        } else if ch == '$' {
355            if chars.peek().is_some_and(|c| c.is_ascii_digit()) {
356                let mut num_str = String::new();
357                while chars.peek().is_some_and(|c| c.is_ascii_digit()) {
358                    num_str.push(chars.next().unwrap());
359                }
360                let num: u32 = num_str.parse().unwrap_or(0);
361                result.push_str(&formatter(num));
362            } else {
363                result.push(ch);
364            }
365        } else if ch == '?' && !chars.peek().is_some_and(|c| c.is_ascii_digit()) {
366            positional_counter += 1;
367            result.push_str(&formatter(positional_counter));
368        } else {
369            result.push(ch);
370        }
371    }
372    result
373}
374
375/// Get a backend by name and database engine.
376///
377/// The `engine` parameter (e.g., "postgresql", "mysql", "sqlite") determines
378/// which manifest is loaded for type mappings. PG-only backends reject non-PG engines.
379pub fn get_backend(name: &str, engine: &str) -> Result<Box<dyn CodegenBackend>, ScytheError> {
380    // Normalize engine aliases (e.g., "cockroachdb" -> "postgresql") before
381    // passing to backend constructors so each backend only needs to match
382    // canonical engine names.
383    let canonical_engine = normalize_engine(engine);
384    let backend: Box<dyn CodegenBackend> = match name {
385        "rust-sqlx" | "sqlx" | "rust" => Box::new(sqlx::SqlxBackend::new(canonical_engine)?),
386        "rust-tokio-postgres" | "tokio-postgres" => {
387            Box::new(tokio_postgres::TokioPostgresBackend::new(canonical_engine)?)
388        }
389        "python-psycopg3" | "python" => Box::new(python_psycopg3::PythonPsycopg3Backend::new(
390            canonical_engine,
391        )?),
392        "python-asyncpg" => Box::new(python_asyncpg::PythonAsyncpgBackend::new(canonical_engine)?),
393        "python-aiomysql" => Box::new(python_aiomysql::PythonAiomysqlBackend::new(
394            canonical_engine,
395        )?),
396        "python-aiosqlite" => Box::new(python_aiosqlite::PythonAiosqliteBackend::new(
397            canonical_engine,
398        )?),
399        "python-duckdb" => Box::new(python_duckdb::PythonDuckdbBackend::new(canonical_engine)?),
400        "typescript-postgres" | "ts" | "typescript" => Box::new(
401            typescript_postgres::TypescriptPostgresBackend::new(canonical_engine)?,
402        ),
403        "typescript-pg" => Box::new(typescript_pg::TypescriptPgBackend::new(canonical_engine)?),
404        "typescript-mysql2" => Box::new(typescript_mysql2::TypescriptMysql2Backend::new(
405            canonical_engine,
406        )?),
407        "typescript-better-sqlite3" => Box::new(
408            typescript_better_sqlite3::TypescriptBetterSqlite3Backend::new(canonical_engine)?,
409        ),
410        "typescript-duckdb" => Box::new(typescript_duckdb::TypescriptDuckdbBackend::new(
411            canonical_engine,
412        )?),
413        "go-database-sql" => Box::new(go_database_sql::GoDatabaseSqlBackend::new(
414            canonical_engine,
415        )?),
416        "go-pgx" | "go" => Box::new(go_pgx::GoPgxBackend::new(canonical_engine)?),
417        "java-jdbc" | "java" => Box::new(java_jdbc::JavaJdbcBackend::new(canonical_engine)?),
418        "java-r2dbc" | "r2dbc-java" => {
419            Box::new(java_r2dbc::JavaR2dbcBackend::new(canonical_engine)?)
420        }
421        "kotlin-exposed" | "exposed" => {
422            Box::new(kotlin_exposed::KotlinExposedBackend::new(canonical_engine)?)
423        }
424        "kotlin-jdbc" | "kotlin" | "kt" => {
425            Box::new(kotlin_jdbc::KotlinJdbcBackend::new(canonical_engine)?)
426        }
427        "kotlin-r2dbc" | "r2dbc-kotlin" => {
428            Box::new(kotlin_r2dbc::KotlinR2dbcBackend::new(canonical_engine)?)
429        }
430        "csharp-npgsql" | "csharp" | "c#" | "dotnet" => {
431            Box::new(csharp_npgsql::CsharpNpgsqlBackend::new(canonical_engine)?)
432        }
433        "csharp-mysqlconnector" => Box::new(
434            csharp_mysqlconnector::CsharpMysqlConnectorBackend::new(canonical_engine)?,
435        ),
436        "csharp-microsoft-sqlite" => Box::new(
437            csharp_microsoft_sqlite::CsharpMicrosoftSqliteBackend::new(canonical_engine)?,
438        ),
439        "elixir-postgrex" | "elixir" | "ex" => Box::new(
440            elixir_postgrex::ElixirPostgrexBackend::new(canonical_engine)?,
441        ),
442        "elixir-ecto" | "ecto" => Box::new(elixir_ecto::ElixirEctoBackend::new(canonical_engine)?),
443        "elixir-myxql" => Box::new(elixir_myxql::ElixirMyxqlBackend::new(canonical_engine)?),
444        "elixir-exqlite" => Box::new(elixir_exqlite::ElixirExqliteBackend::new(canonical_engine)?),
445        "ruby-pg" | "ruby" | "rb" => Box::new(ruby_pg::RubyPgBackend::new(canonical_engine)?),
446        "ruby-mysql2" => Box::new(ruby_mysql2::RubyMysql2Backend::new(canonical_engine)?),
447        "ruby-sqlite3" => Box::new(ruby_sqlite3::RubySqlite3Backend::new(canonical_engine)?),
448        "ruby-trilogy" | "trilogy" => {
449            Box::new(ruby_trilogy::RubyTrilogyBackend::new(canonical_engine)?)
450        }
451        "php-pdo" | "php" => Box::new(php_pdo::PhpPdoBackend::new(canonical_engine)?),
452        "php-amphp" | "amphp" => Box::new(php_amphp::PhpAmphpBackend::new(canonical_engine)?),
453        // MSSQL backends
454        "rust-tiberius" | "tiberius" => {
455            Box::new(rust_tiberius::RustTiberiusBackend::new(canonical_engine)?)
456        }
457        "python-pyodbc" | "pyodbc" => {
458            Box::new(python_pyodbc::PythonPyodbcBackend::new(canonical_engine)?)
459        }
460        "typescript-mssql" | "tedious" => Box::new(typescript_mssql::TypescriptMssqlBackend::new(
461            canonical_engine,
462        )?),
463        "csharp-sqlclient" => Box::new(csharp_sqlclient::CsharpSqlClientBackend::new(
464            canonical_engine,
465        )?),
466        "ruby-tiny-tds" | "tiny-tds" | "tiny_tds" => {
467            Box::new(ruby_tiny_tds::RubyTinyTdsBackend::new(canonical_engine)?)
468        }
469        "elixir-tds" | "tds" => Box::new(elixir_tds::ElixirTdsBackend::new(canonical_engine)?),
470        // Oracle backends
471        "rust-sibyl" | "sibyl" => Box::new(rust_sibyl::RustSibylBackend::new(canonical_engine)?),
472        "python-oracledb" | "oracledb" => Box::new(python_oracledb::PythonOracledbBackend::new(
473            canonical_engine,
474        )?),
475        "typescript-oracledb" => Box::new(typescript_oracledb::TypescriptOracledbBackend::new(
476            canonical_engine,
477        )?),
478        "go-godror" | "godror" => Box::new(go_godror::GoGodrorBackend::new(canonical_engine)?),
479        "csharp-oracle" => Box::new(csharp_oracle::CsharpOracleBackend::new(canonical_engine)?),
480        "ruby-oci8" | "oci8" => Box::new(ruby_oci8::RubyOci8Backend::new(canonical_engine)?),
481        "elixir-jamdb" | "jamdb" => {
482            Box::new(elixir_jamdb::ElixirJamdbBackend::new(canonical_engine)?)
483        }
484        // Snowflake backends
485        "python-snowflake" => Box::new(python_snowflake::PythonSnowflakeBackend::new(
486            canonical_engine,
487        )?),
488        "typescript-snowflake" => Box::new(typescript_snowflake::TypescriptSnowflakeBackend::new(
489            canonical_engine,
490        )?),
491        "go-gosnowflake" | "gosnowflake" => {
492            Box::new(go_gosnowflake::GoGosnowflakeBackend::new(canonical_engine)?)
493        }
494        "csharp-snowflake" => Box::new(csharp_snowflake::CsharpSnowflakeBackend::new(
495            canonical_engine,
496        )?),
497        _ => {
498            return Err(ScytheError::new(
499                ErrorCode::InternalError,
500                format!("unknown backend: {}", name),
501            ));
502        }
503    };
504
505    // Validate engine is supported by this backend
506    if !backend
507        .supported_engines()
508        .iter()
509        .any(|e| normalize_engine(e) == canonical_engine)
510    {
511        return Err(ScytheError::new(
512            ErrorCode::InternalError,
513            format!(
514                "backend '{}' does not support engine '{}'. Supported: {:?}",
515                name,
516                engine,
517                backend.supported_engines()
518            ),
519        ));
520    }
521
522    Ok(backend)
523}
524
525/// Normalize engine name to canonical form.
526fn normalize_engine(engine: &str) -> &str {
527    match engine {
528        "postgresql" | "postgres" | "pg" | "cockroachdb" | "crdb" => "postgresql",
529        "mysql" => "mysql",
530        "mariadb" => "mariadb",
531        "sqlite" | "sqlite3" => "sqlite",
532        "duckdb" => "duckdb",
533        "mssql" | "sqlserver" | "tsql" => "mssql",
534        "oracle" => "oracle",
535        "snowflake" => "snowflake",
536        "redshift" => "redshift",
537        other => other,
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    fn param(name: &str, position: i64) -> AnalyzedParam {
546        AnalyzedParam {
547            name: name.to_string(),
548            neutral_type: "string".to_string(),
549            nullable: true,
550            position,
551        }
552    }
553
554    #[test]
555    fn test_normalize_engine_cockroachdb() {
556        assert_eq!(normalize_engine("cockroachdb"), "postgresql");
557        assert_eq!(normalize_engine("crdb"), "postgresql");
558    }
559
560    #[test]
561    fn test_get_backend_cockroachdb_with_pg_backends() {
562        // CockroachDB should work with all PostgreSQL-compatible backends
563        let pg_backends = [
564            "rust-sqlx",
565            "rust-tokio-postgres",
566            "python-psycopg3",
567            "python-asyncpg",
568            "typescript-postgres",
569            "typescript-pg",
570            "go-pgx",
571            "ruby-pg",
572            "elixir-postgrex",
573            "csharp-npgsql",
574            "php-pdo",
575            "php-amphp",
576        ];
577        for backend_name in &pg_backends {
578            let result = get_backend(backend_name, "cockroachdb");
579            assert!(
580                result.is_ok(),
581                "backend '{}' should accept cockroachdb engine, got: {:?}",
582                backend_name,
583                result.err()
584            );
585        }
586    }
587
588    #[test]
589    fn test_get_backend_crdb_alias() {
590        let result = get_backend("rust-sqlx", "crdb");
591        assert!(
592            result.is_ok(),
593            "rust-sqlx should accept 'crdb' engine alias"
594        );
595    }
596
597    #[test]
598    fn test_normalize_engine_duckdb() {
599        assert_eq!(normalize_engine("duckdb"), "duckdb");
600    }
601
602    #[test]
603    fn test_get_backend_duckdb_with_compatible_backends() {
604        let duckdb_backends = [
605            "python-duckdb",
606            "typescript-duckdb",
607            "go-database-sql",
608            "java-jdbc",
609            "kotlin-jdbc",
610        ];
611        for backend_name in &duckdb_backends {
612            let result = get_backend(backend_name, "duckdb");
613            assert!(
614                result.is_ok(),
615                "backend '{}' should accept duckdb engine, got: {:?}",
616                backend_name,
617                result.err()
618            );
619        }
620    }
621
622    #[test]
623    fn test_get_backend_duckdb_rejected_by_pg_only() {
624        let result = get_backend("rust-sqlx", "duckdb");
625        assert!(result.is_err(), "rust-sqlx should reject duckdb engine");
626    }
627
628    #[test]
629    fn test_rewrite_simple_equality() {
630        let sql = "SELECT * FROM users WHERE status = $1";
631        let params = vec![param("status", 1)];
632        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
633        assert_eq!(
634            result,
635            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1)"
636        );
637    }
638
639    #[test]
640    fn test_rewrite_qualified_column() {
641        let sql = "SELECT * FROM users u WHERE u.status = $1";
642        let params = vec![param("status", 1)];
643        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
644        assert_eq!(
645            result,
646            "SELECT * FROM users u WHERE ($1 IS NULL OR u.status = $1)"
647        );
648    }
649
650    #[test]
651    fn test_rewrite_multiple_optional() {
652        let sql = "SELECT * FROM users WHERE status = $1 AND name = $2";
653        let params = vec![param("status", 1), param("name", 2)];
654        let result =
655            rewrite_optional_params(sql, &["status".to_string(), "name".to_string()], &params);
656        assert_eq!(
657            result,
658            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1) AND ($2 IS NULL OR name = $2)"
659        );
660    }
661
662    #[test]
663    fn test_rewrite_mixed_optional_required() {
664        let sql = "SELECT * FROM users WHERE id = $1 AND status = $2";
665        let params = vec![param("id", 1), param("status", 2)];
666        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
667        assert_eq!(
668            result,
669            "SELECT * FROM users WHERE id = $1 AND ($2 IS NULL OR status = $2)"
670        );
671    }
672
673    #[test]
674    fn test_rewrite_like_operator() {
675        let sql = "SELECT * FROM users WHERE name LIKE $1";
676        let params = vec![param("name", 1)];
677        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
678        assert_eq!(
679            result,
680            "SELECT * FROM users WHERE ($1 IS NULL OR name LIKE $1)"
681        );
682    }
683
684    #[test]
685    fn test_rewrite_ilike_operator() {
686        let sql = "SELECT * FROM users WHERE name ILIKE $1";
687        let params = vec![param("name", 1)];
688        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
689        assert_eq!(
690            result,
691            "SELECT * FROM users WHERE ($1 IS NULL OR name ILIKE $1)"
692        );
693    }
694
695    #[test]
696    fn test_rewrite_comparison_operators() {
697        let sql = "SELECT * FROM users WHERE age >= $1";
698        let params = vec![param("age", 1)];
699        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
700        assert_eq!(
701            result,
702            "SELECT * FROM users WHERE ($1 IS NULL OR age >= $1)"
703        );
704    }
705
706    #[test]
707    fn test_rewrite_less_than() {
708        let sql = "SELECT * FROM users WHERE age < $1";
709        let params = vec![param("age", 1)];
710        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
711        assert_eq!(result, "SELECT * FROM users WHERE ($1 IS NULL OR age < $1)");
712    }
713
714    #[test]
715    fn test_no_rewrite_without_optional() {
716        let sql = "SELECT * FROM users WHERE status = $1";
717        let params = vec![param("status", 1)];
718        let result = rewrite_optional_params(sql, &[], &params);
719        assert_eq!(result, sql);
720    }
721
722    #[test]
723    fn test_rewrite_not_equal() {
724        let sql = "SELECT * FROM users WHERE status <> $1";
725        let params = vec![param("status", 1)];
726        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
727        assert_eq!(
728            result,
729            "SELECT * FROM users WHERE ($1 IS NULL OR status <> $1)"
730        );
731    }
732
733    #[test]
734    fn test_rewrite_does_not_match_similar_placeholder() {
735        // $1 should not match $10
736        let sql = "SELECT * FROM users WHERE status = $10";
737        let params = vec![param("status", 1)];
738        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
739        // $1 placeholder doesn't appear, so no rewrite
740        assert_eq!(result, sql);
741    }
742
743    #[test]
744    fn test_normalize_engine_mariadb() {
745        assert_eq!(normalize_engine("mariadb"), "mariadb");
746    }
747
748    #[test]
749    fn test_get_backend_mariadb_with_mysql_backends() {
750        let mariadb_backends = [
751            "rust-sqlx",
752            "python-aiomysql",
753            "typescript-mysql2",
754            "go-database-sql",
755            "java-jdbc",
756            "java-r2dbc",
757            "kotlin-jdbc",
758            "kotlin-r2dbc",
759            "csharp-mysqlconnector",
760            "elixir-myxql",
761            "ruby-mysql2",
762            "ruby-trilogy",
763            "php-pdo",
764            "php-amphp",
765        ];
766        for backend_name in &mariadb_backends {
767            let result = get_backend(backend_name, "mariadb");
768            assert!(
769                result.is_ok(),
770                "backend '{}' should accept mariadb engine, got: {:?}",
771                backend_name,
772                result.err()
773            );
774        }
775    }
776}