sql_middleware/translation/
mod.rs1use std::borrow::Cow;
2
3mod parsers;
4mod scanner;
5
6use parsers::{
7 is_block_comment_end, is_block_comment_start, is_line_comment_start, matches_tag,
8 try_start_dollar_quote,
9};
10use scanner::{State, scan_digits};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum PlaceholderStyle {
15 Postgres,
17 Sqlite,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TranslationMode {
33 PoolDefault,
35 ForceOn,
37 ForceOff,
39}
40
41impl TranslationMode {
42 #[must_use]
43 pub fn resolve(self, pool_default: bool) -> bool {
44 match self {
45 TranslationMode::PoolDefault => pool_default,
46 TranslationMode::ForceOn => true,
47 TranslationMode::ForceOff => false,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
54pub enum PrepareMode {
55 #[default]
57 Direct,
58 Prepared,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct QueryOptions {
65 pub translation: TranslationMode,
66 pub prepare: PrepareMode,
67}
68
69impl Default for QueryOptions {
70 fn default() -> Self {
71 Self {
72 translation: TranslationMode::PoolDefault,
73 prepare: PrepareMode::default(),
74 }
75 }
76}
77
78impl QueryOptions {
79 #[must_use]
80 pub fn with_translation(mut self, translation: TranslationMode) -> Self {
81 self.translation = translation;
82 self
83 }
84
85 #[must_use]
86 pub fn with_prepare(mut self, prepare: PrepareMode) -> Self {
87 self.prepare = prepare;
88 self
89 }
90}
91
92#[must_use]
116pub fn translate_placeholders(sql: &str, target: PlaceholderStyle, enabled: bool) -> Cow<'_, str> {
117 if !enabled {
118 return Cow::Borrowed(sql);
119 }
120
121 let mut out: Option<String> = None;
122 let mut state = State::Normal;
123 let mut idx = 0;
124 let bytes = sql.as_bytes();
125
126 while idx < bytes.len() {
127 let b = bytes[idx];
128 let mut replaced = false;
129 match state {
130 State::Normal => match b {
131 b'\'' => state = State::SingleQuoted,
132 b'"' => state = State::DoubleQuoted,
133 _ if is_line_comment_start(bytes, idx) => state = State::LineComment,
134 _ if is_block_comment_start(bytes, idx) => state = State::BlockComment(1),
135 b'$' => {
136 if let Some((tag, advance)) = try_start_dollar_quote(bytes, idx) {
137 state = State::DollarQuoted(tag);
138 idx = advance;
139 } else if matches!(target, PlaceholderStyle::Sqlite)
140 && let Some((digits_end, digits)) = scan_digits(bytes, idx + 1)
141 {
142 let buf = out.get_or_insert_with(|| sql[..idx].to_string());
143 buf.push('?');
144 buf.push_str(digits);
145 idx = digits_end - 1;
146 replaced = true;
147 }
148 }
149 b'?' if matches!(target, PlaceholderStyle::Postgres) => {
150 if let Some((digits_end, digits)) = scan_digits(bytes, idx + 1) {
151 let buf = out.get_or_insert_with(|| sql[..idx].to_string());
152 buf.push('$');
153 buf.push_str(digits);
154 idx = digits_end - 1;
155 replaced = true;
156 }
157 }
158 _ => {}
159 },
160 State::SingleQuoted => {
161 if b == b'\'' {
162 if bytes.get(idx + 1) == Some(&b'\'') {
163 idx += 1; } else {
165 state = State::Normal;
166 }
167 }
168 }
169 State::DoubleQuoted => {
170 if b == b'"' {
171 if bytes.get(idx + 1) == Some(&b'"') {
172 idx += 1; } else {
174 state = State::Normal;
175 }
176 }
177 }
178 State::LineComment => {
179 if b == b'\n' {
180 state = State::Normal;
181 }
182 }
183 State::BlockComment(depth) => {
184 if is_block_comment_start(bytes, idx) {
185 state = State::BlockComment(depth + 1);
186 } else if is_block_comment_end(bytes, idx) {
187 if depth == 1 {
188 state = State::Normal;
189 } else {
190 state = State::BlockComment(depth - 1);
191 }
192 }
193 }
194 State::DollarQuoted(ref tag) => {
195 if b == b'$' && matches_tag(bytes, idx, tag) {
196 let tag_len = tag.len();
197 state = State::Normal;
198 idx += tag_len;
199 }
200 }
201 }
202
203 if let Some(ref mut buf) = out
204 && !replaced
205 {
206 buf.push(b as char);
207 }
208
209 idx += 1;
210 }
211
212 match out {
213 Some(buf) => Cow::Owned(buf),
214 None => Cow::Borrowed(sql),
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn translates_sqlite_to_postgres() {
224 let sql = "select * from t where a = ?1 and b = ?2";
225 let res = translate_placeholders(sql, PlaceholderStyle::Postgres, true);
226 assert_eq!(res, "select * from t where a = $1 and b = $2");
227 }
228
229 #[test]
230 fn translates_postgres_to_sqlite() {
231 let sql = "insert into t values($1, $2)";
232 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
233 assert_eq!(res, "insert into t values(?1, ?2)");
234 }
235
236 #[test]
237 fn skips_inside_literals_and_comments() {
238 let sql = "select '?1', $1 -- $2\n/* ?3 */ from t where a = $1";
239 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
240 assert_eq!(res, "select '?1', ?1 -- $2\n/* ?3 */ from t where a = ?1");
241 }
242
243 #[test]
244 fn skips_dollar_quoted_blocks() {
245 let sql = "$foo$ select $1 from t $foo$ where a = $1";
246 let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
247 assert_eq!(res, "$foo$ select $1 from t $foo$ where a = ?1");
248 }
249
250 #[test]
251 fn respects_disabled_flag() {
252 let sql = "select * from t where a = ?1";
253 let res = translate_placeholders(sql, PlaceholderStyle::Postgres, false);
254 assert!(matches!(res, Cow::Borrowed(_)));
255 assert_eq!(res, sql);
256 }
257
258 #[test]
259 fn translation_mode_resolution() {
260 assert!(TranslationMode::ForceOn.resolve(false));
261 assert!(!TranslationMode::ForceOff.resolve(true));
262 assert!(TranslationMode::PoolDefault.resolve(true));
263 assert!(!TranslationMode::PoolDefault.resolve(false));
264 }
265}