Skip to main content

sql_middleware/translation/
mod.rs

1use std::borrow::Cow;
2
3mod parsers;
4mod scanner;
5
6use parsers::{
7    is_block_comment_end, is_block_comment_start, is_line_comment_start, matches_tag,
8    try_start_dollar_quote,
9};
10use scanner::{State, scan_digits};
11
12/// Target placeholder style for translation.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum PlaceholderStyle {
15    /// PostgreSQL-style placeholders like `$1`.
16    Postgres,
17    /// SQLite-style placeholders like `?1` (also used by Turso).
18    Sqlite,
19}
20
21/// How to resolve translation for a call relative to the pool default.
22///
23/// # Examples
24/// ```rust
25/// use sql_middleware::prelude::*;
26///
27/// let options = QueryOptions::default()
28///     .with_translation(TranslationMode::ForceOn);
29/// # let _ = options;
30/// ```
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TranslationMode {
33    /// Follow the pool's default setting.
34    PoolDefault,
35    /// Force translation on, regardless of pool default.
36    ForceOn,
37    /// Force translation off, regardless of pool default.
38    ForceOff,
39}
40
41impl TranslationMode {
42    #[must_use]
43    pub fn resolve(self, pool_default: bool) -> bool {
44        match self {
45            TranslationMode::PoolDefault => pool_default,
46            TranslationMode::ForceOn => true,
47            TranslationMode::ForceOff => false,
48        }
49    }
50}
51
52/// How to resolve prepared execution for a call.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54pub enum PrepareMode {
55    /// Execute without preparing.
56    #[default]
57    Direct,
58    /// Prepare the statement before execution.
59    Prepared,
60}
61
62/// Per-call options for query/execute paths.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct QueryOptions {
65    pub translation: TranslationMode,
66    pub prepare: PrepareMode,
67}
68
69impl Default for QueryOptions {
70    fn default() -> Self {
71        Self {
72            translation: TranslationMode::PoolDefault,
73            prepare: PrepareMode::default(),
74        }
75    }
76}
77
78impl QueryOptions {
79    #[must_use]
80    pub fn with_translation(mut self, translation: TranslationMode) -> Self {
81        self.translation = translation;
82        self
83    }
84
85    #[must_use]
86    pub fn with_prepare(mut self, prepare: PrepareMode) -> Self {
87        self.prepare = prepare;
88        self
89    }
90}
91
92/// Translate placeholders between Postgres-style `$N` and SQLite-style `?N`.
93///
94/// Warning: translation skips quoted strings, comments, and dollar-quoted blocks via a lightweight
95/// state machine; it may still miss edge cases in complex SQL. For dialect-specific SQL (e.g.,
96/// PL/pgSQL bodies), prefer backend-specific SQL instead of relying on translation:
97/// ```rust
98/// # use sql_middleware::prelude::*;
99/// # async fn demo(conn: &mut MiddlewarePoolConnection) -> Result<(), SqlMiddlewareDbError> {
100/// let query = match conn {
101///     MiddlewarePoolConnection::Postgres { .. } => r#"$function$
102/// BEGIN
103///     RETURN ($1 ~ $q$[\t\r\n\v\\]$q$);
104/// END;
105/// $function$"#,
106///     MiddlewarePoolConnection::Sqlite { .. } | MiddlewarePoolConnection::Turso { .. } => {
107///         include_str!("../sql/functions/sqlite/03_sp_get_scores.sql")
108///     }
109/// };
110/// # let _ = query;
111/// # Ok(())
112/// # }
113/// ```
114/// Returns a borrowed `Cow` when no changes are needed.
115#[must_use]
116pub fn translate_placeholders(sql: &str, target: PlaceholderStyle, enabled: bool) -> Cow<'_, str> {
117    if !enabled {
118        return Cow::Borrowed(sql);
119    }
120
121    let mut out: Option<String> = None;
122    let mut state = State::Normal;
123    let mut idx = 0;
124    let bytes = sql.as_bytes();
125
126    while idx < bytes.len() {
127        let b = bytes[idx];
128        let mut replaced = false;
129        match state {
130            State::Normal => match b {
131                b'\'' => state = State::SingleQuoted,
132                b'"' => state = State::DoubleQuoted,
133                _ if is_line_comment_start(bytes, idx) => state = State::LineComment,
134                _ if is_block_comment_start(bytes, idx) => state = State::BlockComment(1),
135                b'$' => {
136                    if let Some((tag, advance)) = try_start_dollar_quote(bytes, idx) {
137                        state = State::DollarQuoted(tag);
138                        idx = advance;
139                    } else if matches!(target, PlaceholderStyle::Sqlite)
140                        && let Some((digits_end, digits)) = scan_digits(bytes, idx + 1)
141                    {
142                        let buf = out.get_or_insert_with(|| sql[..idx].to_string());
143                        buf.push('?');
144                        buf.push_str(digits);
145                        idx = digits_end - 1;
146                        replaced = true;
147                    }
148                }
149                b'?' if matches!(target, PlaceholderStyle::Postgres) => {
150                    if let Some((digits_end, digits)) = scan_digits(bytes, idx + 1) {
151                        let buf = out.get_or_insert_with(|| sql[..idx].to_string());
152                        buf.push('$');
153                        buf.push_str(digits);
154                        idx = digits_end - 1;
155                        replaced = true;
156                    }
157                }
158                _ => {}
159            },
160            State::SingleQuoted => {
161                if b == b'\'' {
162                    if bytes.get(idx + 1) == Some(&b'\'') {
163                        idx += 1; // skip escaped quote
164                    } else {
165                        state = State::Normal;
166                    }
167                }
168            }
169            State::DoubleQuoted => {
170                if b == b'"' {
171                    if bytes.get(idx + 1) == Some(&b'"') {
172                        idx += 1; // skip escaped quote
173                    } else {
174                        state = State::Normal;
175                    }
176                }
177            }
178            State::LineComment => {
179                if b == b'\n' {
180                    state = State::Normal;
181                }
182            }
183            State::BlockComment(depth) => {
184                if is_block_comment_start(bytes, idx) {
185                    state = State::BlockComment(depth + 1);
186                } else if is_block_comment_end(bytes, idx) {
187                    if depth == 1 {
188                        state = State::Normal;
189                    } else {
190                        state = State::BlockComment(depth - 1);
191                    }
192                }
193            }
194            State::DollarQuoted(ref tag) => {
195                if b == b'$' && matches_tag(bytes, idx, tag) {
196                    let tag_len = tag.len();
197                    state = State::Normal;
198                    idx += tag_len;
199                }
200            }
201        }
202
203        if let Some(ref mut buf) = out
204            && !replaced
205        {
206            buf.push(b as char);
207        }
208
209        idx += 1;
210    }
211
212    match out {
213        Some(buf) => Cow::Owned(buf),
214        None => Cow::Borrowed(sql),
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn translates_sqlite_to_postgres() {
224        let sql = "select * from t where a = ?1 and b = ?2";
225        let res = translate_placeholders(sql, PlaceholderStyle::Postgres, true);
226        assert_eq!(res, "select * from t where a = $1 and b = $2");
227    }
228
229    #[test]
230    fn translates_postgres_to_sqlite() {
231        let sql = "insert into t values($1, $2)";
232        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
233        assert_eq!(res, "insert into t values(?1, ?2)");
234    }
235
236    #[test]
237    fn skips_inside_literals_and_comments() {
238        let sql = "select '?1', $1 -- $2\n/* ?3 */ from t where a = $1";
239        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
240        assert_eq!(res, "select '?1', ?1 -- $2\n/* ?3 */ from t where a = ?1");
241    }
242
243    #[test]
244    fn skips_dollar_quoted_blocks() {
245        let sql = "$foo$ select $1 from t $foo$ where a = $1";
246        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
247        assert_eq!(res, "$foo$ select $1 from t $foo$ where a = ?1");
248    }
249
250    #[test]
251    fn respects_disabled_flag() {
252        let sql = "select * from t where a = ?1";
253        let res = translate_placeholders(sql, PlaceholderStyle::Postgres, false);
254        assert!(matches!(res, Cow::Borrowed(_)));
255        assert_eq!(res, sql);
256    }
257
258    #[test]
259    fn translation_mode_resolution() {
260        assert!(TranslationMode::ForceOn.resolve(false));
261        assert!(!TranslationMode::ForceOff.resolve(true));
262        assert!(TranslationMode::PoolDefault.resolve(true));
263        assert!(!TranslationMode::PoolDefault.resolve(false));
264    }
265}