1pub mod csharp_microsoft_sqlite;
2pub mod csharp_mysqlconnector;
3pub mod csharp_npgsql;
4pub mod elixir_ecto;
5pub mod elixir_exqlite;
6pub mod elixir_myxql;
7pub mod elixir_postgrex;
8pub mod go_database_sql;
9pub mod go_pgx;
10pub mod java_jdbc;
11pub mod java_r2dbc;
12pub mod kotlin_exposed;
13pub mod kotlin_jdbc;
14pub mod kotlin_r2dbc;
15pub mod php_amphp;
16pub mod php_pdo;
17pub mod python_aiomysql;
18pub mod python_aiosqlite;
19pub mod python_asyncpg;
20pub mod python_common;
21pub mod python_duckdb;
22pub mod python_psycopg3;
23pub mod ruby_mysql2;
24pub mod ruby_pg;
25pub(crate) mod ruby_rbs;
26pub mod ruby_sqlite3;
27pub mod ruby_trilogy;
28pub mod sqlx;
29pub mod tokio_postgres;
30pub mod typescript_better_sqlite3;
31pub mod typescript_common;
32pub mod typescript_duckdb;
33pub mod typescript_mysql2;
34pub mod typescript_pg;
35pub mod typescript_postgres;
36
37use scythe_core::analyzer::AnalyzedParam;
38use scythe_core::errors::{ErrorCode, ScytheError};
39
40use crate::backend_trait::CodegenBackend;
41
42pub(crate) fn clean_sql(sql: &str) -> String {
45 sql.lines()
46 .filter(|line| !line.trim_start().starts_with("--"))
47 .collect::<Vec<_>>()
48 .join("\n")
49 .trim()
50 .trim_end_matches(';')
51 .trim()
52 .to_string()
53}
54
55pub(crate) fn clean_sql_oneline(sql: &str) -> String {
57 sql.lines()
58 .filter(|line| !line.trim_start().starts_with("--"))
59 .collect::<Vec<_>>()
60 .join(" ")
61 .trim()
62 .trim_end_matches(';')
63 .trim()
64 .to_string()
65}
66
67pub(crate) fn rewrite_optional_params(
75 sql: &str,
76 optional_params: &[String],
77 params: &[AnalyzedParam],
78) -> String {
79 if optional_params.is_empty() {
80 return sql.to_string();
81 }
82
83 let mut result = sql.to_string();
84
85 for opt_name in optional_params {
86 let Some(param) = params.iter().find(|p| p.name == *opt_name) else {
87 continue;
88 };
89 let placeholder = format!("${}", param.position);
90
91 for op in &[
93 ">=", "<=", "<>", "!=", ">", "<", "=", "ILIKE", "ilike", "LIKE", "like",
94 ] {
95 result = rewrite_comparison(&result, &placeholder, op);
96 }
97 }
98
99 result
100}
101
102fn rewrite_comparison(sql: &str, placeholder: &str, op: &str) -> String {
105 let mut result = String::with_capacity(sql.len() + 32);
106 let chars: Vec<char> = sql.chars().collect();
107 let len = chars.len();
108 let mut i = 0;
109
110 while i < len {
111 if let Some((start, col, end)) = try_match_col_op_ph(&chars, i, op, placeholder) {
113 if start > i {
115 }
117 result.push_str(&format!(
118 "({placeholder} IS NULL OR {col} {op} {placeholder})"
119 ));
120 i = end;
121 continue;
122 }
123
124 if let Some((end, col)) = try_match_ph_op_col(&chars, i, op, placeholder) {
126 result.push_str(&format!(
127 "({placeholder} IS NULL OR {col} {op} {placeholder})"
128 ));
129 i = end;
130 continue;
131 }
132
133 result.push(chars[i]);
134 i += 1;
135 }
136
137 result
138}
139
140fn try_match_col_op_ph(
143 chars: &[char],
144 i: usize,
145 op: &str,
146 placeholder: &str,
147) -> Option<(usize, String, usize)> {
148 if !is_ident_char(chars[i]) {
150 return None;
151 }
152 if i > 0 && is_ident_char(chars[i - 1]) {
154 return None;
155 }
156
157 let ident_start = i;
159 let mut j = i;
160 while j < chars.len() && is_ident_char(chars[j]) {
161 j += 1;
162 }
163 let ident: String = chars[ident_start..j].iter().collect();
164
165 while j < chars.len() && chars[j].is_whitespace() {
167 j += 1;
168 }
169
170 let op_chars: Vec<char> = op.chars().collect();
172 if j + op_chars.len() > chars.len() {
173 return None;
174 }
175 for (k, oc) in op_chars.iter().enumerate() {
176 if chars[j + k] != *oc {
177 return None;
178 }
179 }
180 j += op_chars.len();
181
182 while j < chars.len() && chars[j].is_whitespace() {
184 j += 1;
185 }
186
187 let ph_chars: Vec<char> = placeholder.chars().collect();
189 if j + ph_chars.len() > chars.len() {
190 return None;
191 }
192 for (k, pc) in ph_chars.iter().enumerate() {
193 if chars[j + k] != *pc {
194 return None;
195 }
196 }
197 j += ph_chars.len();
198
199 if j < chars.len() && chars[j].is_ascii_digit() {
201 return None;
202 }
203
204 Some((i, ident, j))
205}
206
207fn try_match_ph_op_col(
210 chars: &[char],
211 i: usize,
212 op: &str,
213 placeholder: &str,
214) -> Option<(usize, String)> {
215 let ph_chars: Vec<char> = placeholder.chars().collect();
216 if i + ph_chars.len() > chars.len() {
217 return None;
218 }
219
220 if i > 0 && (chars[i - 1] == '$' || chars[i - 1].is_ascii_digit()) {
222 return None;
223 }
224
225 for (k, pc) in ph_chars.iter().enumerate() {
227 if chars[i + k] != *pc {
228 return None;
229 }
230 }
231 let mut j = i + ph_chars.len();
232
233 if j < chars.len() && chars[j].is_ascii_digit() {
235 return None;
236 }
237
238 while j < chars.len() && chars[j].is_whitespace() {
240 j += 1;
241 }
242
243 let op_chars: Vec<char> = op.chars().collect();
245 if j + op_chars.len() > chars.len() {
246 return None;
247 }
248 for (k, oc) in op_chars.iter().enumerate() {
249 if chars[j + k] != *oc {
250 return None;
251 }
252 }
253 j += op_chars.len();
254
255 while j < chars.len() && chars[j].is_whitespace() {
257 j += 1;
258 }
259
260 if j >= chars.len() || !is_ident_char(chars[j]) {
262 return None;
263 }
264 let ident_start = j;
265 while j < chars.len() && is_ident_char(chars[j]) {
266 j += 1;
267 }
268 let ident: String = chars[ident_start..j].iter().collect();
269
270 if ident == "NULL" {
272 return None;
273 }
274
275 Some((j, ident))
276}
277
278pub(crate) fn clean_sql_with_optional(
280 sql: &str,
281 optional_params: &[String],
282 params: &[AnalyzedParam],
283) -> String {
284 let cleaned = clean_sql(sql);
285 rewrite_optional_params(&cleaned, optional_params, params)
286}
287
288pub(crate) fn clean_sql_oneline_with_optional(
290 sql: &str,
291 optional_params: &[String],
292 params: &[AnalyzedParam],
293) -> String {
294 let cleaned = clean_sql_oneline(sql);
295 rewrite_optional_params(&cleaned, optional_params, params)
296}
297
298fn is_ident_char(c: char) -> bool {
299 c.is_alphanumeric() || c == '_' || c == '.'
300}
301
302pub fn get_backend(name: &str, engine: &str) -> Result<Box<dyn CodegenBackend>, ScytheError> {
307 let canonical_engine = normalize_engine(engine);
311 let backend: Box<dyn CodegenBackend> = match name {
312 "rust-sqlx" | "sqlx" | "rust" => Box::new(sqlx::SqlxBackend::new(canonical_engine)?),
313 "rust-tokio-postgres" | "tokio-postgres" => {
314 Box::new(tokio_postgres::TokioPostgresBackend::new(canonical_engine)?)
315 }
316 "python-psycopg3" | "python" => Box::new(python_psycopg3::PythonPsycopg3Backend::new(
317 canonical_engine,
318 )?),
319 "python-asyncpg" => Box::new(python_asyncpg::PythonAsyncpgBackend::new(canonical_engine)?),
320 "python-aiomysql" => Box::new(python_aiomysql::PythonAiomysqlBackend::new(
321 canonical_engine,
322 )?),
323 "python-aiosqlite" => Box::new(python_aiosqlite::PythonAiosqliteBackend::new(
324 canonical_engine,
325 )?),
326 "python-duckdb" => Box::new(python_duckdb::PythonDuckdbBackend::new(canonical_engine)?),
327 "typescript-postgres" | "ts" | "typescript" => Box::new(
328 typescript_postgres::TypescriptPostgresBackend::new(canonical_engine)?,
329 ),
330 "typescript-pg" => Box::new(typescript_pg::TypescriptPgBackend::new(canonical_engine)?),
331 "typescript-mysql2" => Box::new(typescript_mysql2::TypescriptMysql2Backend::new(
332 canonical_engine,
333 )?),
334 "typescript-better-sqlite3" => Box::new(
335 typescript_better_sqlite3::TypescriptBetterSqlite3Backend::new(canonical_engine)?,
336 ),
337 "typescript-duckdb" => Box::new(typescript_duckdb::TypescriptDuckdbBackend::new(
338 canonical_engine,
339 )?),
340 "go-database-sql" => Box::new(go_database_sql::GoDatabaseSqlBackend::new(
341 canonical_engine,
342 )?),
343 "go-pgx" | "go" => Box::new(go_pgx::GoPgxBackend::new(canonical_engine)?),
344 "java-jdbc" | "java" => Box::new(java_jdbc::JavaJdbcBackend::new(canonical_engine)?),
345 "java-r2dbc" | "r2dbc-java" => {
346 Box::new(java_r2dbc::JavaR2dbcBackend::new(canonical_engine)?)
347 }
348 "kotlin-exposed" | "exposed" => {
349 Box::new(kotlin_exposed::KotlinExposedBackend::new(canonical_engine)?)
350 }
351 "kotlin-jdbc" | "kotlin" | "kt" => {
352 Box::new(kotlin_jdbc::KotlinJdbcBackend::new(canonical_engine)?)
353 }
354 "kotlin-r2dbc" | "r2dbc-kotlin" => {
355 Box::new(kotlin_r2dbc::KotlinR2dbcBackend::new(canonical_engine)?)
356 }
357 "csharp-npgsql" | "csharp" | "c#" | "dotnet" => {
358 Box::new(csharp_npgsql::CsharpNpgsqlBackend::new(canonical_engine)?)
359 }
360 "csharp-mysqlconnector" => Box::new(
361 csharp_mysqlconnector::CsharpMysqlConnectorBackend::new(canonical_engine)?,
362 ),
363 "csharp-microsoft-sqlite" => Box::new(
364 csharp_microsoft_sqlite::CsharpMicrosoftSqliteBackend::new(canonical_engine)?,
365 ),
366 "elixir-postgrex" | "elixir" | "ex" => Box::new(
367 elixir_postgrex::ElixirPostgrexBackend::new(canonical_engine)?,
368 ),
369 "elixir-ecto" | "ecto" => Box::new(elixir_ecto::ElixirEctoBackend::new(canonical_engine)?),
370 "elixir-myxql" => Box::new(elixir_myxql::ElixirMyxqlBackend::new(canonical_engine)?),
371 "elixir-exqlite" => Box::new(elixir_exqlite::ElixirExqliteBackend::new(canonical_engine)?),
372 "ruby-pg" | "ruby" | "rb" => Box::new(ruby_pg::RubyPgBackend::new(canonical_engine)?),
373 "ruby-mysql2" => Box::new(ruby_mysql2::RubyMysql2Backend::new(canonical_engine)?),
374 "ruby-sqlite3" => Box::new(ruby_sqlite3::RubySqlite3Backend::new(canonical_engine)?),
375 "ruby-trilogy" | "trilogy" => {
376 Box::new(ruby_trilogy::RubyTrilogyBackend::new(canonical_engine)?)
377 }
378 "php-pdo" | "php" => Box::new(php_pdo::PhpPdoBackend::new(canonical_engine)?),
379 "php-amphp" | "amphp" => Box::new(php_amphp::PhpAmphpBackend::new(canonical_engine)?),
380 _ => {
381 return Err(ScytheError::new(
382 ErrorCode::InternalError,
383 format!("unknown backend: {}", name),
384 ));
385 }
386 };
387
388 if !backend
390 .supported_engines()
391 .iter()
392 .any(|e| normalize_engine(e) == canonical_engine)
393 {
394 return Err(ScytheError::new(
395 ErrorCode::InternalError,
396 format!(
397 "backend '{}' does not support engine '{}'. Supported: {:?}",
398 name,
399 engine,
400 backend.supported_engines()
401 ),
402 ));
403 }
404
405 Ok(backend)
406}
407
408fn normalize_engine(engine: &str) -> &str {
410 match engine {
411 "postgresql" | "postgres" | "pg" | "cockroachdb" | "crdb" => "postgresql",
412 "mysql" | "mariadb" => "mysql",
413 "sqlite" | "sqlite3" => "sqlite",
414 "duckdb" => "duckdb",
415 other => other,
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 fn param(name: &str, position: i64) -> AnalyzedParam {
424 AnalyzedParam {
425 name: name.to_string(),
426 neutral_type: "string".to_string(),
427 nullable: true,
428 position,
429 }
430 }
431
432 #[test]
433 fn test_normalize_engine_cockroachdb() {
434 assert_eq!(normalize_engine("cockroachdb"), "postgresql");
435 assert_eq!(normalize_engine("crdb"), "postgresql");
436 }
437
438 #[test]
439 fn test_get_backend_cockroachdb_with_pg_backends() {
440 let pg_backends = [
442 "rust-sqlx",
443 "rust-tokio-postgres",
444 "python-psycopg3",
445 "python-asyncpg",
446 "typescript-postgres",
447 "typescript-pg",
448 "go-pgx",
449 "ruby-pg",
450 "elixir-postgrex",
451 "csharp-npgsql",
452 "php-pdo",
453 "php-amphp",
454 ];
455 for backend_name in &pg_backends {
456 let result = get_backend(backend_name, "cockroachdb");
457 assert!(
458 result.is_ok(),
459 "backend '{}' should accept cockroachdb engine, got: {:?}",
460 backend_name,
461 result.err()
462 );
463 }
464 }
465
466 #[test]
467 fn test_get_backend_crdb_alias() {
468 let result = get_backend("rust-sqlx", "crdb");
469 assert!(
470 result.is_ok(),
471 "rust-sqlx should accept 'crdb' engine alias"
472 );
473 }
474
475 #[test]
476 fn test_normalize_engine_duckdb() {
477 assert_eq!(normalize_engine("duckdb"), "duckdb");
478 }
479
480 #[test]
481 fn test_get_backend_duckdb_with_compatible_backends() {
482 let duckdb_backends = [
483 "python-duckdb",
484 "typescript-duckdb",
485 "go-database-sql",
486 "java-jdbc",
487 "kotlin-jdbc",
488 ];
489 for backend_name in &duckdb_backends {
490 let result = get_backend(backend_name, "duckdb");
491 assert!(
492 result.is_ok(),
493 "backend '{}' should accept duckdb engine, got: {:?}",
494 backend_name,
495 result.err()
496 );
497 }
498 }
499
500 #[test]
501 fn test_get_backend_duckdb_rejected_by_pg_only() {
502 let result = get_backend("rust-sqlx", "duckdb");
503 assert!(result.is_err(), "rust-sqlx should reject duckdb engine");
504 }
505
506 #[test]
507 fn test_rewrite_simple_equality() {
508 let sql = "SELECT * FROM users WHERE status = $1";
509 let params = vec![param("status", 1)];
510 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
511 assert_eq!(
512 result,
513 "SELECT * FROM users WHERE ($1 IS NULL OR status = $1)"
514 );
515 }
516
517 #[test]
518 fn test_rewrite_qualified_column() {
519 let sql = "SELECT * FROM users u WHERE u.status = $1";
520 let params = vec![param("status", 1)];
521 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
522 assert_eq!(
523 result,
524 "SELECT * FROM users u WHERE ($1 IS NULL OR u.status = $1)"
525 );
526 }
527
528 #[test]
529 fn test_rewrite_multiple_optional() {
530 let sql = "SELECT * FROM users WHERE status = $1 AND name = $2";
531 let params = vec![param("status", 1), param("name", 2)];
532 let result =
533 rewrite_optional_params(sql, &["status".to_string(), "name".to_string()], ¶ms);
534 assert_eq!(
535 result,
536 "SELECT * FROM users WHERE ($1 IS NULL OR status = $1) AND ($2 IS NULL OR name = $2)"
537 );
538 }
539
540 #[test]
541 fn test_rewrite_mixed_optional_required() {
542 let sql = "SELECT * FROM users WHERE id = $1 AND status = $2";
543 let params = vec![param("id", 1), param("status", 2)];
544 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
545 assert_eq!(
546 result,
547 "SELECT * FROM users WHERE id = $1 AND ($2 IS NULL OR status = $2)"
548 );
549 }
550
551 #[test]
552 fn test_rewrite_like_operator() {
553 let sql = "SELECT * FROM users WHERE name LIKE $1";
554 let params = vec![param("name", 1)];
555 let result = rewrite_optional_params(sql, &["name".to_string()], ¶ms);
556 assert_eq!(
557 result,
558 "SELECT * FROM users WHERE ($1 IS NULL OR name LIKE $1)"
559 );
560 }
561
562 #[test]
563 fn test_rewrite_ilike_operator() {
564 let sql = "SELECT * FROM users WHERE name ILIKE $1";
565 let params = vec![param("name", 1)];
566 let result = rewrite_optional_params(sql, &["name".to_string()], ¶ms);
567 assert_eq!(
568 result,
569 "SELECT * FROM users WHERE ($1 IS NULL OR name ILIKE $1)"
570 );
571 }
572
573 #[test]
574 fn test_rewrite_comparison_operators() {
575 let sql = "SELECT * FROM users WHERE age >= $1";
576 let params = vec![param("age", 1)];
577 let result = rewrite_optional_params(sql, &["age".to_string()], ¶ms);
578 assert_eq!(
579 result,
580 "SELECT * FROM users WHERE ($1 IS NULL OR age >= $1)"
581 );
582 }
583
584 #[test]
585 fn test_rewrite_less_than() {
586 let sql = "SELECT * FROM users WHERE age < $1";
587 let params = vec![param("age", 1)];
588 let result = rewrite_optional_params(sql, &["age".to_string()], ¶ms);
589 assert_eq!(result, "SELECT * FROM users WHERE ($1 IS NULL OR age < $1)");
590 }
591
592 #[test]
593 fn test_no_rewrite_without_optional() {
594 let sql = "SELECT * FROM users WHERE status = $1";
595 let params = vec![param("status", 1)];
596 let result = rewrite_optional_params(sql, &[], ¶ms);
597 assert_eq!(result, sql);
598 }
599
600 #[test]
601 fn test_rewrite_not_equal() {
602 let sql = "SELECT * FROM users WHERE status <> $1";
603 let params = vec![param("status", 1)];
604 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
605 assert_eq!(
606 result,
607 "SELECT * FROM users WHERE ($1 IS NULL OR status <> $1)"
608 );
609 }
610
611 #[test]
612 fn test_rewrite_does_not_match_similar_placeholder() {
613 let sql = "SELECT * FROM users WHERE status = $10";
615 let params = vec![param("status", 1)];
616 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
617 assert_eq!(result, sql);
619 }
620}