sql_middleware/translation/
mod.rs1use std::borrow::Cow;
2
3mod parsers;
4mod scanner;
5
6use parsers::{
7 is_block_comment_end, is_block_comment_start, is_line_comment_start, matches_tag,
8 try_start_dollar_quote,
9};
10use scanner::{State, scan_digits};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum PlaceholderStyle {
15 Postgres,
17 Sqlite,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TranslationMode {
33 PoolDefault,
35 ForceOn,
37 ForceOff,
39}
40
41impl TranslationMode {
42 #[must_use]
43 pub fn resolve(self, pool_default: bool) -> bool {
44 match self {
45 TranslationMode::PoolDefault => pool_default,
46 TranslationMode::ForceOn => true,
47 TranslationMode::ForceOff => false,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct QueryOptions {
55 pub translation: TranslationMode,
56}
57
58impl Default for QueryOptions {
59 fn default() -> Self {
60 Self {
61 translation: TranslationMode::PoolDefault,
62 }
63 }
64}
65
66impl QueryOptions {
67 #[must_use]
68 pub fn with_translation(mut self, translation: TranslationMode) -> Self {
69 self.translation = translation;
70 self
71 }
72}
73
74#[must_use]
98pub fn translate_placeholders(sql: &str, target: PlaceholderStyle, enabled: bool) -> Cow<'_, str> {
99 if !enabled {
100 return Cow::Borrowed(sql);
101 }
102
103 let mut out: Option<String> = None;
104 let mut state = State::Normal;
105 let mut idx = 0;
106 let bytes = sql.as_bytes();
107
108 while idx < bytes.len() {
109 let b = bytes[idx];
110 let mut replaced = false;
111 match state {
112 State::Normal => match b {
113 b'\'' => state = State::SingleQuoted,
114 b'"' => state = State::DoubleQuoted,
115 _ if is_line_comment_start(bytes, idx) => state = State::LineComment,
116 _ if is_block_comment_start(bytes, idx) => state = State::BlockComment(1),
117 b'$' => {
118 if let Some((tag, advance)) = try_start_dollar_quote(bytes, idx) {
119 state = State::DollarQuoted(tag);
120 idx = advance;
121 } else if matches!(target, PlaceholderStyle::Sqlite)
122 && let Some((digits_end, digits)) = scan_digits(bytes, idx + 1)
123 {
124 let buf = out.get_or_insert_with(|| sql[..idx].to_string());
125 buf.push('?');
126 buf.push_str(digits);
127 idx = digits_end - 1;
128 replaced = true;
129 }
130 }
131 b'?' if matches!(target, PlaceholderStyle::Postgres) => {
132 if let Some((digits_end, digits)) = scan_digits(bytes, idx + 1) {
133 let buf = out.get_or_insert_with(|| sql[..idx].to_string());
134 buf.push('$');
135 buf.push_str(digits);
136 idx = digits_end - 1;
137 replaced = true;
138 }
139 }
140 _ => {}
141 },
142 State::SingleQuoted => {
143 if b == b'\'' {
144 if bytes.get(idx + 1) == Some(&b'\'') {
145 idx += 1; } else {
147 state = State::Normal;
148 }
149 }
150 }
151 State::DoubleQuoted => {
152 if b == b'"' {
153 if bytes.get(idx + 1) == Some(&b'"') {
154 idx += 1; } else {
156 state = State::Normal;
157 }
158 }
159 }
160 State::LineComment => {
161 if b == b'\n' {
162 state = State::Normal;
163 }
164 }
165 State::BlockComment(depth) => {
166 if is_block_comment_start(bytes, idx) {
167 state = State::BlockComment(depth + 1);
168 } else if is_block_comment_end(bytes, idx) {
169 if depth == 1 {
170 state = State::Normal;
171 } else {
172 state = State::BlockComment(depth - 1);
173 }
174 }
175 }
176 State::DollarQuoted(ref tag) => {
177 if b == b'$' && matches_tag(bytes, idx, tag) {
178 let tag_len = tag.len();
179 state = State::Normal;
180 idx += tag_len;
181 }
182 }
183 }
184
185 if let Some(ref mut buf) = out
186 && !replaced
187 {
188 buf.push(b as char);
189 }
190
191 idx += 1;
192 }
193
194 match out {
195 Some(buf) => Cow::Owned(buf),
196 None => Cow::Borrowed(sql),
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn translates_sqlite_to_postgres() {
206 let sql = "select * from t where a = ?1 and b = ?2";
207 let res = translate_placeholders(sql, PlaceholderStyle::Postgres, true);
208 assert_eq!(res, "select * from t where a = $1 and b = $2");
209 }
210
211 #[test]
212 fn translates_postgres_to_sqlite() {
213 let sql = "insert into t values($1, $2)";
214 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
215 assert_eq!(res, "insert into t values(?1, ?2)");
216 }
217
218 #[test]
219 fn skips_inside_literals_and_comments() {
220 let sql = "select '?1', $1 -- $2\n/* ?3 */ from t where a = $1";
221 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
222 assert_eq!(res, "select '?1', ?1 -- $2\n/* ?3 */ from t where a = ?1");
223 }
224
225 #[test]
226 fn skips_dollar_quoted_blocks() {
227 let sql = "$foo$ select $1 from t $foo$ where a = $1";
228 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
229 assert_eq!(res, "$foo$ select $1 from t $foo$ where a = ?1");
230 }
231
232 #[test]
233 fn respects_disabled_flag() {
234 let sql = "select * from t where a = ?1";
235 let res = translate_placeholders(sql, PlaceholderStyle::Postgres, false);
236 assert!(matches!(res, Cow::Borrowed(_)));
237 assert_eq!(res, sql);
238 }
239
240 #[test]
241 fn translation_mode_resolution() {
242 assert!(TranslationMode::ForceOn.resolve(false));
243 assert!(!TranslationMode::ForceOff.resolve(true));
244 assert!(TranslationMode::PoolDefault.resolve(true));
245 assert!(!TranslationMode::PoolDefault.resolve(false));
246 }
247}