1use super::query_builder::{FieldOperator, FieldQuery, Query, QueryNode, QueryValue, VectorQuery};
25
26#[derive(Debug, Clone)]
29pub struct TranslatedQuery {
30 pub query: String,
32 pub params: Vec<(String, Vec<u8>)>,
35}
36
37impl TranslatedQuery {
38 pub fn simple(query: String) -> Self {
40 Self {
41 query,
42 params: Vec::new(),
43 }
44 }
45
46 pub fn has_params(&self) -> bool {
48 !self.params.is_empty()
49 }
50}
51
52pub struct RediSearchTranslator;
54
55impl RediSearchTranslator {
56 pub fn translate(query: &Query) -> String {
59 Self::translate_with_params(query).query
60 }
61
62 pub fn translate_with_params(query: &Query) -> TranslatedQuery {
65 Self::translate_node_with_params(&query.root)
66 }
67
68 fn translate_node_with_params(node: &QueryNode) -> TranslatedQuery {
69 match node {
70 QueryNode::Field(field_query) => {
71 TranslatedQuery::simple(Self::translate_field(field_query))
72 }
73 QueryNode::And(nodes) => {
74 let (vector_nodes, filter_nodes): (Vec<_>, Vec<_>) = nodes
76 .iter()
77 .partition(|n| matches!(n, QueryNode::Vector(_)));
78
79 if let Some(QueryNode::Vector(vq)) = vector_nodes.first() {
80 let filter = if filter_nodes.is_empty() {
82 "*".to_string()
83 } else {
84 let parts: Vec<String> = filter_nodes
85 .iter()
86 .map(|n| Self::translate_node_with_params(n).query)
87 .collect();
88 if parts.len() == 1 {
89 parts[0].clone()
90 } else {
91 format!("({})", parts.join(" "))
92 }
93 };
94 Self::translate_vector_query(vq, &filter)
95 } else {
96 let parts: Vec<String> = nodes
98 .iter()
99 .map(|n| Self::translate_node_with_params(n).query)
100 .collect();
101 let query = if parts.len() == 1 {
102 parts[0].clone()
103 } else {
104 format!("({})", parts.join(" "))
105 };
106 TranslatedQuery::simple(query)
107 }
108 }
109 QueryNode::Or(nodes) => {
110 let parts: Vec<String> = nodes
111 .iter()
112 .map(|n| Self::translate_node_with_params(n).query)
113 .collect();
114 let query = if parts.len() == 1 {
115 parts[0].clone()
116 } else {
117 format!("({})", parts.join(" | "))
118 };
119 TranslatedQuery::simple(query)
120 }
121 QueryNode::Not(inner) => {
122 let inner_result = Self::translate_node_with_params(inner);
123 TranslatedQuery {
124 query: format!("-({})", inner_result.query),
125 params: inner_result.params,
126 }
127 }
128 QueryNode::Vector(vq) => {
129 Self::translate_vector_query(vq, "*")
131 }
132 }
133 }
134
135 fn translate_vector_query(vq: &VectorQuery, filter: &str) -> TranslatedQuery {
138 let param_name = "vec_blob";
139
140 let blob: Vec<u8> = vq.vector
142 .iter()
143 .flat_map(|f| f.to_le_bytes())
144 .collect();
145
146 let query = format!(
147 "({})=>[KNN {} @{} ${} AS vector_score]",
148 filter,
149 vq.k,
150 Self::escape_field_name(&vq.field),
151 param_name
152 );
153
154 TranslatedQuery {
155 query,
156 params: vec![(param_name.to_string(), blob)],
157 }
158 }
159
160 #[allow(dead_code)]
161 fn translate_node(node: &QueryNode) -> String {
162 Self::translate_node_with_params(node).query
163 }
164
165 fn translate_field(field: &FieldQuery) -> String {
166 let field_name = Self::escape_field_name(&field.field);
167
168 match (&field.operator, &field.value) {
169 (FieldOperator::Equals, QueryValue::Text(text)) => {
170 let escaped = Self::escape_special_chars(text);
173 if text.contains(' ') {
174 format!("@{}:({})", field_name, escaped)
175 } else {
176 format!("@{}:{}", field_name, escaped)
177 }
178 }
179 (FieldOperator::Equals, QueryValue::Numeric(num)) => {
180 format!("@{}:[{} {}]", field_name, num, num)
181 }
182 (FieldOperator::Equals, QueryValue::Boolean(b)) => {
183 format!("@{}:{}", field_name, if *b { "true" } else { "false" })
184 }
185 (FieldOperator::Contains, QueryValue::Text(text)) => {
186 format!("@{}:*{}*", field_name, Self::escape_special_chars(text))
187 }
188 (FieldOperator::Range, QueryValue::NumericRange { min, max }) => {
189 let min_str = min.map(|v| v.to_string()).unwrap_or_else(|| "-inf".to_string());
190 let max_str = max.map(|v| v.to_string()).unwrap_or_else(|| "+inf".to_string());
191 format!("@{}:[{} {}]", field_name, min_str, max_str)
192 }
193 (FieldOperator::In, QueryValue::Tags(tags)) => {
194 let tag_str = tags
195 .iter()
196 .map(|t| Self::escape_value(t))
197 .collect::<Vec<_>>()
198 .join("|");
199 format!("@{}:{{{}}}", field_name, tag_str)
200 }
201 (FieldOperator::Prefix, QueryValue::Text(text)) => {
202 format!("@{}:{}*", field_name, Self::escape_value(text))
203 }
204 (FieldOperator::Fuzzy, QueryValue::Text(text)) => {
205 format!("@{}:%{}%", field_name, Self::escape_value(text))
206 }
207 _ => {
208 format!("@{}:{:?}", field_name, field.value)
210 }
211 }
212 }
213
214 fn escape_field_name(field: &str) -> String {
215 if field.contains(|c: char| !c.is_alphanumeric() && c != '_') {
217 format!("`{}`", field)
218 } else {
219 field.to_string()
220 }
221 }
222
223 fn escape_special_chars(value: &str) -> String {
225 let mut escaped = String::new();
226 for c in value.chars() {
227 match c {
228 '@' | ':' | '|' | '(' | ')' | '[' | ']' | '{' | '}' | '*' | '%' | '-' | '+' => {
230 escaped.push('\\');
231 escaped.push(c);
232 }
233 _ => escaped.push(c),
234 }
235 }
236 escaped
237 }
238
239 fn escape_value(value: &str) -> String {
241 let mut escaped = String::new();
242 for c in value.chars() {
243 match c {
244 '@' | ':' | '|' | '(' | ')' | '[' | ']' | '{' | '}' | '*' | '%' | '-' | '+' | ' ' => {
246 escaped.push('\\');
247 escaped.push(c);
248 }
249 _ => escaped.push(c),
250 }
251 }
252 escaped
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_simple_field_query() {
262 let query = Query::field_eq("name", "Alice");
263 let redis_query = RediSearchTranslator::translate(&query);
264 assert_eq!(redis_query, "@name:Alice");
265 }
266
267 #[test]
268 fn test_field_with_spaces() {
269 let query = Query::field_eq("name", "Alice Smith");
271 let redis_query = RediSearchTranslator::translate(&query);
272 assert_eq!(redis_query, "@name:(Alice Smith)");
273 }
274
275 #[test]
276 fn test_numeric_range() {
277 let query = Query::numeric_range("age", Some(25.0), Some(40.0));
278 let redis_query = RediSearchTranslator::translate(&query);
279 assert_eq!(redis_query, "@age:[25 40]");
280 }
281
282 #[test]
283 fn test_numeric_range_unbounded_min() {
284 let query = Query::numeric_range("age", None, Some(40.0));
285 let redis_query = RediSearchTranslator::translate(&query);
286 assert_eq!(redis_query, "@age:[-inf 40]");
287 }
288
289 #[test]
290 fn test_numeric_range_unbounded_max() {
291 let query = Query::numeric_range("score", Some(100.0), None);
292 let redis_query = RediSearchTranslator::translate(&query);
293 assert_eq!(redis_query, "@score:[100 +inf]");
294 }
295
296 #[test]
297 fn test_tag_query() {
298 let query = Query::tags("tags", vec!["rust".to_string(), "database".to_string()]);
299 let redis_query = RediSearchTranslator::translate(&query);
300 assert_eq!(redis_query, "@tags:{rust|database}");
301 }
302
303 #[test]
304 fn test_and_query() {
305 let query = Query::field_eq("name", "Alice")
306 .and(Query::numeric_range("age", Some(25.0), Some(40.0)));
307 let redis_query = RediSearchTranslator::translate(&query);
308 assert_eq!(redis_query, "(@name:Alice @age:[25 40])");
309 }
310
311 #[test]
312 fn test_or_query() {
313 let query = Query::field_eq("status", "active")
314 .or(Query::field_eq("status", "pending"));
315 let redis_query = RediSearchTranslator::translate(&query);
316 assert_eq!(redis_query, "(@status:active | @status:pending)");
317 }
318
319 #[test]
320 fn test_not_query() {
321 let query = Query::field_eq("deleted", "true").negate();
322 let redis_query = RediSearchTranslator::translate(&query);
323 assert_eq!(redis_query, "-(@deleted:true)");
324 }
325
326 #[test]
327 fn test_contains_query() {
328 let query = Query::text_search("description", "database");
329 let redis_query = RediSearchTranslator::translate(&query);
330 assert_eq!(redis_query, "@description:*database*");
331 }
332
333 #[test]
334 fn test_prefix_query() {
335 let query = Query::prefix("email", "admin");
336 let redis_query = RediSearchTranslator::translate(&query);
337 assert_eq!(redis_query, "@email:admin*");
338 }
339
340 #[test]
341 fn test_fuzzy_query() {
342 let query = Query::fuzzy("name", "alice");
343 let redis_query = RediSearchTranslator::translate(&query);
344 assert_eq!(redis_query, "@name:%alice%");
345 }
346
347 #[test]
348 fn test_complex_query() {
349 let alice_query = Query::field_eq("name", "Alice")
351 .and(Query::numeric_range("age", Some(25.0), Some(40.0)));
352
353 let bob_query = Query::field_eq("name", "Bob")
354 .and(Query::tags("tags", vec!["rust".to_string(), "database".to_string()]));
355
356 let query = alice_query.or(bob_query);
357 let redis_query = RediSearchTranslator::translate(&query);
358
359 assert_eq!(
360 redis_query,
361 "((@name:Alice @age:[25 40]) | (@name:Bob @tags:{rust|database}))"
362 );
363 }
364
365 #[test]
366 fn test_escape_special_chars() {
367 let query = Query::field_eq("email", "user@example.com");
368 let redis_query = RediSearchTranslator::translate(&query);
369 assert_eq!(redis_query, "@email:user\\@example.com");
370 }
371
372 #[test]
373 fn test_escape_colon() {
374 let query = Query::field_eq("time", "12:30");
375 let redis_query = RediSearchTranslator::translate(&query);
376 assert_eq!(redis_query, "@time:12\\:30");
377 }
378
379 #[test]
380 fn test_vector_query_basic() {
381 let embedding = vec![0.1, 0.2, 0.3, 0.4];
382 let query = Query::vector("embedding", embedding.clone(), 10);
383 let result = RediSearchTranslator::translate_with_params(&query);
384
385 assert_eq!(result.query, "(*)=>[KNN 10 @embedding $vec_blob AS vector_score]");
387
388 assert_eq!(result.params.len(), 1);
390 assert_eq!(result.params[0].0, "vec_blob");
391
392 let blob = &result.params[0].1;
394 assert_eq!(blob.len(), 4 * 4); }
396
397 #[test]
398 fn test_vector_query_with_filter() {
399 let embedding = vec![0.1, 0.2, 0.3, 0.4];
400 let query = Query::vector_filtered(
401 Query::tags("category", vec!["tech".into()]),
402 "embedding",
403 embedding,
404 5,
405 );
406 let result = RediSearchTranslator::translate_with_params(&query);
407
408 assert_eq!(result.query, "(@category:{tech})=>[KNN 5 @embedding $vec_blob AS vector_score]");
410 assert!(result.has_params());
411 }
412
413 #[test]
414 fn test_vector_query_complex_filter() {
415 let embedding = vec![1.0, 2.0];
416 let filter = Query::field_eq("status", "active")
417 .and(Query::numeric_range("age", Some(18.0), None));
418
419 let query = Query::vector_filtered(filter, "vec", embedding, 20);
420 let result = RediSearchTranslator::translate_with_params(&query);
421
422 assert!(result.query.contains("@status:active"));
424 assert!(result.query.contains("@age:[18 +inf]"));
425 assert!(result.query.contains("=>[KNN 20 @vec $vec_blob AS vector_score]"));
426 }
427
428 #[test]
429 fn test_vector_blob_encoding() {
430 let embedding = vec![1.0f32];
432 let query = Query::vector("field", embedding, 1);
433 let result = RediSearchTranslator::translate_with_params(&query);
434
435 let blob = &result.params[0].1;
436 assert_eq!(blob, &[0x00, 0x00, 0x80, 0x3f]);
438 }
439
440 #[test]
441 fn test_non_vector_query_has_no_params() {
442 let query = Query::field_eq("name", "test");
443 let result = RediSearchTranslator::translate_with_params(&query);
444
445 assert!(!result.has_params());
446 assert!(result.params.is_empty());
447 }
448}