1use crate::error::{GraphError, Result};
6use crate::hybrid::vector_index::HybridIndex;
7use crate::types::{EdgeId, NodeId};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SemanticSearchConfig {
14 pub max_path_length: usize,
16 pub min_similarity: f32,
18 pub top_k: usize,
20 pub semantic_weight: f32,
22}
23
24impl Default for SemanticSearchConfig {
25 fn default() -> Self {
26 Self {
27 max_path_length: 3,
28 min_similarity: 0.7,
29 top_k: 10,
30 semantic_weight: 0.6,
31 }
32 }
33}
34
35pub struct SemanticSearch {
37 index: HybridIndex,
39 config: SemanticSearchConfig,
41}
42
43impl SemanticSearch {
44 pub fn new(index: HybridIndex, config: SemanticSearchConfig) -> Self {
46 Self { index, config }
47 }
48
49 pub fn find_similar_nodes(&self, query: &[f32], k: usize) -> Result<Vec<SemanticMatch>> {
51 let results = self.index.search_similar_nodes(query, k)?;
52
53 let max_distance = 1.0 - self.config.min_similarity;
56
57 let mut matches = Vec::with_capacity(results.len());
61 for (node_id, distance) in results {
62 if distance <= max_distance {
64 matches.push(SemanticMatch {
65 node_id,
66 score: 1.0 - distance,
67 path_length: 0,
68 });
69 }
70 }
71 Ok(matches)
72 }
73
74 pub fn find_semantic_paths(
76 &self,
77 start_node: &NodeId,
78 query: &[f32],
79 max_paths: usize,
80 ) -> Result<Vec<SemanticPath>> {
81 let mut paths = Vec::new();
89
90 let similar = self.find_similar_nodes(query, max_paths)?;
92
93 for match_result in similar {
94 paths.push(SemanticPath {
95 nodes: vec![start_node.clone(), match_result.node_id],
96 edges: vec![],
97 semantic_score: match_result.score,
98 graph_distance: 1,
99 combined_score: self.compute_path_score(match_result.score, 1),
100 });
101 }
102
103 Ok(paths)
104 }
105
106 pub fn detect_clusters(
108 &self,
109 nodes: &[NodeId],
110 min_cluster_size: usize,
111 ) -> Result<Vec<ClusterResult>> {
112 let mut clusters = Vec::new();
119
120 if nodes.len() >= min_cluster_size {
122 clusters.push(ClusterResult {
123 cluster_id: 0,
124 nodes: nodes.to_vec(),
125 centroid: None,
126 coherence_score: 0.85,
127 });
128 }
129
130 Ok(clusters)
131 }
132
133 pub fn find_related_edges(&self, query: &[f32], k: usize) -> Result<Vec<EdgeMatch>> {
135 let results = self.index.search_similar_edges(query, k)?;
136
137 let max_distance = 1.0 - self.config.min_similarity;
139
140 let mut matches = Vec::with_capacity(results.len());
142 for (edge_id, distance) in results {
143 if distance <= max_distance {
144 matches.push(EdgeMatch {
145 edge_id,
146 score: 1.0 - distance,
147 });
148 }
149 }
150 Ok(matches)
151 }
152
153 fn compute_path_score(&self, semantic_score: f32, graph_distance: usize) -> f32 {
155 let w = self.config.semantic_weight;
156 let distance_penalty = 1.0 / (graph_distance as f32 + 1.0);
157
158 w * semantic_score + (1.0 - w) * distance_penalty
159 }
160
161 pub fn expand_query(&self, query: &[f32], expansion_factor: usize) -> Result<Vec<Vec<f32>>> {
163 let similar = self.index.search_similar_nodes(query, expansion_factor)?;
165
166 Ok(vec![query.to_vec()])
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct SemanticMatch {
175 pub node_id: NodeId,
176 pub score: f32,
177 pub path_length: usize,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct SemanticPath {
183 pub nodes: Vec<NodeId>,
185 pub edges: Vec<EdgeId>,
187 pub semantic_score: f32,
189 pub graph_distance: usize,
191 pub combined_score: f32,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ClusterResult {
198 pub cluster_id: usize,
199 pub nodes: Vec<NodeId>,
200 pub centroid: Option<Vec<f32>>,
201 pub coherence_score: f32,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct EdgeMatch {
207 pub edge_id: EdgeId,
208 pub score: f32,
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::hybrid::vector_index::{EmbeddingConfig, VectorIndexType};
215
216 #[test]
217 fn test_semantic_search_creation() {
218 let config = EmbeddingConfig::default();
219 let index = HybridIndex::new(config).unwrap();
220 let search_config = SemanticSearchConfig::default();
221
222 let _search = SemanticSearch::new(index, search_config);
223 }
224
225 #[test]
226 fn test_find_similar_nodes() -> Result<()> {
227 let config = EmbeddingConfig {
228 dimensions: 4,
229 ..Default::default()
230 };
231 let index = HybridIndex::new(config)?;
232 index.initialize_index(VectorIndexType::Node)?;
233
234 index.add_node_embedding("doc1".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
236 index.add_node_embedding("doc2".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
237
238 let search = SemanticSearch::new(index, SemanticSearchConfig::default());
239 let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 5)?;
240
241 assert!(!results.is_empty());
242 Ok(())
243 }
244
245 #[test]
246 fn test_cluster_detection() -> Result<()> {
247 let config = EmbeddingConfig::default();
248 let index = HybridIndex::new(config)?;
249 let search = SemanticSearch::new(index, SemanticSearchConfig::default());
250
251 let nodes = vec!["n1".to_string(), "n2".to_string(), "n3".to_string()];
252 let clusters = search.detect_clusters(&nodes, 2)?;
253
254 assert_eq!(clusters.len(), 1);
255 Ok(())
256 }
257
258 #[test]
259 fn test_similarity_score_range() -> Result<()> {
260 let config = EmbeddingConfig {
262 dimensions: 4,
263 ..Default::default()
264 };
265 let index = HybridIndex::new(config)?;
266 index.initialize_index(VectorIndexType::Node)?;
267
268 index.add_node_embedding("identical".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
270 index.add_node_embedding("similar".to_string(), vec![0.9, 0.1, 0.0, 0.0])?;
271 index.add_node_embedding("different".to_string(), vec![0.0, 1.0, 0.0, 0.0])?;
272
273 let search_config = SemanticSearchConfig {
274 min_similarity: 0.0, ..Default::default()
276 };
277 let search = SemanticSearch::new(index, search_config);
278 let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 10)?;
279
280 for result in &results {
282 assert!(
283 result.score >= 0.0 && result.score <= 1.0,
284 "Score {} out of valid range [0, 1]",
285 result.score
286 );
287 }
288
289 if !results.is_empty() {
291 let top_result = &results[0];
292 assert!(
293 top_result.score > 0.9,
294 "Identical vector should have score > 0.9"
295 );
296 }
297
298 Ok(())
299 }
300
301 #[test]
302 fn test_min_similarity_filtering() -> Result<()> {
303 let config = EmbeddingConfig {
304 dimensions: 4,
305 ..Default::default()
306 };
307 let index = HybridIndex::new(config)?;
308 index.initialize_index(VectorIndexType::Node)?;
309
310 index.add_node_embedding("high_sim".to_string(), vec![1.0, 0.0, 0.0, 0.0])?;
312 index.add_node_embedding("low_sim".to_string(), vec![0.0, 0.0, 0.0, 1.0])?;
313
314 let search_config = SemanticSearchConfig {
316 min_similarity: 0.9,
317 ..Default::default()
318 };
319 let search = SemanticSearch::new(index, search_config);
320 let results = search.find_similar_nodes(&[1.0, 0.0, 0.0, 0.0], 10)?;
321
322 for result in &results {
324 assert!(
325 result.score >= 0.9,
326 "Result with score {} should be filtered out (min: 0.9)",
327 result.score
328 );
329 }
330
331 Ok(())
332 }
333}