1use crate::error::{GraphError, Result};
6use crate::types::{EdgeId, NodeId, Properties, PropertyValue};
7use dashmap::DashMap;
8use parking_lot::RwLock;
9use ruvector_core::index::flat::FlatIndex;
10#[cfg(feature = "hnsw_rs")]
11use ruvector_core::index::hnsw::HnswIndex;
12use ruvector_core::index::VectorIndex;
13#[cfg(feature = "hnsw_rs")]
14use ruvector_core::types::HnswConfig;
15use ruvector_core::types::{DistanceMetric, SearchResult};
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub enum VectorIndexType {
22 Node,
24 Edge,
26 Hyperedge,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct EmbeddingConfig {
33 pub dimensions: usize,
35 pub metric: DistanceMetric,
37 #[cfg(feature = "hnsw_rs")]
39 pub hnsw_config: HnswConfig,
40 pub embedding_property: String,
42}
43
44impl Default for EmbeddingConfig {
45 fn default() -> Self {
46 Self {
47 dimensions: 384, metric: DistanceMetric::Cosine,
49 #[cfg(feature = "hnsw_rs")]
50 hnsw_config: HnswConfig::default(),
51 embedding_property: "embedding".to_string(),
52 }
53 }
54}
55
56#[cfg(feature = "hnsw_rs")]
58type IndexImpl = HnswIndex;
59#[cfg(not(feature = "hnsw_rs"))]
60type IndexImpl = FlatIndex;
61
62pub struct HybridIndex {
64 node_index: Arc<RwLock<Option<IndexImpl>>>,
66 edge_index: Arc<RwLock<Option<IndexImpl>>>,
68 hyperedge_index: Arc<RwLock<Option<IndexImpl>>>,
70
71 node_id_map: Arc<DashMap<NodeId, String>>,
73 edge_id_map: Arc<DashMap<EdgeId, String>>,
75 hyperedge_id_map: Arc<DashMap<String, String>>,
77
78 config: EmbeddingConfig,
80}
81
82impl HybridIndex {
83 pub fn new(config: EmbeddingConfig) -> Result<Self> {
85 Ok(Self {
86 node_index: Arc::new(RwLock::new(None)),
87 edge_index: Arc::new(RwLock::new(None)),
88 hyperedge_index: Arc::new(RwLock::new(None)),
89 node_id_map: Arc::new(DashMap::new()),
90 edge_id_map: Arc::new(DashMap::new()),
91 hyperedge_id_map: Arc::new(DashMap::new()),
92 config,
93 })
94 }
95
96 #[cfg(feature = "hnsw_rs")]
98 pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
99 let index = HnswIndex::new(
100 self.config.dimensions,
101 self.config.metric,
102 self.config.hnsw_config.clone(),
103 )
104 .map_err(|e| GraphError::IndexError(format!("Failed to create HNSW index: {}", e)))?;
105
106 match index_type {
107 VectorIndexType::Node => {
108 *self.node_index.write() = Some(index);
109 }
110 VectorIndexType::Edge => {
111 *self.edge_index.write() = Some(index);
112 }
113 VectorIndexType::Hyperedge => {
114 *self.hyperedge_index.write() = Some(index);
115 }
116 }
117
118 Ok(())
119 }
120
121 #[cfg(not(feature = "hnsw_rs"))]
123 pub fn initialize_index(&self, index_type: VectorIndexType) -> Result<()> {
124 let index = FlatIndex::new(self.config.dimensions, self.config.metric);
125
126 match index_type {
127 VectorIndexType::Node => {
128 *self.node_index.write() = Some(index);
129 }
130 VectorIndexType::Edge => {
131 *self.edge_index.write() = Some(index);
132 }
133 VectorIndexType::Hyperedge => {
134 *self.hyperedge_index.write() = Some(index);
135 }
136 }
137
138 Ok(())
139 }
140
141 pub fn add_node_embedding(&self, node_id: NodeId, embedding: Vec<f32>) -> Result<()> {
143 if embedding.len() != self.config.dimensions {
144 return Err(GraphError::InvalidEmbedding(format!(
145 "Expected {} dimensions, got {}",
146 self.config.dimensions,
147 embedding.len()
148 )));
149 }
150
151 let mut index_guard = self.node_index.write();
152 let index = index_guard
153 .as_mut()
154 .ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
155
156 let vector_id = format!("node_{}", node_id);
157 index
158 .add(vector_id.clone(), embedding)
159 .map_err(|e| GraphError::IndexError(format!("Failed to add node embedding: {}", e)))?;
160
161 self.node_id_map.insert(node_id, vector_id);
162 Ok(())
163 }
164
165 pub fn add_edge_embedding(&self, edge_id: EdgeId, embedding: Vec<f32>) -> Result<()> {
167 if embedding.len() != self.config.dimensions {
168 return Err(GraphError::InvalidEmbedding(format!(
169 "Expected {} dimensions, got {}",
170 self.config.dimensions,
171 embedding.len()
172 )));
173 }
174
175 let mut index_guard = self.edge_index.write();
176 let index = index_guard
177 .as_mut()
178 .ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
179
180 let vector_id = format!("edge_{}", edge_id);
181 index
182 .add(vector_id.clone(), embedding)
183 .map_err(|e| GraphError::IndexError(format!("Failed to add edge embedding: {}", e)))?;
184
185 self.edge_id_map.insert(edge_id, vector_id);
186 Ok(())
187 }
188
189 pub fn add_hyperedge_embedding(&self, hyperedge_id: String, embedding: Vec<f32>) -> Result<()> {
191 if embedding.len() != self.config.dimensions {
192 return Err(GraphError::InvalidEmbedding(format!(
193 "Expected {} dimensions, got {}",
194 self.config.dimensions,
195 embedding.len()
196 )));
197 }
198
199 let mut index_guard = self.hyperedge_index.write();
200 let index = index_guard
201 .as_mut()
202 .ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
203
204 let vector_id = format!("hyperedge_{}", hyperedge_id);
205 index.add(vector_id.clone(), embedding).map_err(|e| {
206 GraphError::IndexError(format!("Failed to add hyperedge embedding: {}", e))
207 })?;
208
209 self.hyperedge_id_map.insert(hyperedge_id, vector_id);
210 Ok(())
211 }
212
213 pub fn search_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>> {
215 let index_guard = self.node_index.read();
216 let index = index_guard
217 .as_ref()
218 .ok_or_else(|| GraphError::IndexError("Node index not initialized".to_string()))?;
219
220 let results = index
221 .search(query, k)
222 .map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
223
224 Ok(results
225 .into_iter()
226 .filter_map(|result| {
227 let node_id = result.id.strip_prefix("node_")?.to_string();
229 Some((node_id, result.score))
230 })
231 .collect())
232 }
233
234 pub fn search_similar_edges(&self, query: &[f32], k: usize) -> Result<Vec<(EdgeId, f32)>> {
236 let index_guard = self.edge_index.read();
237 let index = index_guard
238 .as_ref()
239 .ok_or_else(|| GraphError::IndexError("Edge index not initialized".to_string()))?;
240
241 let results = index
242 .search(query, k)
243 .map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
244
245 Ok(results
246 .into_iter()
247 .filter_map(|result| {
248 let edge_id = result.id.strip_prefix("edge_")?.to_string();
249 Some((edge_id, result.score))
250 })
251 .collect())
252 }
253
254 pub fn search_similar_hyperedges(&self, query: &[f32], k: usize) -> Result<Vec<(String, f32)>> {
256 let index_guard = self.hyperedge_index.read();
257 let index = index_guard
258 .as_ref()
259 .ok_or_else(|| GraphError::IndexError("Hyperedge index not initialized".to_string()))?;
260
261 let results = index
262 .search(query, k)
263 .map_err(|e| GraphError::IndexError(format!("Search failed: {}", e)))?;
264
265 Ok(results
266 .into_iter()
267 .filter_map(|result| {
268 let hyperedge_id = result.id.strip_prefix("hyperedge_")?.to_string();
269 Some((hyperedge_id, result.score))
270 })
271 .collect())
272 }
273
274 pub fn extract_embedding(&self, properties: &Properties) -> Result<Option<Vec<f32>>> {
276 let prop_value = match properties.get(&self.config.embedding_property) {
277 Some(v) => v,
278 None => return Ok(None),
279 };
280
281 match prop_value {
282 PropertyValue::Array(arr) => {
283 let embedding: Result<Vec<f32>> = arr
284 .iter()
285 .map(|v| match v {
286 PropertyValue::Float(f) => Ok(*f as f32),
287 PropertyValue::Integer(i) => Ok(*i as f32),
288 _ => Err(GraphError::InvalidEmbedding(
289 "Embedding array must contain numbers".to_string(),
290 )),
291 })
292 .collect();
293 embedding.map(Some)
294 }
295 _ => Err(GraphError::InvalidEmbedding(
296 "Embedding property must be an array".to_string(),
297 )),
298 }
299 }
300
301 pub fn stats(&self) -> HybridIndexStats {
303 let node_count = self.node_id_map.len();
304 let edge_count = self.edge_id_map.len();
305 let hyperedge_count = self.hyperedge_id_map.len();
306
307 HybridIndexStats {
308 node_count,
309 edge_count,
310 hyperedge_count,
311 total_embeddings: node_count + edge_count + hyperedge_count,
312 }
313 }
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct HybridIndexStats {
319 pub node_count: usize,
320 pub edge_count: usize,
321 pub hyperedge_count: usize,
322 pub total_embeddings: usize,
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_hybrid_index_creation() -> Result<()> {
331 let config = EmbeddingConfig::default();
332 let index = HybridIndex::new(config)?;
333
334 let stats = index.stats();
335 assert_eq!(stats.total_embeddings, 0);
336
337 Ok(())
338 }
339
340 #[test]
341 fn test_node_embedding_indexing() -> Result<()> {
342 let config = EmbeddingConfig {
343 dimensions: 4,
344 ..Default::default()
345 };
346 let index = HybridIndex::new(config)?;
347 index.initialize_index(VectorIndexType::Node)?;
348
349 let embedding = vec![1.0, 2.0, 3.0, 4.0];
350 index.add_node_embedding("node1".to_string(), embedding)?;
351
352 let stats = index.stats();
353 assert_eq!(stats.node_count, 1);
354
355 Ok(())
356 }
357
358 #[test]
359 fn test_similarity_search() -> Result<()> {
360 let config = EmbeddingConfig {
361 dimensions: 4,
362 ..Default::default()
363 };
364 let index = HybridIndex::new(config)?;
365 index.initialize_index(VectorIndexType::Node)?;
366
367 index.add_node_embedding("node1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
369 index.add_node_embedding("node2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
370 index.add_node_embedding("node3".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
371
372 let results = index.search_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 2)?;
374
375 assert!(results.len() <= 2);
376 if !results.is_empty() {
377 assert_eq!(results[0].0, "node1");
378 }
379
380 Ok(())
381 }
382}