ruvector_gnn/
query.rs

1//! Query API for RuVector GNN
2//!
3//! Provides high-level query interfaces for vector search, neural search,
4//! and subgraph extraction.
5
6use serde::{Deserialize, Serialize};
7
8/// Query mode for different search strategies
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum QueryMode {
11    /// Pure HNSW vector search
12    VectorSearch,
13    /// GNN-enhanced neural search
14    NeuralSearch,
15    /// Extract k-hop subgraph around results
16    SubgraphExtraction,
17    /// Differentiable search with soft attention
18    DifferentiableSearch,
19}
20
21/// Query configuration for RuVector searches
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct RuvectorQuery {
24    /// Query vector for similarity search
25    pub vector: Option<Vec<f32>>,
26    /// Text query (requires embedding model)
27    pub text: Option<String>,
28    /// Node ID for subgraph extraction
29    pub node_id: Option<u64>,
30    /// Search mode
31    pub mode: QueryMode,
32    /// Number of results to return
33    pub k: usize,
34    /// HNSW search parameter (exploration factor)
35    pub ef: usize,
36    /// GNN depth for neural search
37    pub gnn_depth: usize,
38    /// Temperature for differentiable search (higher = softer)
39    pub temperature: f32,
40    /// Whether to return attention weights
41    pub return_attention: bool,
42}
43
44impl Default for RuvectorQuery {
45    fn default() -> Self {
46        Self {
47            vector: None,
48            text: None,
49            node_id: None,
50            mode: QueryMode::VectorSearch,
51            k: 10,
52            ef: 50,
53            gnn_depth: 2,
54            temperature: 1.0,
55            return_attention: false,
56        }
57    }
58}
59
60impl RuvectorQuery {
61    /// Create a basic vector search query
62    ///
63    /// # Arguments
64    /// * `vector` - Query vector
65    /// * `k` - Number of results to return
66    ///
67    /// # Example
68    /// ```
69    /// use ruvector_gnn::query::RuvectorQuery;
70    ///
71    /// let query = RuvectorQuery::vector_search(vec![0.1, 0.2, 0.3], 10);
72    /// assert_eq!(query.k, 10);
73    /// ```
74    pub fn vector_search(vector: Vec<f32>, k: usize) -> Self {
75        Self {
76            vector: Some(vector),
77            mode: QueryMode::VectorSearch,
78            k,
79            ..Default::default()
80        }
81    }
82
83    /// Create a GNN-enhanced neural search query
84    ///
85    /// # Arguments
86    /// * `vector` - Query vector
87    /// * `k` - Number of results to return
88    /// * `gnn_depth` - Number of GNN layers to apply
89    ///
90    /// # Example
91    /// ```
92    /// use ruvector_gnn::query::RuvectorQuery;
93    ///
94    /// let query = RuvectorQuery::neural_search(vec![0.1, 0.2, 0.3], 10, 3);
95    /// assert_eq!(query.gnn_depth, 3);
96    /// ```
97    pub fn neural_search(vector: Vec<f32>, k: usize, gnn_depth: usize) -> Self {
98        Self {
99            vector: Some(vector),
100            mode: QueryMode::NeuralSearch,
101            k,
102            gnn_depth,
103            ..Default::default()
104        }
105    }
106
107    /// Create a subgraph extraction query
108    ///
109    /// # Arguments
110    /// * `vector` - Query vector
111    /// * `k` - Number of nodes in subgraph
112    ///
113    /// # Example
114    /// ```
115    /// use ruvector_gnn::query::RuvectorQuery;
116    ///
117    /// let query = RuvectorQuery::subgraph_search(vec![0.1, 0.2, 0.3], 20);
118    /// assert_eq!(query.k, 20);
119    /// ```
120    pub fn subgraph_search(vector: Vec<f32>, k: usize) -> Self {
121        Self {
122            vector: Some(vector),
123            mode: QueryMode::SubgraphExtraction,
124            k,
125            ..Default::default()
126        }
127    }
128
129    /// Create a differentiable search query with temperature
130    ///
131    /// # Arguments
132    /// * `vector` - Query vector
133    /// * `k` - Number of results
134    /// * `temperature` - Softmax temperature (higher = softer distribution)
135    pub fn differentiable_search(vector: Vec<f32>, k: usize, temperature: f32) -> Self {
136        Self {
137            vector: Some(vector),
138            mode: QueryMode::DifferentiableSearch,
139            k,
140            temperature,
141            return_attention: true,
142            ..Default::default()
143        }
144    }
145
146    /// Set text query (requires embedding model)
147    pub fn with_text(mut self, text: String) -> Self {
148        self.text = Some(text);
149        self
150    }
151
152    /// Set node ID for centered queries
153    pub fn with_node(mut self, node_id: u64) -> Self {
154        self.node_id = Some(node_id);
155        self
156    }
157
158    /// Set EF parameter for HNSW search
159    pub fn with_ef(mut self, ef: usize) -> Self {
160        self.ef = ef;
161        self
162    }
163
164    /// Enable attention weight return
165    pub fn with_attention(mut self) -> Self {
166        self.return_attention = true;
167        self
168    }
169}
170
171/// Subgraph representation
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct SubGraph {
174    /// Node IDs in the subgraph
175    pub nodes: Vec<u64>,
176    /// Edges as (from, to, weight) tuples
177    pub edges: Vec<(u64, u64, f32)>,
178}
179
180impl SubGraph {
181    /// Create a new empty subgraph
182    pub fn new() -> Self {
183        Self {
184            nodes: Vec::new(),
185            edges: Vec::new(),
186        }
187    }
188
189    /// Create subgraph with nodes and edges
190    pub fn with_edges(nodes: Vec<u64>, edges: Vec<(u64, u64, f32)>) -> Self {
191        Self { nodes, edges }
192    }
193
194    /// Get number of nodes
195    pub fn node_count(&self) -> usize {
196        self.nodes.len()
197    }
198
199    /// Get number of edges
200    pub fn edge_count(&self) -> usize {
201        self.edges.len()
202    }
203
204    /// Check if subgraph contains a node
205    pub fn contains_node(&self, node_id: u64) -> bool {
206        self.nodes.contains(&node_id)
207    }
208
209    /// Get average edge weight
210    pub fn average_edge_weight(&self) -> f32 {
211        if self.edges.is_empty() {
212            return 0.0;
213        }
214        let sum: f32 = self.edges.iter().map(|(_, _, w)| w).sum();
215        sum / self.edges.len() as f32
216    }
217}
218
219impl Default for SubGraph {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225/// Query result with nodes, scores, and optional metadata
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct QueryResult {
228    /// Matched node IDs
229    pub nodes: Vec<u64>,
230    /// Similarity scores (higher = more similar)
231    pub scores: Vec<f32>,
232    /// Optional node embeddings after GNN processing
233    pub embeddings: Option<Vec<Vec<f32>>>,
234    /// Optional attention weights from differentiable search
235    pub attention_weights: Option<Vec<Vec<f32>>>,
236    /// Optional subgraph extraction
237    pub subgraph: Option<SubGraph>,
238    /// Query latency in milliseconds
239    pub latency_ms: u64,
240}
241
242impl QueryResult {
243    /// Create a new empty query result
244    pub fn new() -> Self {
245        Self {
246            nodes: Vec::new(),
247            scores: Vec::new(),
248            embeddings: None,
249            attention_weights: None,
250            subgraph: None,
251            latency_ms: 0,
252        }
253    }
254
255    /// Create query result with nodes and scores
256    ///
257    /// # Arguments
258    /// * `nodes` - Node IDs
259    /// * `scores` - Similarity scores
260    ///
261    /// # Example
262    /// ```
263    /// use ruvector_gnn::query::QueryResult;
264    ///
265    /// let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
266    /// assert_eq!(result.nodes.len(), 3);
267    /// ```
268    pub fn with_nodes(nodes: Vec<u64>, scores: Vec<f32>) -> Self {
269        Self {
270            nodes,
271            scores,
272            embeddings: None,
273            attention_weights: None,
274            subgraph: None,
275            latency_ms: 0,
276        }
277    }
278
279    /// Add embeddings to the result
280    pub fn with_embeddings(mut self, embeddings: Vec<Vec<f32>>) -> Self {
281        self.embeddings = Some(embeddings);
282        self
283    }
284
285    /// Add attention weights to the result
286    pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
287        self.attention_weights = Some(attention);
288        self
289    }
290
291    /// Add subgraph to the result
292    pub fn with_subgraph(mut self, subgraph: SubGraph) -> Self {
293        self.subgraph = Some(subgraph);
294        self
295    }
296
297    /// Set query latency
298    pub fn with_latency(mut self, latency_ms: u64) -> Self {
299        self.latency_ms = latency_ms;
300        self
301    }
302
303    /// Get number of results
304    pub fn len(&self) -> usize {
305        self.nodes.len()
306    }
307
308    /// Check if result is empty
309    pub fn is_empty(&self) -> bool {
310        self.nodes.is_empty()
311    }
312
313    /// Get top-k results
314    pub fn top_k(&self, k: usize) -> Self {
315        let k = k.min(self.nodes.len());
316        Self {
317            nodes: self.nodes[..k].to_vec(),
318            scores: self.scores[..k].to_vec(),
319            embeddings: self.embeddings.as_ref().map(|e| e[..k].to_vec()),
320            attention_weights: self.attention_weights.as_ref().map(|a| a[..k].to_vec()),
321            subgraph: self.subgraph.clone(),
322            latency_ms: self.latency_ms,
323        }
324    }
325
326    /// Get the best result (highest score)
327    pub fn best(&self) -> Option<(u64, f32)> {
328        if self.nodes.is_empty() {
329            None
330        } else {
331            Some((self.nodes[0], self.scores[0]))
332        }
333    }
334
335    /// Filter results by minimum score
336    pub fn filter_by_score(mut self, min_score: f32) -> Self {
337        let mut filtered_nodes = Vec::new();
338        let mut filtered_scores = Vec::new();
339        let mut filtered_embeddings = Vec::new();
340        let mut filtered_attention = Vec::new();
341
342        for i in 0..self.nodes.len() {
343            if self.scores[i] >= min_score {
344                filtered_nodes.push(self.nodes[i]);
345                filtered_scores.push(self.scores[i]);
346
347                if let Some(ref emb) = self.embeddings {
348                    filtered_embeddings.push(emb[i].clone());
349                }
350
351                if let Some(ref att) = self.attention_weights {
352                    filtered_attention.push(att[i].clone());
353                }
354            }
355        }
356
357        self.nodes = filtered_nodes;
358        self.scores = filtered_scores;
359
360        if !filtered_embeddings.is_empty() {
361            self.embeddings = Some(filtered_embeddings);
362        }
363
364        if !filtered_attention.is_empty() {
365            self.attention_weights = Some(filtered_attention);
366        }
367
368        self
369    }
370}
371
372impl Default for QueryResult {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_query_mode_serialization() {
384        let mode = QueryMode::NeuralSearch;
385        let json = serde_json::to_string(&mode).unwrap();
386        let deserialized: QueryMode = serde_json::from_str(&json).unwrap();
387        assert_eq!(mode, deserialized);
388    }
389
390    #[test]
391    fn test_ruvector_query_default() {
392        let query = RuvectorQuery::default();
393        assert_eq!(query.k, 10);
394        assert_eq!(query.ef, 50);
395        assert_eq!(query.gnn_depth, 2);
396        assert_eq!(query.temperature, 1.0);
397        assert_eq!(query.mode, QueryMode::VectorSearch);
398        assert!(!query.return_attention);
399    }
400
401    #[test]
402    fn test_vector_search_query() {
403        let vector = vec![0.1, 0.2, 0.3, 0.4];
404        let query = RuvectorQuery::vector_search(vector.clone(), 5);
405
406        assert_eq!(query.vector, Some(vector));
407        assert_eq!(query.k, 5);
408        assert_eq!(query.mode, QueryMode::VectorSearch);
409    }
410
411    #[test]
412    fn test_neural_search_query() {
413        let vector = vec![0.1, 0.2, 0.3];
414        let query = RuvectorQuery::neural_search(vector.clone(), 10, 3);
415
416        assert_eq!(query.vector, Some(vector));
417        assert_eq!(query.k, 10);
418        assert_eq!(query.gnn_depth, 3);
419        assert_eq!(query.mode, QueryMode::NeuralSearch);
420    }
421
422    #[test]
423    fn test_subgraph_search_query() {
424        let vector = vec![0.5, 0.5];
425        let query = RuvectorQuery::subgraph_search(vector.clone(), 20);
426
427        assert_eq!(query.vector, Some(vector));
428        assert_eq!(query.k, 20);
429        assert_eq!(query.mode, QueryMode::SubgraphExtraction);
430    }
431
432    #[test]
433    fn test_differentiable_search_query() {
434        let vector = vec![0.3, 0.4, 0.5];
435        let query = RuvectorQuery::differentiable_search(vector.clone(), 15, 0.5);
436
437        assert_eq!(query.vector, Some(vector));
438        assert_eq!(query.k, 15);
439        assert_eq!(query.temperature, 0.5);
440        assert_eq!(query.mode, QueryMode::DifferentiableSearch);
441        assert!(query.return_attention);
442    }
443
444    #[test]
445    fn test_query_builder_pattern() {
446        let query = RuvectorQuery::vector_search(vec![0.1, 0.2], 5)
447            .with_text("hello world".to_string())
448            .with_node(42)
449            .with_ef(100)
450            .with_attention();
451
452        assert_eq!(query.text, Some("hello world".to_string()));
453        assert_eq!(query.node_id, Some(42));
454        assert_eq!(query.ef, 100);
455        assert!(query.return_attention);
456    }
457
458    #[test]
459    fn test_subgraph_new() {
460        let subgraph = SubGraph::new();
461        assert_eq!(subgraph.node_count(), 0);
462        assert_eq!(subgraph.edge_count(), 0);
463    }
464
465    #[test]
466    fn test_subgraph_with_edges() {
467        let nodes = vec![1, 2, 3];
468        let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.5)];
469        let subgraph = SubGraph::with_edges(nodes.clone(), edges.clone());
470
471        assert_eq!(subgraph.nodes, nodes);
472        assert_eq!(subgraph.edges, edges);
473        assert_eq!(subgraph.node_count(), 3);
474        assert_eq!(subgraph.edge_count(), 3);
475    }
476
477    #[test]
478    fn test_subgraph_contains_node() {
479        let nodes = vec![1, 2, 3];
480        let subgraph = SubGraph::with_edges(nodes, vec![]);
481
482        assert!(subgraph.contains_node(1));
483        assert!(subgraph.contains_node(2));
484        assert!(subgraph.contains_node(3));
485        assert!(!subgraph.contains_node(4));
486    }
487
488    #[test]
489    fn test_subgraph_average_edge_weight() {
490        let edges = vec![(1, 2, 0.8), (2, 3, 0.6), (1, 3, 0.4)];
491        let subgraph = SubGraph::with_edges(vec![1, 2, 3], edges);
492
493        let avg = subgraph.average_edge_weight();
494        assert!((avg - 0.6).abs() < 0.001);
495    }
496
497    #[test]
498    fn test_subgraph_empty_average() {
499        let subgraph = SubGraph::new();
500        assert_eq!(subgraph.average_edge_weight(), 0.0);
501    }
502
503    #[test]
504    fn test_query_result_new() {
505        let result = QueryResult::new();
506        assert!(result.is_empty());
507        assert_eq!(result.len(), 0);
508        assert_eq!(result.latency_ms, 0);
509    }
510
511    #[test]
512    fn test_query_result_with_nodes() {
513        let nodes = vec![1, 2, 3];
514        let scores = vec![0.9, 0.8, 0.7];
515        let result = QueryResult::with_nodes(nodes.clone(), scores.clone());
516
517        assert_eq!(result.nodes, nodes);
518        assert_eq!(result.scores, scores);
519        assert_eq!(result.len(), 3);
520        assert!(!result.is_empty());
521    }
522
523    #[test]
524    fn test_query_result_builder_pattern() {
525        let embeddings = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
526        let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4]];
527        let subgraph = SubGraph::with_edges(vec![1, 2], vec![(1, 2, 0.8)]);
528
529        let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8])
530            .with_embeddings(embeddings.clone())
531            .with_attention(attention.clone())
532            .with_subgraph(subgraph.clone())
533            .with_latency(100);
534
535        assert_eq!(result.embeddings, Some(embeddings));
536        assert_eq!(result.attention_weights, Some(attention));
537        assert_eq!(result.subgraph, Some(subgraph));
538        assert_eq!(result.latency_ms, 100);
539    }
540
541    #[test]
542    fn test_query_result_top_k() {
543        let nodes = vec![1, 2, 3, 4, 5];
544        let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
545        let result = QueryResult::with_nodes(nodes, scores);
546
547        let top_3 = result.top_k(3);
548        assert_eq!(top_3.len(), 3);
549        assert_eq!(top_3.nodes, vec![1, 2, 3]);
550        assert_eq!(top_3.scores, vec![0.9, 0.8, 0.7]);
551    }
552
553    #[test]
554    fn test_query_result_top_k_overflow() {
555        let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]);
556        let top_10 = result.top_k(10);
557        assert_eq!(top_10.len(), 2); // Should only return available results
558    }
559
560    #[test]
561    fn test_query_result_best() {
562        let result = QueryResult::with_nodes(vec![1, 2, 3], vec![0.9, 0.8, 0.7]);
563        let best = result.best();
564        assert_eq!(best, Some((1, 0.9)));
565    }
566
567    #[test]
568    fn test_query_result_best_empty() {
569        let result = QueryResult::new();
570        assert_eq!(result.best(), None);
571    }
572
573    #[test]
574    fn test_query_result_filter_by_score() {
575        let nodes = vec![1, 2, 3, 4, 5];
576        let scores = vec![0.9, 0.8, 0.7, 0.6, 0.5];
577        let result = QueryResult::with_nodes(nodes, scores);
578
579        let filtered = result.filter_by_score(0.7);
580        assert_eq!(filtered.len(), 3);
581        assert_eq!(filtered.nodes, vec![1, 2, 3]);
582        assert_eq!(filtered.scores, vec![0.9, 0.8, 0.7]);
583    }
584
585    #[test]
586    fn test_query_result_filter_with_embeddings() {
587        let nodes = vec![1, 2, 3];
588        let scores = vec![0.9, 0.6, 0.8];
589        let embeddings = vec![vec![0.1], vec![0.2], vec![0.3]];
590
591        let result = QueryResult::with_nodes(nodes, scores).with_embeddings(embeddings);
592
593        let filtered = result.filter_by_score(0.7);
594        assert_eq!(filtered.len(), 2);
595        assert_eq!(filtered.nodes, vec![1, 3]);
596        assert_eq!(filtered.embeddings, Some(vec![vec![0.1], vec![0.3]]));
597    }
598
599    #[test]
600    fn test_query_result_filter_with_attention() {
601        let nodes = vec![1, 2, 3];
602        let scores = vec![0.9, 0.5, 0.8];
603        let attention = vec![vec![0.5, 0.5], vec![0.6, 0.4], vec![0.7, 0.3]];
604
605        let result = QueryResult::with_nodes(nodes, scores).with_attention(attention);
606
607        let filtered = result.filter_by_score(0.75);
608        assert_eq!(filtered.len(), 2);
609        assert_eq!(filtered.nodes, vec![1, 3]);
610        assert_eq!(
611            filtered.attention_weights,
612            Some(vec![vec![0.5, 0.5], vec![0.7, 0.3]])
613        );
614    }
615
616    #[test]
617    fn test_query_serialization() {
618        let query = RuvectorQuery::neural_search(vec![0.1, 0.2], 5, 2);
619        let json = serde_json::to_string(&query).unwrap();
620        let deserialized: RuvectorQuery = serde_json::from_str(&json).unwrap();
621
622        assert_eq!(deserialized.k, query.k);
623        assert_eq!(deserialized.gnn_depth, query.gnn_depth);
624        assert_eq!(deserialized.mode, query.mode);
625    }
626
627    #[test]
628    fn test_result_serialization() {
629        let result = QueryResult::with_nodes(vec![1, 2], vec![0.9, 0.8]).with_latency(50);
630
631        let json = serde_json::to_string(&result).unwrap();
632        let deserialized: QueryResult = serde_json::from_str(&json).unwrap();
633
634        assert_eq!(deserialized.nodes, result.nodes);
635        assert_eq!(deserialized.scores, result.scores);
636        assert_eq!(deserialized.latency_ms, result.latency_ms);
637    }
638
639    #[test]
640    fn test_subgraph_serialization() {
641        let subgraph = SubGraph::with_edges(vec![1, 2, 3], vec![(1, 2, 0.8), (2, 3, 0.6)]);
642
643        let json = serde_json::to_string(&subgraph).unwrap();
644        let deserialized: SubGraph = serde_json::from_str(&json).unwrap();
645
646        assert_eq!(deserialized.nodes, subgraph.nodes);
647        assert_eq!(deserialized.edges, subgraph.edges);
648    }
649
650    #[test]
651    fn test_edge_case_empty_filter() {
652        let result = QueryResult::with_nodes(vec![1, 2], vec![0.5, 0.4]);
653        let filtered = result.filter_by_score(0.9);
654
655        assert!(filtered.is_empty());
656        assert_eq!(filtered.len(), 0);
657    }
658
659    #[test]
660    fn test_query_mode_variants() {
661        // Test all query mode variants
662        assert_eq!(QueryMode::VectorSearch, QueryMode::VectorSearch);
663        assert_ne!(QueryMode::VectorSearch, QueryMode::NeuralSearch);
664        assert_ne!(QueryMode::NeuralSearch, QueryMode::SubgraphExtraction);
665        assert_ne!(
666            QueryMode::SubgraphExtraction,
667            QueryMode::DifferentiableSearch
668        );
669    }
670}