ruvector_graph/hybrid/
vector_index.rs

1//! Vector indexing for graph elements
2//!
3//! Integrates RuVector's index (HNSW or Flat) with graph nodes, edges, and hyperedges.
4
5use crate::error::{GraphError, Result};
6use crate::types::{EdgeId, NodeId, Properties, PropertyValue};
7use dashmap::DashMap;
8use parking_lot::RwLock;
9use ruvector_core::index::flat::FlatIndex;
10#[cfg(feature = "hnsw_rs")]
11use ruvector_core::index::hnsw::HnswIndex;
12use ruvector_core::index::VectorIndex;
13#[cfg(feature = "hnsw_rs")]
14use ruvector_core::types::HnswConfig;
15use ruvector_core::types::{DistanceMetric, SearchResult};
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18
19/// Type of graph element that can be indexed
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub enum VectorIndexType {
22    /// Node embeddings
23    Node,
24    /// Edge embeddings
25    Edge,
26    /// Hyperedge embeddings
27    Hyperedge,
28}
29
30/// Configuration for embedding storage
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct EmbeddingConfig {
33    /// Dimension of embeddings
34    pub dimensions: usize,
35    /// Distance metric for similarity
36    pub metric: DistanceMetric,
37    /// HNSW index configuration (only used when hnsw_rs feature is enabled)
38    #[cfg(feature = "hnsw_rs")]
39    pub hnsw_config: HnswConfig,
40    /// Property name where embeddings are stored
41    pub embedding_property: String,
42}
43
44impl Default for EmbeddingConfig {
45    fn default() -> Self {
46        Self {
47            dimensions: 384, // Common for small models like MiniLM
48            metric: DistanceMetric::Cosine,
49            #[cfg(feature = "hnsw_rs")]
50            hnsw_config: HnswConfig::default(),
51            embedding_property: "embedding".to_string(),
52        }
53    }
54}
55
56// Index type alias based on feature flags
57#[cfg(feature = "hnsw_rs")]
58type IndexImpl = HnswIndex;
59#[cfg(not(feature = "hnsw_rs"))]
60type IndexImpl = FlatIndex;
61
62/// Hybrid index combining graph structure with vector search
63pub struct HybridIndex {
64    /// Node embeddings index
65    node_index: Arc<RwLock<Option<IndexImpl>>>,
66    /// Edge embeddings index
67    edge_index: Arc<RwLock<Option<IndexImpl>>>,
68    /// Hyperedge embeddings index
69    hyperedge_index: Arc<RwLock<Option<IndexImpl>>>,
70
71    /// Mapping from node IDs to internal vector IDs
72    node_id_map: Arc<DashMap<NodeId, String>>,
73    /// Mapping from edge IDs to internal vector IDs
74    edge_id_map: Arc<DashMap<EdgeId, String>>,
75    /// Mapping from hyperedge IDs to internal vector IDs
76    hyperedge_id_map: Arc<DashMap<String, String>>,
77
78    /// Configuration
79    config: EmbeddingConfig,
80}
81
82impl HybridIndex {
83    /// Create a new hybrid index
84    pub fn new(config: EmbeddingConfig) -> Result<Self> {
85        Ok(Self {
86            node_index: Arc::new(RwLock::new(None)),
87            edge_index: Arc::new(RwLock::new(None)),
88            hyperedge_index: Arc::new(RwLock::new(None)),
89            node_id_map: Arc::new(DashMap::new()),
90            edge_id_map: Arc::new(DashMap::new()),
91            hyperedge_id_map: Arc::new(DashMap::new()),
92            config,
93        })
94    }
95
96    /// Initialize index for a specific element type
97    #[cfg(feature = "hnsw_rs")]
98    pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
99        let index = HnswIndex::new(
100            self.config.dimensions,
101            self.config.metric,
102            self.config.hnsw_config.clone(),
103        )
104        .map_err(|e| GraphError::IndexError(format!("Failed to create HNSW index: {}", e)))?;
105
106        match index_type {
107            VectorIndexType::Node => {
108                *self.node_index.write() = Some(index);
109            }
110            VectorIndexType::Edge => {
111                *self.edge_index.write() = Some(index);
112            }
113            VectorIndexType::Hyperedge => {
114                *self.hyperedge_index.write() = Some(index);
115            }
116        }
117
118        Ok(())
119    }
120
121    /// Initialize index for a specific element type (Flat index for WASM)
122    #[cfg(not(feature = "hnsw_rs"))]
123    pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
124        let index = FlatIndex::new(self.config.dimensions, self.config.metric);
125
126        match index_type {
127            VectorIndexType::Node => {
128                *self.node_index.write() = Some(index);
129            }
130            VectorIndexType::Edge => {
131                *self.edge_index.write() = Some(index);
132            }
133            VectorIndexType::Hyperedge => {
134                *self.hyperedge_index.write() = Some(index);
135            }
136        }
137
138        Ok(())
139    }
140
141    /// Add node embedding to index
142    pub fn add_node_embedding(&self, node_id: NodeId, embedding: Vec<f32>) -> Result<()> {
143        if embedding.len() != self.config.dimensions {
144            return Err(GraphError::InvalidEmbedding(format!(
145                "Expected {} dimensions, got {}",
146                self.config.dimensions,
147                embedding.len()
148            )));
149        }
150
151        let mut index_guard = self.node_index.write();
152        let index = index_guard
153            .as_mut()
154            .ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
155
156        let vector_id = format!("node_{}", node_id);
157        index
158            .add(vector_id.clone(), embedding)
159            .map_err(|e| GraphError::IndexError(format!("Failed to add node embedding: {}", e)))?;
160
161        self.node_id_map.insert(node_id, vector_id);
162        Ok(())
163    }
164
165    /// Add edge embedding to index
166    pub fn add_edge_embedding(&self, edge_id: EdgeId, embedding: Vec<f32>) -> Result<()> {
167        if embedding.len() != self.config.dimensions {
168            return Err(GraphError::InvalidEmbedding(format!(
169                "Expected {} dimensions, got {}",
170                self.config.dimensions,
171                embedding.len()
172            )));
173        }
174
175        let mut index_guard = self.edge_index.write();
176        let index = index_guard
177            .as_mut()
178            .ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
179
180        let vector_id = format!("edge_{}", edge_id);
181        index
182            .add(vector_id.clone(), embedding)
183            .map_err(|e| GraphError::IndexError(format!("Failed to add edge embedding: {}", e)))?;
184
185        self.edge_id_map.insert(edge_id, vector_id);
186        Ok(())
187    }
188
189    /// Add hyperedge embedding to index
190    pub fn add_hyperedge_embedding(&self, hyperedge_id: String, embedding: Vec<f32>) -> Result<()> {
191        if embedding.len() != self.config.dimensions {
192            return Err(GraphError::InvalidEmbedding(format!(
193                "Expected {} dimensions, got {}",
194                self.config.dimensions,
195                embedding.len()
196            )));
197        }
198
199        let mut index_guard = self.hyperedge_index.write();
200        let index = index_guard
201            .as_mut()
202            .ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
203
204        let vector_id = format!("hyperedge_{}", hyperedge_id);
205        index.add(vector_id.clone(), embedding).map_err(|e| {
206            GraphError::IndexError(format!("Failed to add hyperedge embedding: {}", e))
207        })?;
208
209        self.hyperedge_id_map.insert(hyperedge_id, vector_id);
210        Ok(())
211    }
212
213    /// Search for similar nodes
214    pub fn search_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
215        let index_guard = self.node_index.read();
216        let index = index_guard
217            .as_ref()
218            .ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
219
220        let results = index
221            .search(query, k)
222            .map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
223
224        Ok(results
225            .into_iter()
226            .filter_map(|result| {
227                // Remove "node_" prefix to get original ID
228                let node_id = result.id.strip_prefix("node_")?.to_string();
229                Some((node_id, result.score))
230            })
231            .collect())
232    }
233
234    /// Search for similar edges
235    pub fn search_similar_edges(&self, query: &[f32], k: usize) -> Result<Vec<(EdgeId, f32)>> {
236        let index_guard = self.edge_index.read();
237        let index = index_guard
238            .as_ref()
239            .ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
240
241        let results = index
242            .search(query, k)
243            .map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
244
245        Ok(results
246            .into_iter()
247            .filter_map(|result| {
248                let edge_id = result.id.strip_prefix("edge_")?.to_string();
249                Some((edge_id, result.score))
250            })
251            .collect())
252    }
253
254    /// Search for similar hyperedges
255    pub fn search_similar_hyperedges(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
256        let index_guard = self.hyperedge_index.read();
257        let index = index_guard
258            .as_ref()
259            .ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
260
261        let results = index
262            .search(query, k)
263            .map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
264
265        Ok(results
266            .into_iter()
267            .filter_map(|result| {
268                let hyperedge_id = result.id.strip_prefix("hyperedge_")?.to_string();
269                Some((hyperedge_id, result.score))
270            })
271            .collect())
272    }
273
274    /// Extract embedding from properties
275    pub fn extract_embedding(&self, properties: &Properties) -> Result<Option<Vec<f32>>> {
276        let prop_value = match properties.get(&self.config.embedding_property) {
277            Some(v) => v,
278            None => return Ok(None),
279        };
280
281        match prop_value {
282            PropertyValue::Array(arr) => {
283                let embedding: Result<Vec<f32>> = arr
284                    .iter()
285                    .map(|v| match v {
286                        PropertyValue::Float(f) => Ok(*f as f32),
287                        PropertyValue::Integer(i) => Ok(*i as f32),
288                        _ => Err(GraphError::InvalidEmbedding(
289                            "Embedding array must contain numbers".to_string(),
290                        )),
291                    })
292                    .collect();
293                embedding.map(Some)
294            }
295            _ => Err(GraphError::InvalidEmbedding(
296                "Embedding property must be an array".to_string(),
297            )),
298        }
299    }
300
301    /// Get index statistics
302    pub fn stats(&self) -> HybridIndexStats {
303        let node_count = self.node_id_map.len();
304        let edge_count = self.edge_id_map.len();
305        let hyperedge_count = self.hyperedge_id_map.len();
306
307        HybridIndexStats {
308            node_count,
309            edge_count,
310            hyperedge_count,
311            total_embeddings: node_count + edge_count + hyperedge_count,
312        }
313    }
314}
315
316/// Statistics about the hybrid index
317#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct HybridIndexStats {
319    pub node_count: usize,
320    pub edge_count: usize,
321    pub hyperedge_count: usize,
322    pub total_embeddings: usize,
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_hybrid_index_creation() -> Result<()> {
331        let config = EmbeddingConfig::default();
332        let index = HybridIndex::new(config)?;
333
334        let stats = index.stats();
335        assert_eq!(stats.total_embeddings, 0);
336
337        Ok(())
338    }
339
340    #[test]
341    fn test_node_embedding_indexing() -> Result<()> {
342        let config = EmbeddingConfig {
343            dimensions: 4,
344            ..Default::default()
345        };
346        let index = HybridIndex::new(config)?;
347        index.initialize_index(VectorIndexType::Node)?;
348
349        let embedding = vec![1.0, 2.0, 3.0, 4.0];
350        index.add_node_embedding("node1".to_string(), embedding)?;
351
352        let stats = index.stats();
353        assert_eq!(stats.node_count, 1);
354
355        Ok(())
356    }
357
358    #[test]
359    fn test_similarity_search() -> Result<()> {
360        let config = EmbeddingConfig {
361            dimensions: 4,
362            ..Default::default()
363        };
364        let index = HybridIndex::new(config)?;
365        index.initialize_index(VectorIndexType::Node)?;
366
367        // Add some embeddings
368        index.add_node_embedding("node1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
369        index.add_node_embedding("node2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
370        index.add_node_embedding("node3".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
371
372        // Search for similar to node1
373        let results = index.search_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 2)?;
374
375        assert!(results.len() <= 2);
376        if !results.is_empty() {
377            assert_eq!(results[0].0, "node1");
378        }
379
380        Ok(())
381    }
382}