sql_middleware/
translation.rs1use std::borrow::Cow;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum PlaceholderStyle {
6 Postgres,
8 Sqlite,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TranslationMode {
15 PoolDefault,
17 ForceOn,
19 ForceOff,
21}
22
23impl TranslationMode {
24 #[must_use]
25 pub fn resolve(self, pool_default: bool) -> bool {
26 match self {
27 TranslationMode::PoolDefault => pool_default,
28 TranslationMode::ForceOn => true,
29 TranslationMode::ForceOff => false,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct QueryOptions {
37 pub translation: TranslationMode,
38}
39
40impl Default for QueryOptions {
41 fn default() -> Self {
42 Self {
43 translation: TranslationMode::PoolDefault,
44 }
45 }
46}
47
48impl QueryOptions {
49 #[must_use]
50 pub fn with_translation(mut self, translation: TranslationMode) -> Self {
51 self.translation = translation;
52 self
53 }
54}
55
56#[must_use]
80pub fn translate_placeholders(
81 sql: &str,
82 target: PlaceholderStyle,
83 enabled: bool,
84) -> Cow<'_, str> {
85 if !enabled {
86 return Cow::Borrowed(sql);
87 }
88
89 let mut out: Option<String> = None;
90 let mut state = State::Normal;
91 let mut idx = 0;
92 let bytes = sql.as_bytes();
93
94 while idx < bytes.len() {
95 let b = bytes[idx];
96 let mut replaced = false;
97 match state {
98 State::Normal => match b {
99 b'\'' => state = State::SingleQuoted,
100 b'"' => state = State::DoubleQuoted,
101 b'-' if bytes.get(idx + 1) == Some(&b'-') => {
102 state = State::LineComment;
103 }
104 b'/' if bytes.get(idx + 1) == Some(&b'*') => {
105 state = State::BlockComment(1);
106 }
107 b'$' => {
108 if let Some((tag, advance)) = try_start_dollar_quote(bytes, idx) {
109 state = State::DollarQuoted(tag);
110 idx = advance;
111 } else if matches!(target, PlaceholderStyle::Sqlite)
112 && let Some((digits_end, digits)) = scan_digits(bytes, idx + 1)
113 {
114 let buf = out.get_or_insert_with(|| sql[..idx].to_string());
115 buf.push('?');
116 buf.push_str(digits);
117 idx = digits_end - 1;
118 replaced = true;
119 }
120 }
121 b'?' if matches!(target, PlaceholderStyle::Postgres) => {
122 if let Some((digits_end, digits)) = scan_digits(bytes, idx + 1) {
123 let buf = out.get_or_insert_with(|| sql[..idx].to_string());
124 buf.push('$');
125 buf.push_str(digits);
126 idx = digits_end - 1;
127 replaced = true;
128 }
129 }
130 _ => {}
131 },
132 State::SingleQuoted => {
133 if b == b'\'' {
134 if bytes.get(idx + 1) == Some(&b'\'') {
135 idx += 1; } else {
137 state = State::Normal;
138 }
139 }
140 }
141 State::DoubleQuoted => {
142 if b == b'"' {
143 if bytes.get(idx + 1) == Some(&b'"') {
144 idx += 1; } else {
146 state = State::Normal;
147 }
148 }
149 }
150 State::LineComment => {
151 if b == b'\n' {
152 state = State::Normal;
153 }
154 }
155 State::BlockComment(depth) => {
156 if b == b'/' && bytes.get(idx + 1) == Some(&b'*') {
157 state = State::BlockComment(depth + 1);
158 } else if b == b'*' && bytes.get(idx + 1) == Some(&b'/') {
159 if depth == 1 {
160 state = State::Normal;
161 } else {
162 state = State::BlockComment(depth - 1);
163 }
164 }
165 }
166 State::DollarQuoted(ref tag) => {
167 if b == b'$' && matches_tag(bytes, idx, tag) {
168 let tag_len = tag.len();
169 state = State::Normal;
170 idx += tag_len;
171 }
172 }
173 }
174
175 if let Some(ref mut buf) = out && !replaced {
176 buf.push(b as char);
177 }
178
179 idx += 1;
180 }
181
182 match out {
183 Some(buf) => Cow::Owned(buf),
184 None => Cow::Borrowed(sql),
185 }
186}
187
188#[derive(Clone)]
189enum State {
190 Normal,
191 SingleQuoted,
192 DoubleQuoted,
193 LineComment,
194 BlockComment(u32),
195 DollarQuoted(String),
196}
197
198fn scan_digits(bytes: &[u8], start: usize) -> Option<(usize, &str)> {
199 let mut idx = start;
200 while idx < bytes.len() && bytes[idx].is_ascii_digit() {
201 idx += 1;
202 }
203 if idx == start {
204 None
205 } else {
206 std::str::from_utf8(&bytes[start..idx])
207 .ok()
208 .map(|digits| (idx, digits))
209 }
210}
211
212fn try_start_dollar_quote(bytes: &[u8], start: usize) -> Option<(String, usize)> {
213 let mut idx = start + 1;
214 while idx < bytes.len() && bytes[idx] != b'$' {
215 let b = bytes[idx];
216 if !(b.is_ascii_alphanumeric() || b == b'_') {
217 return None;
218 }
219 idx += 1;
220 }
221
222 if idx < bytes.len() && bytes[idx] == b'$' {
223 let tag = String::from_utf8(bytes[start + 1..idx].to_vec()).ok()?;
224 Some((tag, idx))
225 } else {
226 None
227 }
228}
229
230fn matches_tag(bytes: &[u8], idx: usize, tag: &str) -> bool {
231 let end = idx + 1 + tag.len();
232 end < bytes.len()
233 && bytes[idx + 1..=end].starts_with(tag.as_bytes())
234 && bytes.get(end) == Some(&b'$')
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn translates_sqlite_to_postgres() {
243 let sql = "select * from t where a = ?1 and b = ?2";
244 let res = translate_placeholders(sql, PlaceholderStyle::Postgres, true);
245 assert_eq!(res, "select * from t where a = $1 and b = $2");
246 }
247
248 #[test]
249 fn translates_postgres_to_sqlite() {
250 let sql = "insert into t values($1, $2)";
251 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
252 assert_eq!(res, "insert into t values(?1, ?2)");
253 }
254
255 #[test]
256 fn skips_inside_literals_and_comments() {
257 let sql = "select '?1', $1 -- $2\n/* ?3 */ from t where a = $1";
258 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
259 assert_eq!(res, "select '?1', ?1 -- $2\n/* ?3 */ from t where a = ?1");
260 }
261
262 #[test]
263 fn skips_dollar_quoted_blocks() {
264 let sql = "$foo$ select $1 from t $foo$ where a = $1";
265 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
266 assert_eq!(res, "$foo$ select $1 from t $foo$ where a = ?1");
267 }
268
269 #[test]
270 fn respects_disabled_flag() {
271 let sql = "select * from t where a = ?1";
272 let res = translate_placeholders(sql, PlaceholderStyle::Postgres, false);
273 assert!(matches!(res, Cow::Borrowed(_)));
274 assert_eq!(res, sql);
275 }
276
277 #[test]
278 fn translation_mode_resolution() {
279 assert!(TranslationMode::ForceOn.resolve(false));
280 assert!(!TranslationMode::ForceOff.resolve(true));
281 assert!(TranslationMode::PoolDefault.resolve(true));
282 assert!(!TranslationMode::PoolDefault.resolve(false));
283 }
284}