1pub mod csharp_microsoft_sqlite;
2pub mod csharp_mysqlconnector;
3pub mod csharp_npgsql;
4pub mod elixir_ecto;
5pub mod elixir_exqlite;
6pub mod elixir_myxql;
7pub mod elixir_postgrex;
8pub mod go_database_sql;
9pub mod go_pgx;
10pub mod java_jdbc;
11pub mod kotlin_jdbc;
12pub mod php_amphp;
13pub mod php_pdo;
14pub mod python_aiomysql;
15pub mod python_aiosqlite;
16pub mod python_asyncpg;
17pub mod python_common;
18pub mod python_psycopg3;
19pub mod ruby_mysql2;
20pub mod ruby_pg;
21pub mod ruby_sqlite3;
22pub mod ruby_trilogy;
23pub mod sqlx;
24pub mod tokio_postgres;
25pub mod typescript_better_sqlite3;
26pub mod typescript_common;
27pub mod typescript_mysql2;
28pub mod typescript_pg;
29pub mod typescript_postgres;
30
31use scythe_core::analyzer::AnalyzedParam;
32use scythe_core::errors::{ErrorCode, ScytheError};
33
34use crate::backend_trait::CodegenBackend;
35
36pub(crate) fn clean_sql(sql: &str) -> String {
39 sql.lines()
40 .filter(|line| !line.trim_start().starts_with("--"))
41 .collect::<Vec<_>>()
42 .join("\n")
43 .trim()
44 .trim_end_matches(';')
45 .trim()
46 .to_string()
47}
48
49pub(crate) fn clean_sql_oneline(sql: &str) -> String {
51 sql.lines()
52 .filter(|line| !line.trim_start().starts_with("--"))
53 .collect::<Vec<_>>()
54 .join(" ")
55 .trim()
56 .trim_end_matches(';')
57 .trim()
58 .to_string()
59}
60
61pub(crate) fn rewrite_optional_params(
69 sql: &str,
70 optional_params: &[String],
71 params: &[AnalyzedParam],
72) -> String {
73 if optional_params.is_empty() {
74 return sql.to_string();
75 }
76
77 let mut result = sql.to_string();
78
79 for opt_name in optional_params {
80 let Some(param) = params.iter().find(|p| p.name == *opt_name) else {
81 continue;
82 };
83 let placeholder = format!("${}", param.position);
84
85 for op in &[
87 ">=", "<=", "<>", "!=", ">", "<", "=", "ILIKE", "ilike", "LIKE", "like",
88 ] {
89 result = rewrite_comparison(&result, &placeholder, op);
90 }
91 }
92
93 result
94}
95
96fn rewrite_comparison(sql: &str, placeholder: &str, op: &str) -> String {
99 let mut result = String::with_capacity(sql.len() + 32);
100 let chars: Vec<char> = sql.chars().collect();
101 let len = chars.len();
102 let mut i = 0;
103
104 while i < len {
105 if let Some((start, col, end)) = try_match_col_op_ph(&chars, i, op, placeholder) {
107 if start > i {
109 }
111 result.push_str(&format!(
112 "({placeholder} IS NULL OR {col} {op} {placeholder})"
113 ));
114 i = end;
115 continue;
116 }
117
118 if let Some((end, col)) = try_match_ph_op_col(&chars, i, op, placeholder) {
120 result.push_str(&format!(
121 "({placeholder} IS NULL OR {col} {op} {placeholder})"
122 ));
123 i = end;
124 continue;
125 }
126
127 result.push(chars[i]);
128 i += 1;
129 }
130
131 result
132}
133
134fn try_match_col_op_ph(
137 chars: &[char],
138 i: usize,
139 op: &str,
140 placeholder: &str,
141) -> Option<(usize, String, usize)> {
142 if !is_ident_char(chars[i]) {
144 return None;
145 }
146 if i > 0 && is_ident_char(chars[i - 1]) {
148 return None;
149 }
150
151 let ident_start = i;
153 let mut j = i;
154 while j < chars.len() && is_ident_char(chars[j]) {
155 j += 1;
156 }
157 let ident: String = chars[ident_start..j].iter().collect();
158
159 while j < chars.len() && chars[j].is_whitespace() {
161 j += 1;
162 }
163
164 let op_chars: Vec<char> = op.chars().collect();
166 if j + op_chars.len() > chars.len() {
167 return None;
168 }
169 for (k, oc) in op_chars.iter().enumerate() {
170 if chars[j + k] != *oc {
171 return None;
172 }
173 }
174 j += op_chars.len();
175
176 while j < chars.len() && chars[j].is_whitespace() {
178 j += 1;
179 }
180
181 let ph_chars: Vec<char> = placeholder.chars().collect();
183 if j + ph_chars.len() > chars.len() {
184 return None;
185 }
186 for (k, pc) in ph_chars.iter().enumerate() {
187 if chars[j + k] != *pc {
188 return None;
189 }
190 }
191 j += ph_chars.len();
192
193 if j < chars.len() && chars[j].is_ascii_digit() {
195 return None;
196 }
197
198 Some((i, ident, j))
199}
200
201fn try_match_ph_op_col(
204 chars: &[char],
205 i: usize,
206 op: &str,
207 placeholder: &str,
208) -> Option<(usize, String)> {
209 let ph_chars: Vec<char> = placeholder.chars().collect();
210 if i + ph_chars.len() > chars.len() {
211 return None;
212 }
213
214 if i > 0 && (chars[i - 1] == '$' || chars[i - 1].is_ascii_digit()) {
216 return None;
217 }
218
219 for (k, pc) in ph_chars.iter().enumerate() {
221 if chars[i + k] != *pc {
222 return None;
223 }
224 }
225 let mut j = i + ph_chars.len();
226
227 if j < chars.len() && chars[j].is_ascii_digit() {
229 return None;
230 }
231
232 while j < chars.len() && chars[j].is_whitespace() {
234 j += 1;
235 }
236
237 let op_chars: Vec<char> = op.chars().collect();
239 if j + op_chars.len() > chars.len() {
240 return None;
241 }
242 for (k, oc) in op_chars.iter().enumerate() {
243 if chars[j + k] != *oc {
244 return None;
245 }
246 }
247 j += op_chars.len();
248
249 while j < chars.len() && chars[j].is_whitespace() {
251 j += 1;
252 }
253
254 if j >= chars.len() || !is_ident_char(chars[j]) {
256 return None;
257 }
258 let ident_start = j;
259 while j < chars.len() && is_ident_char(chars[j]) {
260 j += 1;
261 }
262 let ident: String = chars[ident_start..j].iter().collect();
263
264 if ident == "NULL" {
266 return None;
267 }
268
269 Some((j, ident))
270}
271
272pub(crate) fn clean_sql_with_optional(
274 sql: &str,
275 optional_params: &[String],
276 params: &[AnalyzedParam],
277) -> String {
278 let cleaned = clean_sql(sql);
279 rewrite_optional_params(&cleaned, optional_params, params)
280}
281
282pub(crate) fn clean_sql_oneline_with_optional(
284 sql: &str,
285 optional_params: &[String],
286 params: &[AnalyzedParam],
287) -> String {
288 let cleaned = clean_sql_oneline(sql);
289 rewrite_optional_params(&cleaned, optional_params, params)
290}
291
292fn is_ident_char(c: char) -> bool {
293 c.is_alphanumeric() || c == '_' || c == '.'
294}
295
296pub fn get_backend(name: &str, engine: &str) -> Result<Box<dyn CodegenBackend>, ScytheError> {
301 let backend: Box<dyn CodegenBackend> = match name {
302 "rust-sqlx" | "sqlx" | "rust" => Box::new(sqlx::SqlxBackend::new(engine)?),
303 "rust-tokio-postgres" | "tokio-postgres" => {
304 Box::new(tokio_postgres::TokioPostgresBackend::new(engine)?)
305 }
306 "python-psycopg3" | "python" => {
307 Box::new(python_psycopg3::PythonPsycopg3Backend::new(engine)?)
308 }
309 "python-asyncpg" => Box::new(python_asyncpg::PythonAsyncpgBackend::new(engine)?),
310 "python-aiomysql" => Box::new(python_aiomysql::PythonAiomysqlBackend::new(engine)?),
311 "python-aiosqlite" => Box::new(python_aiosqlite::PythonAiosqliteBackend::new(engine)?),
312 "typescript-postgres" | "ts" | "typescript" => {
313 Box::new(typescript_postgres::TypescriptPostgresBackend::new(engine)?)
314 }
315 "typescript-pg" => Box::new(typescript_pg::TypescriptPgBackend::new(engine)?),
316 "typescript-mysql2" => Box::new(typescript_mysql2::TypescriptMysql2Backend::new(engine)?),
317 "typescript-better-sqlite3" => {
318 Box::new(typescript_better_sqlite3::TypescriptBetterSqlite3Backend::new(engine)?)
319 }
320 "go-database-sql" => Box::new(go_database_sql::GoDatabaseSqlBackend::new(engine)?),
321 "go-pgx" | "go" => Box::new(go_pgx::GoPgxBackend::new(engine)?),
322 "java-jdbc" | "java" => Box::new(java_jdbc::JavaJdbcBackend::new(engine)?),
323 "kotlin-jdbc" | "kotlin" | "kt" => Box::new(kotlin_jdbc::KotlinJdbcBackend::new(engine)?),
324 "csharp-npgsql" | "csharp" | "c#" | "dotnet" => {
325 Box::new(csharp_npgsql::CsharpNpgsqlBackend::new(engine)?)
326 }
327 "csharp-mysqlconnector" => Box::new(
328 csharp_mysqlconnector::CsharpMysqlConnectorBackend::new(engine)?,
329 ),
330 "csharp-microsoft-sqlite" => Box::new(
331 csharp_microsoft_sqlite::CsharpMicrosoftSqliteBackend::new(engine)?,
332 ),
333 "elixir-postgrex" | "elixir" | "ex" => {
334 Box::new(elixir_postgrex::ElixirPostgrexBackend::new(engine)?)
335 }
336 "elixir-ecto" | "ecto" => Box::new(elixir_ecto::ElixirEctoBackend::new(engine)?),
337 "elixir-myxql" => Box::new(elixir_myxql::ElixirMyxqlBackend::new(engine)?),
338 "elixir-exqlite" => Box::new(elixir_exqlite::ElixirExqliteBackend::new(engine)?),
339 "ruby-pg" | "ruby" | "rb" => Box::new(ruby_pg::RubyPgBackend::new(engine)?),
340 "ruby-mysql2" => Box::new(ruby_mysql2::RubyMysql2Backend::new(engine)?),
341 "ruby-sqlite3" => Box::new(ruby_sqlite3::RubySqlite3Backend::new(engine)?),
342 "ruby-trilogy" | "trilogy" => Box::new(ruby_trilogy::RubyTrilogyBackend::new(engine)?),
343 "php-pdo" | "php" => Box::new(php_pdo::PhpPdoBackend::new(engine)?),
344 "php-amphp" | "amphp" => Box::new(php_amphp::PhpAmphpBackend::new(engine)?),
345 _ => {
346 return Err(ScytheError::new(
347 ErrorCode::InternalError,
348 format!("unknown backend: {}", name),
349 ));
350 }
351 };
352
353 let normalized_engine = normalize_engine(engine);
355 if !backend
356 .supported_engines()
357 .iter()
358 .any(|e| normalize_engine(e) == normalized_engine)
359 {
360 return Err(ScytheError::new(
361 ErrorCode::InternalError,
362 format!(
363 "backend '{}' does not support engine '{}'. Supported: {:?}",
364 name,
365 engine,
366 backend.supported_engines()
367 ),
368 ));
369 }
370
371 Ok(backend)
372}
373
374fn normalize_engine(engine: &str) -> &str {
376 match engine {
377 "postgresql" | "postgres" | "pg" => "postgresql",
378 "mysql" | "mariadb" => "mysql",
379 "sqlite" | "sqlite3" => "sqlite",
380 other => other,
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 fn param(name: &str, position: i64) -> AnalyzedParam {
389 AnalyzedParam {
390 name: name.to_string(),
391 neutral_type: "string".to_string(),
392 nullable: true,
393 position,
394 }
395 }
396
397 #[test]
398 fn test_rewrite_simple_equality() {
399 let sql = "SELECT * FROM users WHERE status = $1";
400 let params = vec![param("status", 1)];
401 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
402 assert_eq!(
403 result,
404 "SELECT * FROM users WHERE ($1 IS NULL OR status = $1)"
405 );
406 }
407
408 #[test]
409 fn test_rewrite_qualified_column() {
410 let sql = "SELECT * FROM users u WHERE u.status = $1";
411 let params = vec![param("status", 1)];
412 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
413 assert_eq!(
414 result,
415 "SELECT * FROM users u WHERE ($1 IS NULL OR u.status = $1)"
416 );
417 }
418
419 #[test]
420 fn test_rewrite_multiple_optional() {
421 let sql = "SELECT * FROM users WHERE status = $1 AND name = $2";
422 let params = vec![param("status", 1), param("name", 2)];
423 let result =
424 rewrite_optional_params(sql, &["status".to_string(), "name".to_string()], ¶ms);
425 assert_eq!(
426 result,
427 "SELECT * FROM users WHERE ($1 IS NULL OR status = $1) AND ($2 IS NULL OR name = $2)"
428 );
429 }
430
431 #[test]
432 fn test_rewrite_mixed_optional_required() {
433 let sql = "SELECT * FROM users WHERE id = $1 AND status = $2";
434 let params = vec![param("id", 1), param("status", 2)];
435 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
436 assert_eq!(
437 result,
438 "SELECT * FROM users WHERE id = $1 AND ($2 IS NULL OR status = $2)"
439 );
440 }
441
442 #[test]
443 fn test_rewrite_like_operator() {
444 let sql = "SELECT * FROM users WHERE name LIKE $1";
445 let params = vec![param("name", 1)];
446 let result = rewrite_optional_params(sql, &["name".to_string()], ¶ms);
447 assert_eq!(
448 result,
449 "SELECT * FROM users WHERE ($1 IS NULL OR name LIKE $1)"
450 );
451 }
452
453 #[test]
454 fn test_rewrite_ilike_operator() {
455 let sql = "SELECT * FROM users WHERE name ILIKE $1";
456 let params = vec![param("name", 1)];
457 let result = rewrite_optional_params(sql, &["name".to_string()], ¶ms);
458 assert_eq!(
459 result,
460 "SELECT * FROM users WHERE ($1 IS NULL OR name ILIKE $1)"
461 );
462 }
463
464 #[test]
465 fn test_rewrite_comparison_operators() {
466 let sql = "SELECT * FROM users WHERE age >= $1";
467 let params = vec![param("age", 1)];
468 let result = rewrite_optional_params(sql, &["age".to_string()], ¶ms);
469 assert_eq!(
470 result,
471 "SELECT * FROM users WHERE ($1 IS NULL OR age >= $1)"
472 );
473 }
474
475 #[test]
476 fn test_rewrite_less_than() {
477 let sql = "SELECT * FROM users WHERE age < $1";
478 let params = vec![param("age", 1)];
479 let result = rewrite_optional_params(sql, &["age".to_string()], ¶ms);
480 assert_eq!(result, "SELECT * FROM users WHERE ($1 IS NULL OR age < $1)");
481 }
482
483 #[test]
484 fn test_no_rewrite_without_optional() {
485 let sql = "SELECT * FROM users WHERE status = $1";
486 let params = vec![param("status", 1)];
487 let result = rewrite_optional_params(sql, &[], ¶ms);
488 assert_eq!(result, sql);
489 }
490
491 #[test]
492 fn test_rewrite_not_equal() {
493 let sql = "SELECT * FROM users WHERE status <> $1";
494 let params = vec![param("status", 1)];
495 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
496 assert_eq!(
497 result,
498 "SELECT * FROM users WHERE ($1 IS NULL OR status <> $1)"
499 );
500 }
501
502 #[test]
503 fn test_rewrite_does_not_match_similar_placeholder() {
504 let sql = "SELECT * FROM users WHERE status = $10";
506 let params = vec![param("status", 1)];
507 let result = rewrite_optional_params(sql, &["status".to_string()], ¶ms);
508 assert_eq!(result, sql);
510 }
511}