Skip to main content

scythe_codegen/backends/
mod.rs

1pub mod csharp_microsoft_sqlite;
2pub mod csharp_mysqlconnector;
3pub mod csharp_npgsql;
4pub mod elixir_ecto;
5pub mod elixir_exqlite;
6pub mod elixir_myxql;
7pub mod elixir_postgrex;
8pub mod go_database_sql;
9pub mod go_pgx;
10pub mod java_jdbc;
11pub mod kotlin_jdbc;
12pub mod php_amphp;
13pub mod php_pdo;
14pub mod python_aiomysql;
15pub mod python_aiosqlite;
16pub mod python_asyncpg;
17pub mod python_common;
18pub mod python_psycopg3;
19pub mod ruby_mysql2;
20pub mod ruby_pg;
21pub mod ruby_sqlite3;
22pub mod ruby_trilogy;
23pub mod sqlx;
24pub mod tokio_postgres;
25pub mod typescript_better_sqlite3;
26pub mod typescript_common;
27pub mod typescript_mysql2;
28pub mod typescript_pg;
29pub mod typescript_postgres;
30
31use scythe_core::analyzer::AnalyzedParam;
32use scythe_core::errors::{ErrorCode, ScytheError};
33
34use crate::backend_trait::CodegenBackend;
35
36/// Strip SQL comments, trailing semicolons, and excess whitespace.
37/// Preserves newlines between lines.
38pub(crate) fn clean_sql(sql: &str) -> String {
39    sql.lines()
40        .filter(|line| !line.trim_start().starts_with("--"))
41        .collect::<Vec<_>>()
42        .join("\n")
43        .trim()
44        .trim_end_matches(';')
45        .trim()
46        .to_string()
47}
48
49/// Like clean_sql but joins lines with spaces (for languages that embed SQL inline).
50pub(crate) fn clean_sql_oneline(sql: &str) -> String {
51    sql.lines()
52        .filter(|line| !line.trim_start().starts_with("--"))
53        .collect::<Vec<_>>()
54        .join(" ")
55        .trim()
56        .trim_end_matches(';')
57        .trim()
58        .to_string()
59}
60
61/// Rewrite SQL for optional parameters.
62///
63/// For each optional param, finds `column = $N` (or `column <> $N`, `column != $N`)
64/// and rewrites to `($N IS NULL OR column = $N)`. This allows callers to pass NULL
65/// to skip a filter condition at runtime.
66///
67/// This operates on the raw SQL before any backend-specific placeholder rewriting.
68pub(crate) fn rewrite_optional_params(
69    sql: &str,
70    optional_params: &[String],
71    params: &[AnalyzedParam],
72) -> String {
73    if optional_params.is_empty() {
74        return sql.to_string();
75    }
76
77    let mut result = sql.to_string();
78
79    for opt_name in optional_params {
80        let Some(param) = params.iter().find(|p| p.name == *opt_name) else {
81            continue;
82        };
83        let placeholder = format!("${}", param.position);
84
85        // Try each comparison operator
86        for op in &[
87            ">=", "<=", "<>", "!=", ">", "<", "=", "ILIKE", "ilike", "LIKE", "like",
88        ] {
89            result = rewrite_comparison(&result, &placeholder, op);
90        }
91    }
92
93    result
94}
95
96/// Rewrite a single `column <op> $N` pattern to `($N IS NULL OR column <op> $N)`.
97/// Handles both `column <op> $N` and `$N <op> column` orderings.
98fn rewrite_comparison(sql: &str, placeholder: &str, op: &str) -> String {
99    let mut result = String::with_capacity(sql.len() + 32);
100    let chars: Vec<char> = sql.chars().collect();
101    let len = chars.len();
102    let mut i = 0;
103
104    while i < len {
105        // Try to match `identifier <op> $N` at this position
106        if let Some((start, col, end)) = try_match_col_op_ph(&chars, i, op, placeholder) {
107            // Write everything before the match start
108            if start > i {
109                // This shouldn't happen since we iterate char by char
110            }
111            result.push_str(&format!(
112                "({placeholder} IS NULL OR {col} {op} {placeholder})"
113            ));
114            i = end;
115            continue;
116        }
117
118        // Try to match `$N <op> identifier` at this position
119        if let Some((end, col)) = try_match_ph_op_col(&chars, i, op, placeholder) {
120            result.push_str(&format!(
121                "({placeholder} IS NULL OR {col} {op} {placeholder})"
122            ));
123            i = end;
124            continue;
125        }
126
127        result.push(chars[i]);
128        i += 1;
129    }
130
131    result
132}
133
134/// Try to match `identifier <ws>* <op> <ws>* placeholder` starting at position `i`.
135/// Returns `(match_start, column_name, match_end)` if found.
136fn try_match_col_op_ph(
137    chars: &[char],
138    i: usize,
139    op: &str,
140    placeholder: &str,
141) -> Option<(usize, String, usize)> {
142    // Must start with an identifier character (word char)
143    if !is_ident_char(chars[i]) {
144        return None;
145    }
146    // Must not be preceded by another ident char (whole-word boundary)
147    if i > 0 && is_ident_char(chars[i - 1]) {
148        return None;
149    }
150
151    // Read the identifier
152    let ident_start = i;
153    let mut j = i;
154    while j < chars.len() && is_ident_char(chars[j]) {
155        j += 1;
156    }
157    let ident: String = chars[ident_start..j].iter().collect();
158
159    // Skip whitespace
160    while j < chars.len() && chars[j].is_whitespace() {
161        j += 1;
162    }
163
164    // Match operator
165    let op_chars: Vec<char> = op.chars().collect();
166    if j + op_chars.len() > chars.len() {
167        return None;
168    }
169    for (k, oc) in op_chars.iter().enumerate() {
170        if chars[j + k] != *oc {
171            return None;
172        }
173    }
174    j += op_chars.len();
175
176    // Skip whitespace
177    while j < chars.len() && chars[j].is_whitespace() {
178        j += 1;
179    }
180
181    // Match placeholder
182    let ph_chars: Vec<char> = placeholder.chars().collect();
183    if j + ph_chars.len() > chars.len() {
184        return None;
185    }
186    for (k, pc) in ph_chars.iter().enumerate() {
187        if chars[j + k] != *pc {
188            return None;
189        }
190    }
191    j += ph_chars.len();
192
193    // Ensure placeholder is not followed by a digit (e.g., $1 vs $10)
194    if j < chars.len() && chars[j].is_ascii_digit() {
195        return None;
196    }
197
198    Some((i, ident, j))
199}
200
201/// Try to match `placeholder <ws>* <op> <ws>* identifier` starting at position `i`.
202/// Returns `(match_end, column_name)` if found.
203fn try_match_ph_op_col(
204    chars: &[char],
205    i: usize,
206    op: &str,
207    placeholder: &str,
208) -> Option<(usize, String)> {
209    let ph_chars: Vec<char> = placeholder.chars().collect();
210    if i + ph_chars.len() > chars.len() {
211        return None;
212    }
213
214    // Must not be preceded by $ or digit (boundary check)
215    if i > 0 && (chars[i - 1] == '$' || chars[i - 1].is_ascii_digit()) {
216        return None;
217    }
218
219    // Match placeholder
220    for (k, pc) in ph_chars.iter().enumerate() {
221        if chars[i + k] != *pc {
222            return None;
223        }
224    }
225    let mut j = i + ph_chars.len();
226
227    // Ensure placeholder is not followed by a digit
228    if j < chars.len() && chars[j].is_ascii_digit() {
229        return None;
230    }
231
232    // Skip whitespace
233    while j < chars.len() && chars[j].is_whitespace() {
234        j += 1;
235    }
236
237    // Match operator
238    let op_chars: Vec<char> = op.chars().collect();
239    if j + op_chars.len() > chars.len() {
240        return None;
241    }
242    for (k, oc) in op_chars.iter().enumerate() {
243        if chars[j + k] != *oc {
244            return None;
245        }
246    }
247    j += op_chars.len();
248
249    // Skip whitespace
250    while j < chars.len() && chars[j].is_whitespace() {
251        j += 1;
252    }
253
254    // Read the identifier
255    if j >= chars.len() || !is_ident_char(chars[j]) {
256        return None;
257    }
258    let ident_start = j;
259    while j < chars.len() && is_ident_char(chars[j]) {
260        j += 1;
261    }
262    let ident: String = chars[ident_start..j].iter().collect();
263
264    // Avoid matching "NULL" (from already-rewritten text)
265    if ident == "NULL" {
266        return None;
267    }
268
269    Some((j, ident))
270}
271
272/// Clean SQL and apply optional parameter rewriting.
273pub(crate) fn clean_sql_with_optional(
274    sql: &str,
275    optional_params: &[String],
276    params: &[AnalyzedParam],
277) -> String {
278    let cleaned = clean_sql(sql);
279    rewrite_optional_params(&cleaned, optional_params, params)
280}
281
282/// Clean SQL (oneline) and apply optional parameter rewriting.
283pub(crate) fn clean_sql_oneline_with_optional(
284    sql: &str,
285    optional_params: &[String],
286    params: &[AnalyzedParam],
287) -> String {
288    let cleaned = clean_sql_oneline(sql);
289    rewrite_optional_params(&cleaned, optional_params, params)
290}
291
292fn is_ident_char(c: char) -> bool {
293    c.is_alphanumeric() || c == '_' || c == '.'
294}
295
296/// Get a backend by name and database engine.
297///
298/// The `engine` parameter (e.g., "postgresql", "mysql", "sqlite") determines
299/// which manifest is loaded for type mappings. PG-only backends reject non-PG engines.
300pub fn get_backend(name: &str, engine: &str) -> Result<Box<dyn CodegenBackend>, ScytheError> {
301    let backend: Box<dyn CodegenBackend> = match name {
302        "rust-sqlx" | "sqlx" | "rust" => Box::new(sqlx::SqlxBackend::new(engine)?),
303        "rust-tokio-postgres" | "tokio-postgres" => {
304            Box::new(tokio_postgres::TokioPostgresBackend::new(engine)?)
305        }
306        "python-psycopg3" | "python" => {
307            Box::new(python_psycopg3::PythonPsycopg3Backend::new(engine)?)
308        }
309        "python-asyncpg" => Box::new(python_asyncpg::PythonAsyncpgBackend::new(engine)?),
310        "python-aiomysql" => Box::new(python_aiomysql::PythonAiomysqlBackend::new(engine)?),
311        "python-aiosqlite" => Box::new(python_aiosqlite::PythonAiosqliteBackend::new(engine)?),
312        "typescript-postgres" | "ts" | "typescript" => {
313            Box::new(typescript_postgres::TypescriptPostgresBackend::new(engine)?)
314        }
315        "typescript-pg" => Box::new(typescript_pg::TypescriptPgBackend::new(engine)?),
316        "typescript-mysql2" => Box::new(typescript_mysql2::TypescriptMysql2Backend::new(engine)?),
317        "typescript-better-sqlite3" => {
318            Box::new(typescript_better_sqlite3::TypescriptBetterSqlite3Backend::new(engine)?)
319        }
320        "go-database-sql" => Box::new(go_database_sql::GoDatabaseSqlBackend::new(engine)?),
321        "go-pgx" | "go" => Box::new(go_pgx::GoPgxBackend::new(engine)?),
322        "java-jdbc" | "java" => Box::new(java_jdbc::JavaJdbcBackend::new(engine)?),
323        "kotlin-jdbc" | "kotlin" | "kt" => Box::new(kotlin_jdbc::KotlinJdbcBackend::new(engine)?),
324        "csharp-npgsql" | "csharp" | "c#" | "dotnet" => {
325            Box::new(csharp_npgsql::CsharpNpgsqlBackend::new(engine)?)
326        }
327        "csharp-mysqlconnector" => Box::new(
328            csharp_mysqlconnector::CsharpMysqlConnectorBackend::new(engine)?,
329        ),
330        "csharp-microsoft-sqlite" => Box::new(
331            csharp_microsoft_sqlite::CsharpMicrosoftSqliteBackend::new(engine)?,
332        ),
333        "elixir-postgrex" | "elixir" | "ex" => {
334            Box::new(elixir_postgrex::ElixirPostgrexBackend::new(engine)?)
335        }
336        "elixir-ecto" | "ecto" => Box::new(elixir_ecto::ElixirEctoBackend::new(engine)?),
337        "elixir-myxql" => Box::new(elixir_myxql::ElixirMyxqlBackend::new(engine)?),
338        "elixir-exqlite" => Box::new(elixir_exqlite::ElixirExqliteBackend::new(engine)?),
339        "ruby-pg" | "ruby" | "rb" => Box::new(ruby_pg::RubyPgBackend::new(engine)?),
340        "ruby-mysql2" => Box::new(ruby_mysql2::RubyMysql2Backend::new(engine)?),
341        "ruby-sqlite3" => Box::new(ruby_sqlite3::RubySqlite3Backend::new(engine)?),
342        "ruby-trilogy" | "trilogy" => Box::new(ruby_trilogy::RubyTrilogyBackend::new(engine)?),
343        "php-pdo" | "php" => Box::new(php_pdo::PhpPdoBackend::new(engine)?),
344        "php-amphp" | "amphp" => Box::new(php_amphp::PhpAmphpBackend::new(engine)?),
345        _ => {
346            return Err(ScytheError::new(
347                ErrorCode::InternalError,
348                format!("unknown backend: {}", name),
349            ));
350        }
351    };
352
353    // Validate engine is supported by this backend
354    let normalized_engine = normalize_engine(engine);
355    if !backend
356        .supported_engines()
357        .iter()
358        .any(|e| normalize_engine(e) == normalized_engine)
359    {
360        return Err(ScytheError::new(
361            ErrorCode::InternalError,
362            format!(
363                "backend '{}' does not support engine '{}'. Supported: {:?}",
364                name,
365                engine,
366                backend.supported_engines()
367            ),
368        ));
369    }
370
371    Ok(backend)
372}
373
374/// Normalize engine name to canonical form.
375fn normalize_engine(engine: &str) -> &str {
376    match engine {
377        "postgresql" | "postgres" | "pg" => "postgresql",
378        "mysql" | "mariadb" => "mysql",
379        "sqlite" | "sqlite3" => "sqlite",
380        other => other,
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    fn param(name: &str, position: i64) -> AnalyzedParam {
389        AnalyzedParam {
390            name: name.to_string(),
391            neutral_type: "string".to_string(),
392            nullable: true,
393            position,
394        }
395    }
396
397    #[test]
398    fn test_rewrite_simple_equality() {
399        let sql = "SELECT * FROM users WHERE status = $1";
400        let params = vec![param("status", 1)];
401        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
402        assert_eq!(
403            result,
404            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1)"
405        );
406    }
407
408    #[test]
409    fn test_rewrite_qualified_column() {
410        let sql = "SELECT * FROM users u WHERE u.status = $1";
411        let params = vec![param("status", 1)];
412        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
413        assert_eq!(
414            result,
415            "SELECT * FROM users u WHERE ($1 IS NULL OR u.status = $1)"
416        );
417    }
418
419    #[test]
420    fn test_rewrite_multiple_optional() {
421        let sql = "SELECT * FROM users WHERE status = $1 AND name = $2";
422        let params = vec![param("status", 1), param("name", 2)];
423        let result =
424            rewrite_optional_params(sql, &["status".to_string(), "name".to_string()], &params);
425        assert_eq!(
426            result,
427            "SELECT * FROM users WHERE ($1 IS NULL OR status = $1) AND ($2 IS NULL OR name = $2)"
428        );
429    }
430
431    #[test]
432    fn test_rewrite_mixed_optional_required() {
433        let sql = "SELECT * FROM users WHERE id = $1 AND status = $2";
434        let params = vec![param("id", 1), param("status", 2)];
435        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
436        assert_eq!(
437            result,
438            "SELECT * FROM users WHERE id = $1 AND ($2 IS NULL OR status = $2)"
439        );
440    }
441
442    #[test]
443    fn test_rewrite_like_operator() {
444        let sql = "SELECT * FROM users WHERE name LIKE $1";
445        let params = vec![param("name", 1)];
446        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
447        assert_eq!(
448            result,
449            "SELECT * FROM users WHERE ($1 IS NULL OR name LIKE $1)"
450        );
451    }
452
453    #[test]
454    fn test_rewrite_ilike_operator() {
455        let sql = "SELECT * FROM users WHERE name ILIKE $1";
456        let params = vec![param("name", 1)];
457        let result = rewrite_optional_params(sql, &["name".to_string()], &params);
458        assert_eq!(
459            result,
460            "SELECT * FROM users WHERE ($1 IS NULL OR name ILIKE $1)"
461        );
462    }
463
464    #[test]
465    fn test_rewrite_comparison_operators() {
466        let sql = "SELECT * FROM users WHERE age >= $1";
467        let params = vec![param("age", 1)];
468        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
469        assert_eq!(
470            result,
471            "SELECT * FROM users WHERE ($1 IS NULL OR age >= $1)"
472        );
473    }
474
475    #[test]
476    fn test_rewrite_less_than() {
477        let sql = "SELECT * FROM users WHERE age < $1";
478        let params = vec![param("age", 1)];
479        let result = rewrite_optional_params(sql, &["age".to_string()], &params);
480        assert_eq!(result, "SELECT * FROM users WHERE ($1 IS NULL OR age < $1)");
481    }
482
483    #[test]
484    fn test_no_rewrite_without_optional() {
485        let sql = "SELECT * FROM users WHERE status = $1";
486        let params = vec![param("status", 1)];
487        let result = rewrite_optional_params(sql, &[], &params);
488        assert_eq!(result, sql);
489    }
490
491    #[test]
492    fn test_rewrite_not_equal() {
493        let sql = "SELECT * FROM users WHERE status <> $1";
494        let params = vec![param("status", 1)];
495        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
496        assert_eq!(
497            result,
498            "SELECT * FROM users WHERE ($1 IS NULL OR status <> $1)"
499        );
500    }
501
502    #[test]
503    fn test_rewrite_does_not_match_similar_placeholder() {
504        // $1 should not match $10
505        let sql = "SELECT * FROM users WHERE status = $10";
506        let params = vec![param("status", 1)];
507        let result = rewrite_optional_params(sql, &["status".to_string()], &params);
508        // $1 placeholder doesn't appear, so no rewrite
509        assert_eq!(result, sql);
510    }
511}