sql_middleware/translation/
mod.rs

1use std::borrow::Cow;
2
3mod parsers;
4mod scanner;
5
6use parsers::{
7    is_block_comment_end, is_block_comment_start, is_line_comment_start, matches_tag,
8    try_start_dollar_quote,
9};
10use scanner::{State, scan_digits};
11
12/// Target placeholder style for translation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum PlaceholderStyle {
15    /// PostgreSQL-style placeholders like `$1`.
16    Postgres,
17    /// SQLite-style placeholders like `?1` (also used by LibSQL/Turso).
18    Sqlite,
19}
20
21/// How to resolve translation for a call relative to the pool default.
22///
23/// # Examples
24/// ```rust
25/// use sql_middleware::prelude::*;
26///
27/// let options = QueryOptions::default()
28///     .with_translation(TranslationMode::ForceOn);
29/// # let _ = options;
30/// ```
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TranslationMode {
33    /// Follow the pool's default setting.
34    PoolDefault,
35    /// Force translation on, regardless of pool default.
36    ForceOn,
37    /// Force translation off, regardless of pool default.
38    ForceOff,
39}
40
41impl TranslationMode {
42    #[must_use]
43    pub fn resolve(self, pool_default: bool) -> bool {
44        match self {
45            TranslationMode::PoolDefault => pool_default,
46            TranslationMode::ForceOn => true,
47            TranslationMode::ForceOff => false,
48        }
49    }
50}
51
52/// Per-call options for query/execute paths.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct QueryOptions {
55    pub translation: TranslationMode,
56}
57
58impl Default for QueryOptions {
59    fn default() -> Self {
60        Self {
61            translation: TranslationMode::PoolDefault,
62        }
63    }
64}
65
66impl QueryOptions {
67    #[must_use]
68    pub fn with_translation(mut self, translation: TranslationMode) -> Self {
69        self.translation = translation;
70        self
71    }
72}
73
74/// Translate placeholders between Postgres-style `$N` and SQLite-style `?N`.
75///
76/// Warning: translation skips quoted strings, comments, and dollar-quoted blocks via a lightweight
77/// state machine; it may still miss edge cases in complex SQL. For dialect-specific SQL (e.g.,
78/// PL/pgSQL bodies), prefer backend-specific SQL instead of relying on translation:
79/// ```rust
80/// # use sql_middleware::prelude::*;
81/// # async fn demo(conn: &mut MiddlewarePoolConnection) -> Result<(), SqlMiddlewareDbError> {
82/// let query = match conn {
83///     MiddlewarePoolConnection::Postgres { .. } => r#"$function$
84/// BEGIN
85///     RETURN ($1 ~ $q$[\t\r\n\v\\]$q$);
86/// END;
87/// $function$"#,
88///     MiddlewarePoolConnection::Sqlite { .. } | MiddlewarePoolConnection::Turso { .. } => {
89///         include_str!("../sql/functions/sqlite/03_sp_get_scores.sql")
90///     }
91/// };
92/// # let _ = query;
93/// # Ok(())
94/// # }
95/// ```
96/// Returns a borrowed `Cow` when no changes are needed.
97#[must_use]
98pub fn translate_placeholders(sql: &str, target: PlaceholderStyle, enabled: bool) -> Cow<'_, str> {
99    if !enabled {
100        return Cow::Borrowed(sql);
101    }
102
103    let mut out: Option<String> = None;
104    let mut state = State::Normal;
105    let mut idx = 0;
106    let bytes = sql.as_bytes();
107
108    while idx < bytes.len() {
109        let b = bytes[idx];
110        let mut replaced = false;
111        match state {
112            State::Normal => match b {
113                b'\'' => state = State::SingleQuoted,
114                b'"' => state = State::DoubleQuoted,
115                _ if is_line_comment_start(bytes, idx) => state = State::LineComment,
116                _ if is_block_comment_start(bytes, idx) => state = State::BlockComment(1),
117                b'$' => {
118                    if let Some((tag, advance)) = try_start_dollar_quote(bytes, idx) {
119                        state = State::DollarQuoted(tag);
120                        idx = advance;
121                    } else if matches!(target, PlaceholderStyle::Sqlite)
122                        && let Some((digits_end, digits)) = scan_digits(bytes, idx + 1)
123                    {
124                        let buf = out.get_or_insert_with(|| sql[..idx].to_string());
125                        buf.push('?');
126                        buf.push_str(digits);
127                        idx = digits_end - 1;
128                        replaced = true;
129                    }
130                }
131                b'?' if matches!(target, PlaceholderStyle::Postgres) => {
132                    if let Some((digits_end, digits)) = scan_digits(bytes, idx + 1) {
133                        let buf = out.get_or_insert_with(|| sql[..idx].to_string());
134                        buf.push('$');
135                        buf.push_str(digits);
136                        idx = digits_end - 1;
137                        replaced = true;
138                    }
139                }
140                _ => {}
141            },
142            State::SingleQuoted => {
143                if b == b'\'' {
144                    if bytes.get(idx + 1) == Some(&b'\'') {
145                        idx += 1; // skip escaped quote
146                    } else {
147                        state = State::Normal;
148                    }
149                }
150            }
151            State::DoubleQuoted => {
152                if b == b'"' {
153                    if bytes.get(idx + 1) == Some(&b'"') {
154                        idx += 1; // skip escaped quote
155                    } else {
156                        state = State::Normal;
157                    }
158                }
159            }
160            State::LineComment => {
161                if b == b'\n' {
162                    state = State::Normal;
163                }
164            }
165            State::BlockComment(depth) => {
166                if is_block_comment_start(bytes, idx) {
167                    state = State::BlockComment(depth + 1);
168                } else if is_block_comment_end(bytes, idx) {
169                    if depth == 1 {
170                        state = State::Normal;
171                    } else {
172                        state = State::BlockComment(depth - 1);
173                    }
174                }
175            }
176            State::DollarQuoted(ref tag) => {
177                if b == b'$' && matches_tag(bytes, idx, tag) {
178                    let tag_len = tag.len();
179                    state = State::Normal;
180                    idx += tag_len;
181                }
182            }
183        }
184
185        if let Some(ref mut buf) = out
186            && !replaced
187        {
188            buf.push(b as char);
189        }
190
191        idx += 1;
192    }
193
194    match out {
195        Some(buf) => Cow::Owned(buf),
196        None => Cow::Borrowed(sql),
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn translates_sqlite_to_postgres() {
206        let sql = "select * from t where a = ?1 and b = ?2";
207        let res = translate_placeholders(sql, PlaceholderStyle::Postgres, true);
208        assert_eq!(res, "select * from t where a = $1 and b = $2");
209    }
210
211    #[test]
212    fn translates_postgres_to_sqlite() {
213        let sql = "insert into t values($1, $2)";
214        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
215        assert_eq!(res, "insert into t values(?1, ?2)");
216    }
217
218    #[test]
219    fn skips_inside_literals_and_comments() {
220        let sql = "select '?1', $1 -- $2\n/* ?3 */ from t where a = $1";
221        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
222        assert_eq!(res, "select '?1', ?1 -- $2\n/* ?3 */ from t where a = ?1");
223    }
224
225    #[test]
226    fn skips_dollar_quoted_blocks() {
227        let sql = "$foo$ select $1 from t $foo$ where a = $1";
228        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
229        assert_eq!(res, "$foo$ select $1 from t $foo$ where a = ?1");
230    }
231
232    #[test]
233    fn respects_disabled_flag() {
234        let sql = "select * from t where a = ?1";
235        let res = translate_placeholders(sql, PlaceholderStyle::Postgres, false);
236        assert!(matches!(res, Cow::Borrowed(_)));
237        assert_eq!(res, sql);
238    }
239
240    #[test]
241    fn translation_mode_resolution() {
242        assert!(TranslationMode::ForceOn.resolve(false));
243        assert!(!TranslationMode::ForceOff.resolve(true));
244        assert!(TranslationMode::PoolDefault.resolve(true));
245        assert!(!TranslationMode::PoolDefault.resolve(false));
246    }
247}