1use std::ops::ControlFlow;
33
34use sqlparser::ast::{
35 Expr, Ident, Statement, Value as AstValue, ValueWithSpan, visit_expressions_mut,
36};
37use sqlparser::tokenizer::Span;
38
39use crate::error::{Result, SQLRiteError};
40use crate::sql::db::table::Value;
41
42pub fn rewrite_placeholders(stmt: &mut Statement) -> usize {
51 let mut counter: usize = 0;
52 let _ = visit_expressions_mut(stmt, |expr| {
53 if let Expr::Value(v) = expr
54 && let AstValue::Placeholder(s) = &mut v.value
55 && s == "?"
56 {
57 counter += 1;
58 *s = format!("?{counter}");
59 }
60 ControlFlow::<()>::Continue(())
61 });
62 counter
63}
64
65pub fn substitute_params(stmt: &mut Statement, params: &[Value]) -> Result<()> {
72 let mut bind_err: Option<SQLRiteError> = None;
73 let _ = visit_expressions_mut(stmt, |expr| {
74 let Expr::Value(v) = expr else {
75 return ControlFlow::Continue(());
76 };
77 let placeholder_str = match &v.value {
78 AstValue::Placeholder(s) => s.clone(),
79 _ => return ControlFlow::Continue(()),
80 };
81 let idx = match placeholder_index(&placeholder_str) {
82 Some(i) => i,
83 None => {
84 bind_err = Some(SQLRiteError::NotImplemented(format!(
85 "unsupported placeholder form `{placeholder_str}`; only `?` and `?N` are supported"
86 )));
87 return ControlFlow::Break(());
88 }
89 };
90 let Some(value) = params.get(idx) else {
91 bind_err = Some(SQLRiteError::General(format!(
92 "missing bind value for `?{}` (got {} parameter{})",
93 idx + 1,
94 params.len(),
95 if params.len() == 1 { "" } else { "s" }
96 )));
97 return ControlFlow::Break(());
98 };
99 *expr = value_to_expr(value);
100 ControlFlow::<()>::Continue(())
101 });
102 if let Some(e) = bind_err {
103 return Err(e);
104 }
105 Ok(())
106}
107
108fn placeholder_index(s: &str) -> Option<usize> {
113 let n = s.strip_prefix('?')?.parse::<usize>().ok()?;
114 if n == 0 {
115 return None;
116 }
117 Some(n - 1)
118}
119
120fn value_to_expr(v: &Value) -> Expr {
124 match v {
125 Value::Null => Expr::Value(ValueWithSpan {
126 value: AstValue::Null,
127 span: Span::empty(),
128 }),
129 Value::Integer(i) => Expr::Value(ValueWithSpan {
130 value: AstValue::Number(i.to_string(), false),
131 span: Span::empty(),
132 }),
133 Value::Real(f) => Expr::Value(ValueWithSpan {
134 value: AstValue::Number(f.to_string(), false),
137 span: Span::empty(),
138 }),
139 Value::Text(s) => Expr::Value(ValueWithSpan {
140 value: AstValue::SingleQuotedString(s.clone()),
141 span: Span::empty(),
142 }),
143 Value::Bool(b) => Expr::Value(ValueWithSpan {
144 value: AstValue::Boolean(*b),
145 span: Span::empty(),
146 }),
147 Value::Vector(v) => {
148 let inner = format_vector_inner(v);
153 Expr::Identifier(Ident {
154 value: inner,
155 quote_style: Some('['),
156 span: Span::empty(),
157 })
158 }
159 }
160}
161
162fn format_vector_inner(v: &[f32]) -> String {
163 let mut s = String::with_capacity(v.len() * 10);
165 for (i, x) in v.iter().enumerate() {
166 if i > 0 {
167 s.push_str(", ");
168 }
169 s.push_str(&x.to_string());
170 }
171 s
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use sqlparser::dialect::SQLiteDialect;
178 use sqlparser::parser::Parser;
179
180 fn parse_one(sql: &str) -> Statement {
181 let mut ast = Parser::parse_sql(&SQLiteDialect {}, sql).unwrap();
182 ast.pop().unwrap()
183 }
184
185 #[test]
186 fn rewrite_assigns_indices_in_source_order() {
187 let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ? AND c = ?");
188 let n = rewrite_placeholders(&mut stmt);
189 assert_eq!(n, 3);
190 let sql = stmt.to_string();
191 assert!(sql.contains("?1"));
192 assert!(sql.contains("?2"));
193 assert!(sql.contains("?3"));
194 }
195
196 #[test]
197 fn rewrite_zero_for_no_placeholders() {
198 let mut stmt = parse_one("SELECT * FROM t WHERE a = 1");
199 assert_eq!(rewrite_placeholders(&mut stmt), 0);
200 }
201
202 #[test]
203 fn rewrite_idempotent_on_numbered_placeholders() {
204 let mut stmt = parse_one("SELECT * FROM t WHERE a = ?1 AND b = ?2");
207 let n = rewrite_placeholders(&mut stmt);
208 assert_eq!(n, 0);
213 }
214
215 #[test]
216 fn substitute_replaces_scalar_params() {
217 let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ? AND c = ?");
218 rewrite_placeholders(&mut stmt);
219 substitute_params(
220 &mut stmt,
221 &[
222 Value::Integer(1),
223 Value::Text("x".into()),
224 Value::Bool(true),
225 ],
226 )
227 .unwrap();
228 let sql = stmt.to_string();
229 assert!(sql.contains("a = 1"), "got: {sql}");
230 assert!(sql.contains("b = 'x'"), "got: {sql}");
231 assert!(sql.contains("c = true"), "got: {sql}");
233 }
234
235 #[test]
236 fn substitute_replaces_vector_param_as_bracket_array() {
237 let mut stmt = parse_one("SELECT id FROM t ORDER BY vec_distance_l2(v, ?) LIMIT 5");
238 rewrite_placeholders(&mut stmt);
239 substitute_params(&mut stmt, &[Value::Vector(vec![0.1, 0.2, 0.3])]).unwrap();
240 let sql = stmt.to_string();
241 assert!(sql.contains("[0.1, 0.2, 0.3]"), "got: {sql}");
243 }
244
245 #[test]
246 fn substitute_errors_on_too_few_params() {
247 let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ?");
248 rewrite_placeholders(&mut stmt);
249 let err = substitute_params(&mut stmt, &[Value::Integer(1)]).unwrap_err();
250 assert!(format!("{err}").contains("missing bind value"));
251 }
252
253 #[test]
254 fn substitute_replaces_null_param() {
255 let mut stmt = parse_one("SELECT * FROM t WHERE a = ?");
256 rewrite_placeholders(&mut stmt);
257 substitute_params(&mut stmt, &[Value::Null]).unwrap();
258 let sql = stmt.to_string();
259 assert!(sql.to_uppercase().contains("NULL"), "got: {sql}");
260 }
261
262 #[test]
263 fn placeholder_index_decodes_canonical_form() {
264 assert_eq!(placeholder_index("?1"), Some(0));
265 assert_eq!(placeholder_index("?42"), Some(41));
266 assert_eq!(placeholder_index("?"), None);
267 assert_eq!(placeholder_index("?0"), None);
268 assert_eq!(placeholder_index(":name"), None);
269 assert_eq!(placeholder_index("$1"), None);
270 }
271}