1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum QueryMode {
11 VectorSearch,
13 NeuralSearch,
15 SubgraphExtraction,
17 DifferentiableSearch,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct RuvectorQuery {
24 pub vector: Option<Vec<f32>>,
26 pub text: Option<String>,
28 pub node_id: Option<u64>,
30 pub mode: QueryMode,
32 pub k: usize,
34 pub ef: usize,
36 pub gnn_depth: usize,
38 pub temperature: f32,
40 pub return_attention: bool,
42}
43
44impl Default for RuvectorQuery {
45 fn default() -> Self {
46 Self {
47 vector: None,
48 text: None,
49 node_id: None,
50 mode: QueryMode::VectorSearch,
51 k: 10,
52 ef: 50,
53 gnn_depth: 2,
54 temperature: 1.0,
55 return_attention: false,
56 }
57 }
58}
59
60impl RuvectorQuery {
61 pub fn vector_search(vector: Vec<f32>, k: usize) -> Self {
75 Self {
76 vector: Some(vector),
77 mode: QueryMode::VectorSearch,
78 k,
79 ..Default::default()
80 }
81 }
82
83 pub fn neural_search(vector: Vec<f32>, k: usize, gnn_depth: usize) -> Self {
98 Self {
99 vector: Some(vector),
100 mode: QueryMode::NeuralSearch,
101 k,
102 gnn_depth,
103 ..Default::default()
104 }
105 }
106
107 pub fn subgraph_search(vector: Vec<f32>, k: usize) -> Self {
121 Self {
122 vector: Some(vector),
123 mode: QueryMode::SubgraphExtraction,
124 k,
125 ..Default::default()
126 }
127 }
128
129 pub fn differentiable_search(vector: Vec<f32>, k: usize, temperature: f32) -> Self {
136 Self {
137 vector: Some(vector),
138 mode: QueryMode::DifferentiableSearch,
139 k,
140 temperature,
141 return_attention: true,
142 ..Default::default()
143 }
144 }
145
146 pub fn with_text(mut self, text: String) -> Self {
148 self.text = Some(text);
149 self
150 }
151
152 pub fn with_node(mut self, node_id: u64) -> Self {
154 self.node_id = Some(node_id);
155 self
156 }
157
158 pub fn with_ef(mut self, ef: usize) -> Self {
160 self.ef = ef;
161 self
162 }
163
164 pub fn with_attention(mut self) -> Self {
166 self.return_attention = true;
167 self
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct SubGraph {
174 pub nodes: Vec<u64>,
176 pub edges: Vec<(u64, u64, f32)>,
178}
179
180impl SubGraph {
181 pub fn new() -> Self {
183 Self {
184 nodes: Vec::new(),
185 edges: Vec::new(),
186 }
187 }
188
189 pub fn with_edges(nodes: Vec<u64>, edges: Vec<(u64, u64, f32)>) -> Self {
191 Self { nodes, edges }
192 }
193
194 pub fn node_count(&self) -> usize {
196 self.nodes.len()
197 }
198
199 pub fn edge_count(&self) -> usize {
201 self.edges.len()
202 }
203
204 pub fn contains_node(&self, node_id: u64) -> bool {
206 self.nodes.contains(&node_id)
207 }
208
209 pub fn average_edge_weight(&self) -> f32 {
211 if self.edges.is_empty() {
212 return 0.0;
213 }
214 let sum: f32 = self.edges.iter().map(|(_, _, w)| w).sum();
215 sum / self.edges.len() as f32
216 }
217}
218
219impl Default for SubGraph {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct QueryResult {
228 pub nodes: Vec<u64>,
230 pub scores: Vec<f32>,
232 pub embeddings: Option<Vec<Vec<f32>>>,
234 pub attention_weights: Option<Vec<Vec<f32>>>,
236 pub subgraph: Option<SubGraph>,
238 pub latency_ms: u64,
240}
241
242impl QueryResult {
243 pub fn new() -> Self {
245 Self {
246 nodes: Vec::new(),
247 scores: Vec::new(),
248 embeddings: None,
249 attention_weights: None,
250 subgraph: None,
251 latency_ms: 0,
252 }
253 }
254
255 pub fn with_nodes(nodes: Vec<u64>, scores: Vec<f32>) -> Self {
269 Self {
270 nodes,
271 scores,
272 embeddings: None,
273 attention_weights: None,
274 subgraph: None,
275 latency_ms: 0,
276 }
277 }
278
279 pub fn with_embeddings(mut self, embeddings: Vec<Vec<f32>>) -> Self {
281 self.embeddings = Some(embeddings);
282 self
283 }
284
285 pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
287 self.attention_weights = Some(attention);
288 self
289 }
290
291 pub fn with_subgraph(mut self, subgraph: SubGraph) -> Self {
293 self.subgraph = Some(subgraph);
294 self
295 }
296
297 pub fn with_latency(mut self, latency_ms: u64) -> Self {
299 self.latency_ms = latency_ms;
300 self
301 }
302
303 pub fn len(&self) -> usize {
305 self.nodes.len()
306 }
307
308 pub fn is_empty(&self) -> bool {
310 self.nodes.is_empty()
311 }
312
313 pub fn top_k(&self, k: usize) -> Self {
315 let k = k.min(self.nodes.len());
316 Self {
317 nodes: self.nodes[..k].to_vec(),
318 scores: self.scores[..k].to_vec(),
319 embeddings: self.embeddings.as_ref().map(|e| e[..k].to_vec()),
320 attention_weights: self.attention_weights.as_ref().map(|a| a[..k].to_vec()),
321 subgraph: self.subgraph.clone(),
322 latency_ms: self.latency_ms,
323 }
324 }
325
326 pub fn best(&self) -> Option<(u64, f32)> {
328 if self.nodes.is_empty() {
329 None
330 } else {
331 Some((self.nodes[0], self.scores[0]))
332 }
333 }
334
335 pub fn filter_by_score(mut self, min_score: f32) -> Self {
337 let mut filtered_nodes = Vec::new();
338 let mut filtered_scores = Vec::new();
339 let mut filtered_embeddings = Vec::new();
340 let mut filtered_attention = Vec::new();
341
342 for i in 0..self.nodes.len() {
343 if self.scores[i] >= min_score {
344 filtered_nodes.push(self.nodes[i]);
345 filtered_scores.push(self.scores[i]);
346
347 if let Some(ref emb) = self.embeddings {
348 filtered_embeddings.push(emb[i].clone());
349 }
350
351 if let Some(ref att) = self.attention_weights {
352 filtered_attention.push(att[i].clone());
353 }
354 }
355 }
356
357 self.nodes = filtered_nodes;
358 self.scores = filtered_scores;
359
360 if !filtered_embeddings.is_empty() {
361 self.embeddings = Some(filtered_embeddings);
362 }
363
364 if !filtered_attention.is_empty() {
365 self.attention_weights = Some(filtered_attention);
366 }
367
368 self
369 }
370}
371
372impl Default for QueryResult {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_query_mode_serialization() {
384 let mode = QueryMode::NeuralSearch;
385 let json = serde_json::to_string(&mode).unwrap();
386 let deserialized: QueryMode = serde_json::from_str(&json).unwrap();
387 assert_eq!(mode, deserialized);
388 }
389
390 #[test]
391 fn test_ruvector_query_default() {
392 let query = RuvectorQuery::default();
393 assert_eq!(query.k, 10);
394 assert_eq!(query.ef, 50);
395 assert_eq!(query.gnn_depth, 2);
396 assert_eq!(query.temperature, 1.0);
397 assert_eq!(query.mode, QueryMode::VectorSearch);
398 assert!(!query.return_attention);
399 }
400
401 #[test]
402 fn test_vector_search_query() {
403 let vector = vec![0.1, 0.2, 0.3, 0.4];
404 let query = RuvectorQuery::vector_search(vector.clone(), 5);
405
406 assert_eq!(query.vector, Some(vector));
407 assert_eq!(query.k, 5);
408 assert_eq!(query.mode, QueryMode::VectorSearch);
409 }
410
411 #[test]
412 fn test_neural_search_query() {
413 let vector = vec![0.1, 0.2, 0.3];
414 let query = RuvectorQuery::neural_search(vector.clone(), 10, 3);
415
416 assert_eq!(query.vector, Some(vector));
417 assert_eq!(query.k, 10);
418 assert_eq!(query.gnn_depth, 3);
419 assert_eq!(query.mode, QueryMode::NeuralSearch);
420 }
421
422 #[test]
423 fn test_subgraph_search_query() {
424 let vector = vec![0.5, 0.5];
425 let query = RuvectorQuery::subgraph_search(vector.clone(), 20);
426
427 assert_eq!(query.vector, Some(vector));
428 assert_eq!(query.k, 20);
429 assert_eq!(query.mode, QueryMode::SubgraphExtraction);
430 }
431
432 #[test]
433 fn test_differentiable_search_query() {
434 let vector = vec![0.3, 0.4, 0.5];
435 let query = RuvectorQuery::differentiable_search(vector.clone(), 15, 0.5);
436
437 assert_eq!(query.vector, Some(vector));
438 assert_eq!(query.k, 15);
439 assert_eq!(query.temperature, 0.5);
440 assert_eq!(query.mode, QueryMode::DifferentiableSearch);
441 assert!(query.return_attention);
442 }
443
444 #[test]
445 fn test_query_builder_pattern() {
446 let query = RuvectorQuery::vector_search(vec![0.1, 0.2], 5)
447 .with_text("hello world".to_string())
448 .with_node(42)
449 .with_ef(100)
450 .with_attention();
451
452 assert_eq!(query.text, Some("hello world".to_string()));
453 assert_eq!(query.node_id, Some(42));
454 assert_eq!(query.ef, 100);
455 assert!(query.return_attention);
456 }
457
458 #[test]
459 fn test_subgraph_new() {
460 let subgraph = SubGraph::new();
461 assert_eq!(subgraph.node_count(), 0);
462 assert_eq!(subgraph.edge_count(), 0);
463 }
464
465 #[test]
466 fn test_subgraph_with_edges() {
467 let nodes = vec![1, 2, 3];
468 let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.5)];
469 let subgraph = SubGraph::with_edges(nodes.clone(), edges.clone());
470
471 assert_eq!(subgraph.nodes, nodes);
472 assert_eq!(subgraph.edges, edges);
473 assert_eq!(subgraph.node_count(), 3);
474 assert_eq!(subgraph.edge_count(), 3);
475 }
476
477 #[test]
478 fn test_subgraph_contains_node() {
479 let nodes = vec![1, 2, 3];
480 let subgraph = SubGraph::with_edges(nodes, vec![]);
481
482 assert!(subgraph.contains_node(1));
483 assert!(subgraph.contains_node(2));
484 assert!(subgraph.contains_node(3));
485 assert!(!subgraph.contains_node(4));
486 }
487
488 #[test]
489 fn test_subgraph_average_edge_weight() {
490 let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.4)];
491 let subgraph = SubGraph::with_edges(vec![1, 2, 3], edges);
492
493 let avg = subgraph.average_edge_weight();
494 assert!((avg - 0.6).abs() < 0.001);
495 }
496
497 #[test]
498 fn test_subgraph_empty_average() {
499 let subgraph = SubGraph::new();
500 assert_eq!(subgraph.average_edge_weight(), 0.0);
501 }
502
503 #[test]
504 fn test_query_result_new() {
505 let result = QueryResult::new();
506 assert!(result.is_empty());
507 assert_eq!(result.len(), 0);
508 assert_eq!(result.latency_ms, 0);
509 }
510
511 #[test]
512 fn test_query_result_with_nodes() {
513 let nodes = vec![1, 2, 3];
514 let scores = vec![0.9, 0.8, 0.7];
515 let result = QueryResult::with_nodes(nodes.clone(), scores.clone());
516
517 assert_eq!(result.nodes, nodes);
518 assert_eq!(result.scores, scores);
519 assert_eq!(result.len(), 3);
520 assert!(!result.is_empty());
521 }
522
523 #[test]
524 fn test_query_result_builder_pattern() {
525 let embeddings = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
526 let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4]];
527 let subgraph = SubGraph::with_edges(vec![1, 2], vec![(1, 2, 0.8)]);
528
529 let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8])
530 .with_embeddings(embeddings.clone())
531 .with_attention(attention.clone())
532 .with_subgraph(subgraph.clone())
533 .with_latency(100);
534
535 assert_eq!(result.embeddings, Some(embeddings));
536 assert_eq!(result.attention_weights, Some(attention));
537 assert_eq!(result.subgraph, Some(subgraph));
538 assert_eq!(result.latency_ms, 100);
539 }
540
541 #[test]
542 fn test_query_result_top_k() {
543 let nodes = vec![1, 2, 3, 4, 5];
544 let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
545 let result = QueryResult::with_nodes(nodes, scores);
546
547 let top_3 = result.top_k(3);
548 assert_eq!(top_3.len(), 3);
549 assert_eq!(top_3.nodes, vec![1, 2, 3]);
550 assert_eq!(top_3.scores, vec![0.9, 0.8, 0.7]);
551 }
552
553 #[test]
554 fn test_query_result_top_k_overflow() {
555 let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]);
556 let top_10 = result.top_k(10);
557 assert_eq!(top_10.len(), 2); }
559
560 #[test]
561 fn test_query_result_best() {
562 let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
563 let best = result.best();
564 assert_eq!(best, Some((1, 0.9)));
565 }
566
567 #[test]
568 fn test_query_result_best_empty() {
569 let result = QueryResult::new();
570 assert_eq!(result.best(), None);
571 }
572
573 #[test]
574 fn test_query_result_filter_by_score() {
575 let nodes = vec![1, 2, 3, 4, 5];
576 let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
577 let result = QueryResult::with_nodes(nodes, scores);
578
579 let filtered = result.filter_by_score(0.7);
580 assert_eq!(filtered.len(), 3);
581 assert_eq!(filtered.nodes, vec![1, 2, 3]);
582 assert_eq!(filtered.scores, vec![0.9, 0.8, 0.7]);
583 }
584
585 #[test]
586 fn test_query_result_filter_with_embeddings() {
587 let nodes = vec![1, 2, 3];
588 let scores = vec![0.9, 0.6, 0.8];
589 let embeddings = vec![vec![0.1], vec![0.2], vec![0.3]];
590
591 let result = QueryResult::with_nodes(nodes, scores).with_embeddings(embeddings);
592
593 let filtered = result.filter_by_score(0.7);
594 assert_eq!(filtered.len(), 2);
595 assert_eq!(filtered.nodes, vec![1, 3]);
596 assert_eq!(filtered.embeddings, Some(vec![vec![0.1], vec![0.3]]));
597 }
598
599 #[test]
600 fn test_query_result_filter_with_attention() {
601 let nodes = vec![1, 2, 3];
602 let scores = vec![0.9, 0.5, 0.8];
603 let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4], vec![0.7, 0.3]];
604
605 let result = QueryResult::with_nodes(nodes, scores).with_attention(attention);
606
607 let filtered = result.filter_by_score(0.75);
608 assert_eq!(filtered.len(), 2);
609 assert_eq!(filtered.nodes, vec![1, 3]);
610 assert_eq!(
611 filtered.attention_weights,
612 Some(vec![vec![0.5, 0.5], vec![0.7, 0.3]])
613 );
614 }
615
616 #[test]
617 fn test_query_serialization() {
618 let query = RuvectorQuery::neural_search(vec![0.1, 0.2], 5, 2);
619 let json = serde_json::to_string(&query).unwrap();
620 let deserialized: RuvectorQuery = serde_json::from_str(&json).unwrap();
621
622 assert_eq!(deserialized.k, query.k);
623 assert_eq!(deserialized.gnn_depth, query.gnn_depth);
624 assert_eq!(deserialized.mode, query.mode);
625 }
626
627 #[test]
628 fn test_result_serialization() {
629 let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]).with_latency(50);
630
631 let json = serde_json::to_string(&result).unwrap();
632 let deserialized: QueryResult = serde_json::from_str(&json).unwrap();
633
634 assert_eq!(deserialized.nodes, result.nodes);
635 assert_eq!(deserialized.scores, result.scores);
636 assert_eq!(deserialized.latency_ms, result.latency_ms);
637 }
638
639 #[test]
640 fn test_subgraph_serialization() {
641 let subgraph = SubGraph::with_edges(vec![1, 2, 3], vec![(1, 2, 0.8), (2, 3, 0.6)]);
642
643 let json = serde_json::to_string(&subgraph).unwrap();
644 let deserialized: SubGraph = serde_json::from_str(&json).unwrap();
645
646 assert_eq!(deserialized.nodes, subgraph.nodes);
647 assert_eq!(deserialized.edges, subgraph.edges);
648 }
649
650 #[test]
651 fn test_edge_case_empty_filter() {
652 let result = QueryResult::with_nodes(vec![1, 2], vec![0.5, 0.4]);
653 let filtered = result.filter_by_score(0.9);
654
655 assert!(filtered.is_empty());
656 assert_eq!(filtered.len(), 0);
657 }
658
659 #[test]
660 fn test_query_mode_variants() {
661 assert_eq!(QueryMode::VectorSearch, QueryMode::VectorSearch);
663 assert_ne!(QueryMode::VectorSearch, QueryMode::NeuralSearch);
664 assert_ne!(QueryMode::NeuralSearch, QueryMode::SubgraphExtraction);
665 assert_ne!(
666 QueryMode::SubgraphExtraction,
667 QueryMode::DifferentiableSearch
668 );
669 }
670}