1use crate::error::{Result, RuvectorError};
7use crate::types::{DistanceMetric, VectorId};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::time::{SystemTime, UNIX_EPOCH};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Hyperedge {
16 pub id: String,
18 pub nodes: Vec<VectorId>,
20 pub description: String,
22 pub embedding: Vec<f32>,
24 pub confidence: f32,
26 pub metadata: HashMap<String, String>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct TemporalHyperedge {
33 pub hyperedge: Hyperedge,
35 pub timestamp: u64,
37 pub expires_at: Option<u64>,
39 pub granularity: TemporalGranularity,
41}
42
43#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
45pub enum TemporalGranularity {
46 Hourly,
47 Daily,
48 Monthly,
49 Yearly,
50}
51
52impl Hyperedge {
53 pub fn new(
55 nodes: Vec<VectorId>,
56 description: String,
57 embedding: Vec<f32>,
58 confidence: f32,
59 ) -> Self {
60 Self {
61 id: Uuid::new_v4().to_string(),
62 nodes,
63 description,
64 embedding,
65 confidence: confidence.clamp(0.0, 1.0),
66 metadata: HashMap::new(),
67 }
68 }
69
70 pub fn order(&self) -> usize {
72 self.nodes.len()
73 }
74
75 pub fn contains_node(&self, node: &VectorId) -> bool {
77 self.nodes.contains(node)
78 }
79}
80
81impl TemporalHyperedge {
82 pub fn new(hyperedge: Hyperedge, granularity: TemporalGranularity) -> Self {
84 let timestamp = SystemTime::now()
85 .duration_since(UNIX_EPOCH)
86 .unwrap()
87 .as_secs();
88
89 Self {
90 hyperedge,
91 timestamp,
92 expires_at: None,
93 granularity,
94 }
95 }
96
97 pub fn is_expired(&self) -> bool {
99 if let Some(expires_at) = self.expires_at {
100 let now = SystemTime::now()
101 .duration_since(UNIX_EPOCH)
102 .unwrap()
103 .as_secs();
104 now > expires_at
105 } else {
106 false
107 }
108 }
109
110 pub fn time_bucket(&self) -> u64 {
112 match self.granularity {
113 TemporalGranularity::Hourly => self.timestamp / 3600,
114 TemporalGranularity::Daily => self.timestamp / 86400,
115 TemporalGranularity::Monthly => self.timestamp / (86400 * 30),
116 TemporalGranularity::Yearly => self.timestamp / (86400 * 365),
117 }
118 }
119}
120
121pub struct HypergraphIndex {
123 entities: HashMap<VectorId, Vec<f32>>,
125 hyperedges: HashMap<String, Hyperedge>,
127 temporal_index: HashMap<u64, Vec<String>>,
129 entity_to_hyperedges: HashMap<VectorId, HashSet<String>>,
131 hyperedge_to_entities: HashMap<String, HashSet<VectorId>>,
133 distance_metric: DistanceMetric,
135}
136
137impl HypergraphIndex {
138 pub fn new(distance_metric: DistanceMetric) -> Self {
140 Self {
141 entities: HashMap::new(),
142 hyperedges: HashMap::new(),
143 temporal_index: HashMap::new(),
144 entity_to_hyperedges: HashMap::new(),
145 hyperedge_to_entities: HashMap::new(),
146 distance_metric,
147 }
148 }
149
150 pub fn add_entity(&mut self, id: VectorId, embedding: Vec<f32>) {
152 self.entities.insert(id.clone(), embedding);
153 self.entity_to_hyperedges.entry(id).or_default();
154 }
155
156 pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) -> Result<()> {
158 let edge_id = hyperedge.id.clone();
159
160 for node in &hyperedge.nodes {
162 if !self.entities.contains_key(node) {
163 return Err(RuvectorError::InvalidInput(format!(
164 "Entity {} not found in hypergraph",
165 node
166 )));
167 }
168 }
169
170 for node in &hyperedge.nodes {
172 self.entity_to_hyperedges
173 .entry(node.clone())
174 .or_default()
175 .insert(edge_id.clone());
176 }
177
178 let nodes_set: HashSet<VectorId> = hyperedge.nodes.iter().cloned().collect();
179 self.hyperedge_to_entities
180 .insert(edge_id.clone(), nodes_set);
181
182 self.hyperedges.insert(edge_id, hyperedge);
183 Ok(())
184 }
185
186 pub fn remove_entity(&mut self, id: &VectorId, cascade: bool) -> usize {
189 let deleted_edges = if cascade {
190 let edge_ids: Vec<String> = self
191 .entity_to_hyperedges
192 .get(id)
193 .map(|s| s.iter().cloned().collect())
194 .unwrap_or_default();
195 let mut count = 0;
196 for edge_id in &edge_ids {
197 if self.remove_hyperedge(edge_id) {
198 count += 1;
199 }
200 }
201 count
202 } else {
203 if let Some(edge_ids) = self.entity_to_hyperedges.remove(id) {
205 for edge_id in &edge_ids {
206 if let Some(nodes) = self.hyperedge_to_entities.get_mut(edge_id) {
207 nodes.remove(id);
208 }
209 }
210 }
211 0
212 };
213 self.entities.remove(id);
214 deleted_edges
215 }
216
217 pub fn remove_hyperedge(&mut self, id: &str) -> bool {
220 if let Some(hyperedge) = self.hyperedges.remove(id) {
221 for node in &hyperedge.nodes {
223 if let Some(set) = self.entity_to_hyperedges.get_mut(node) {
224 set.remove(id);
225 }
226 }
227 self.hyperedge_to_entities.remove(id);
228
229 for bucket_edges in self.temporal_index.values_mut() {
231 bucket_edges.retain(|eid| eid != id);
232 }
233 true
234 } else {
235 false
236 }
237 }
238
239 pub fn add_temporal_hyperedge(&mut self, temporal_edge: TemporalHyperedge) -> Result<()> {
241 let bucket = temporal_edge.time_bucket();
242 let edge_id = temporal_edge.hyperedge.id.clone();
243
244 self.add_hyperedge(temporal_edge.hyperedge)?;
245
246 self.temporal_index.entry(bucket).or_default().push(edge_id);
247
248 Ok(())
249 }
250
251 pub fn search_hyperedges(&self, query_embedding: &[f32], k: usize) -> Vec<(String, f32)> {
253 let mut results: Vec<(String, f32)> = self
254 .hyperedges
255 .iter()
256 .map(|(id, edge)| {
257 let distance = self.compute_distance(query_embedding, &edge.embedding);
258 (id.clone(), distance)
259 })
260 .collect();
261
262 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
263 results.truncate(k);
264 results
265 }
266
267 pub fn k_hop_neighbors(&self, start_node: VectorId, k: usize) -> HashSet<VectorId> {
270 let mut visited = HashSet::new();
271 let mut current_layer = HashSet::new();
272 current_layer.insert(start_node.clone());
273 visited.insert(start_node); for _hop in 0..k {
276 let mut next_layer = HashSet::new();
277
278 for node in current_layer.iter() {
279 if let Some(hyperedges) = self.entity_to_hyperedges.get(node) {
281 for edge_id in hyperedges {
282 if let Some(nodes) = self.hyperedge_to_entities.get(edge_id) {
284 for neighbor in nodes.iter() {
285 if !visited.contains(neighbor) {
286 visited.insert(neighbor.clone());
287 next_layer.insert(neighbor.clone());
288 }
289 }
290 }
291 }
292 }
293 }
294
295 if next_layer.is_empty() {
296 break;
297 }
298 current_layer = next_layer;
299 }
300
301 visited
302 }
303
304 pub fn query_temporal_range(&self, start_bucket: u64, end_bucket: u64) -> Vec<String> {
306 let mut results = Vec::new();
307 for bucket in start_bucket..=end_bucket {
308 if let Some(edges) = self.temporal_index.get(&bucket) {
309 results.extend(edges.iter().cloned());
310 }
311 }
312 results
313 }
314
315 pub fn get_hyperedge(&self, id: &str) -> Option<&Hyperedge> {
317 self.hyperedges.get(id)
318 }
319
320 pub fn stats(&self) -> HypergraphStats {
322 let total_edges = self.hyperedges.len();
323 let total_entities = self.entities.len();
324 let avg_degree = if total_entities > 0 {
325 self.entity_to_hyperedges
326 .values()
327 .map(|edges| edges.len())
328 .sum::<usize>() as f32
329 / total_entities as f32
330 } else {
331 0.0
332 };
333
334 HypergraphStats {
335 total_entities,
336 total_hyperedges: total_edges,
337 avg_entity_degree: avg_degree,
338 }
339 }
340
341 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
342 crate::distance::distance(a, b, self.distance_metric).unwrap_or(f32::MAX)
343 }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct HypergraphStats {
349 pub total_entities: usize,
350 pub total_hyperedges: usize,
351 pub avg_entity_degree: f32,
352}
353
354pub struct CausalMemory {
356 index: HypergraphIndex,
358 causal_counts: HashMap<(VectorId, VectorId), u32>,
360 latencies: HashMap<VectorId, f32>,
362 alpha: f32, beta: f32, gamma: f32, }
367
368impl CausalMemory {
369 pub fn new(distance_metric: DistanceMetric) -> Self {
371 Self {
372 index: HypergraphIndex::new(distance_metric),
373 causal_counts: HashMap::new(),
374 latencies: HashMap::new(),
375 alpha: 0.7,
376 beta: 0.2,
377 gamma: 0.1,
378 }
379 }
380
381 pub fn with_weights(mut self, alpha: f32, beta: f32, gamma: f32) -> Self {
383 self.alpha = alpha;
384 self.beta = beta;
385 self.gamma = gamma;
386 self
387 }
388
389 pub fn add_causal_edge(
391 &mut self,
392 cause: VectorId,
393 effect: VectorId,
394 context: Vec<VectorId>,
395 description: String,
396 embedding: Vec<f32>,
397 latency_ms: f32,
398 ) -> Result<()> {
399 let mut nodes = vec![cause.clone(), effect.clone()];
401 nodes.extend(context);
402
403 let hyperedge = Hyperedge::new(nodes, description, embedding, 1.0);
404 self.index.add_hyperedge(hyperedge)?;
405
406 *self
408 .causal_counts
409 .entry((cause.clone(), effect.clone()))
410 .or_insert(0) += 1;
411
412 let entry = self.latencies.entry(cause).or_insert(0.0);
414 *entry = (*entry + latency_ms) / 2.0; Ok(())
417 }
418
419 pub fn query_with_utility(
421 &self,
422 query_embedding: &[f32],
423 action_id: VectorId,
424 k: usize,
425 ) -> Vec<(String, f32)> {
426 let mut results: Vec<(String, f32)> = self
427 .index
428 .hyperedges
429 .iter()
430 .filter(|(_, edge)| edge.contains_node(&action_id))
431 .map(|(id, edge)| {
432 let similarity = 1.0
433 - self
434 .index
435 .compute_distance(query_embedding, &edge.embedding);
436 let causal_uplift = self.compute_causal_uplift(&edge.nodes);
437 let latency = self.latencies.get(&action_id).copied().unwrap_or(0.0);
438
439 let utility = self.alpha * similarity + self.beta * causal_uplift
440 - self.gamma * (latency / 1000.0); (id.clone(), utility)
443 })
444 .collect();
445
446 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); results.truncate(k);
448 results
449 }
450
451 fn compute_causal_uplift(&self, nodes: &[VectorId]) -> f32 {
452 if nodes.len() < 2 {
453 return 0.0;
454 }
455
456 let mut total_uplift = 0.0;
458 let mut count = 0;
459
460 for i in 0..nodes.len() - 1 {
461 for j in i + 1..nodes.len() {
462 if let Some(&success_count) = self
463 .causal_counts
464 .get(&(nodes[i].clone(), nodes[j].clone()))
465 {
466 total_uplift += (success_count as f32).ln_1p(); count += 1;
468 }
469 }
470 }
471
472 if count > 0 {
473 total_uplift / count as f32
474 } else {
475 0.0
476 }
477 }
478
479 pub fn index(&self) -> &HypergraphIndex {
481 &self.index
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_hyperedge_creation() {
491 let nodes = vec!["1".to_string(), "2".to_string(), "3".to_string()];
492 let desc = "Test relationship".to_string();
493 let embedding = vec![0.1, 0.2, 0.3];
494 let edge = Hyperedge::new(nodes, desc, embedding, 0.95);
495
496 assert_eq!(edge.order(), 3);
497 assert!(edge.contains_node(&"1".to_string()));
498 assert!(!edge.contains_node(&"4".to_string()));
499 assert_eq!(edge.confidence, 0.95);
500 }
501
502 #[test]
503 fn test_temporal_hyperedge() {
504 let nodes = vec!["1".to_string(), "2".to_string()];
505 let desc = "Temporal relationship".to_string();
506 let embedding = vec![0.1, 0.2];
507 let edge = Hyperedge::new(nodes, desc, embedding, 1.0);
508
509 let temporal = TemporalHyperedge::new(edge, TemporalGranularity::Hourly);
510
511 assert!(!temporal.is_expired());
512 assert!(temporal.time_bucket() > 0);
513 }
514
515 #[test]
516 fn test_hypergraph_index() {
517 let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
518
519 index.add_entity("1".to_string(), vec![1.0, 0.0, 0.0]);
521 index.add_entity("2".to_string(), vec![0.0, 1.0, 0.0]);
522 index.add_entity("3".to_string(), vec![0.0, 0.0, 1.0]);
523
524 let edge = Hyperedge::new(
526 vec!["1".to_string(), "2".to_string(), "3".to_string()],
527 "Triple relationship".to_string(),
528 vec![0.5, 0.5, 0.5],
529 0.9,
530 );
531 index.add_hyperedge(edge).unwrap();
532
533 let stats = index.stats();
534 assert_eq!(stats.total_entities, 3);
535 assert_eq!(stats.total_hyperedges, 1);
536 }
537
538 #[test]
539 fn test_k_hop_neighbors() {
540 let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
541
542 index.add_entity("1".to_string(), vec![1.0]);
544 index.add_entity("2".to_string(), vec![1.0]);
545 index.add_entity("3".to_string(), vec![1.0]);
546 index.add_entity("4".to_string(), vec![1.0]);
547
548 let edge1 = Hyperedge::new(
549 vec!["1".to_string(), "2".to_string()],
550 "e1".to_string(),
551 vec![1.0],
552 1.0,
553 );
554 let edge2 = Hyperedge::new(
555 vec!["2".to_string(), "3".to_string()],
556 "e2".to_string(),
557 vec![1.0],
558 1.0,
559 );
560 let edge3 = Hyperedge::new(
561 vec!["3".to_string(), "4".to_string()],
562 "e3".to_string(),
563 vec![1.0],
564 1.0,
565 );
566
567 index.add_hyperedge(edge1).unwrap();
568 index.add_hyperedge(edge2).unwrap();
569 index.add_hyperedge(edge3).unwrap();
570
571 let neighbors = index.k_hop_neighbors("1".to_string(), 2);
572 assert!(neighbors.contains(&"1".to_string()));
573 assert!(neighbors.contains(&"2".to_string()));
574 assert!(neighbors.contains(&"3".to_string()));
575 }
576
577 #[test]
578 fn test_causal_memory() {
579 let mut memory = CausalMemory::new(DistanceMetric::Cosine);
580
581 memory.index.add_entity("1".to_string(), vec![1.0, 0.0]);
582 memory.index.add_entity("2".to_string(), vec![0.0, 1.0]);
583
584 memory
585 .add_causal_edge(
586 "1".to_string(),
587 "2".to_string(),
588 vec![],
589 "Action 1 causes effect 2".to_string(),
590 vec![0.5, 0.5],
591 100.0,
592 )
593 .unwrap();
594
595 let results = memory.query_with_utility(&[0.6, 0.4], "1".to_string(), 5);
596 assert!(!results.is_empty());
597 }
598}