Skip to main content

ruvector_core/advanced/
hypergraph.rs

1//! # Hypergraph Support for N-ary Relationships
2//!
3//! Implements hypergraph structures for representing complex multi-entity relationships
4//! beyond traditional pairwise similarity. Based on HyperGraphRAG (NeurIPS 2025) architecture.
5
6use crate::error::{Result, RuvectorError};
7use crate::types::{DistanceMetric, VectorId};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::time::{SystemTime, UNIX_EPOCH};
11use uuid::Uuid;
12
13/// Hyperedge connecting multiple vectors with description and embedding
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Hyperedge {
16    /// Unique identifier for the hyperedge
17    pub id: String,
18    /// Vector IDs connected by this hyperedge
19    pub nodes: Vec<VectorId>,
20    /// Natural language description of the relationship
21    pub description: String,
22    /// Embedding of the hyperedge description
23    pub embedding: Vec<f32>,
24    /// Confidence weight (0.0-1.0)
25    pub confidence: f32,
26    /// Optional metadata
27    pub metadata: HashMap<String, String>,
28}
29
30/// Temporal hyperedge with time attributes
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct TemporalHyperedge {
33    /// Base hyperedge
34    pub hyperedge: Hyperedge,
35    /// Creation timestamp (Unix epoch seconds)
36    pub timestamp: u64,
37    /// Optional expiration timestamp
38    pub expires_at: Option<u64>,
39    /// Temporal context (hourly, daily, monthly)
40    pub granularity: TemporalGranularity,
41}
42
43/// Temporal granularity for indexing
44#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
45pub enum TemporalGranularity {
46    Hourly,
47    Daily,
48    Monthly,
49    Yearly,
50}
51
52impl Hyperedge {
53    /// Create a new hyperedge
54    pub fn new(
55        nodes: Vec<VectorId>,
56        description: String,
57        embedding: Vec<f32>,
58        confidence: f32,
59    ) -> Self {
60        Self {
61            id: Uuid::new_v4().to_string(),
62            nodes,
63            description,
64            embedding,
65            confidence: confidence.clamp(0.0, 1.0),
66            metadata: HashMap::new(),
67        }
68    }
69
70    /// Get hyperedge order (number of nodes)
71    pub fn order(&self) -> usize {
72        self.nodes.len()
73    }
74
75    /// Check if hyperedge contains a specific node
76    pub fn contains_node(&self, node: &VectorId) -> bool {
77        self.nodes.contains(node)
78    }
79}
80
81impl TemporalHyperedge {
82    /// Create a new temporal hyperedge with current timestamp
83    pub fn new(hyperedge: Hyperedge, granularity: TemporalGranularity) -> Self {
84        let timestamp = SystemTime::now()
85            .duration_since(UNIX_EPOCH)
86            .unwrap()
87            .as_secs();
88
89        Self {
90            hyperedge,
91            timestamp,
92            expires_at: None,
93            granularity,
94        }
95    }
96
97    /// Check if hyperedge is expired
98    pub fn is_expired(&self) -> bool {
99        if let Some(expires_at) = self.expires_at {
100            let now = SystemTime::now()
101                .duration_since(UNIX_EPOCH)
102                .unwrap()
103                .as_secs();
104            now > expires_at
105        } else {
106            false
107        }
108    }
109
110    /// Get time bucket for indexing
111    pub fn time_bucket(&self) -> u64 {
112        match self.granularity {
113            TemporalGranularity::Hourly => self.timestamp / 3600,
114            TemporalGranularity::Daily => self.timestamp / 86400,
115            TemporalGranularity::Monthly => self.timestamp / (86400 * 30),
116            TemporalGranularity::Yearly => self.timestamp / (86400 * 365),
117        }
118    }
119}
120
121/// Hypergraph index with bipartite graph storage
122pub struct HypergraphIndex {
123    /// Entity nodes
124    entities: HashMap<VectorId, Vec<f32>>,
125    /// Hyperedges
126    hyperedges: HashMap<String, Hyperedge>,
127    /// Temporal hyperedges indexed by time bucket
128    temporal_index: HashMap<u64, Vec<String>>,
129    /// Bipartite graph: entity -> hyperedge IDs
130    entity_to_hyperedges: HashMap<VectorId, HashSet<String>>,
131    /// Bipartite graph: hyperedge -> entity IDs
132    hyperedge_to_entities: HashMap<String, HashSet<VectorId>>,
133    /// Distance metric for embeddings
134    distance_metric: DistanceMetric,
135}
136
137impl HypergraphIndex {
138    /// Create a new hypergraph index
139    pub fn new(distance_metric: DistanceMetric) -> Self {
140        Self {
141            entities: HashMap::new(),
142            hyperedges: HashMap::new(),
143            temporal_index: HashMap::new(),
144            entity_to_hyperedges: HashMap::new(),
145            hyperedge_to_entities: HashMap::new(),
146            distance_metric,
147        }
148    }
149
150    /// Add an entity node
151    pub fn add_entity(&mut self, id: VectorId, embedding: Vec<f32>) {
152        self.entities.insert(id.clone(), embedding);
153        self.entity_to_hyperedges.entry(id).or_default();
154    }
155
156    /// Add a hyperedge
157    pub fn add_hyperedge(&mut self, hyperedge: Hyperedge) -> Result<()> {
158        let edge_id = hyperedge.id.clone();
159
160        // Verify all nodes exist
161        for node in &hyperedge.nodes {
162            if !self.entities.contains_key(node) {
163                return Err(RuvectorError::InvalidInput(format!(
164                    "Entity {} not found in hypergraph",
165                    node
166                )));
167            }
168        }
169
170        // Update bipartite graph
171        for node in &hyperedge.nodes {
172            self.entity_to_hyperedges
173                .entry(node.clone())
174                .or_default()
175                .insert(edge_id.clone());
176        }
177
178        let nodes_set: HashSet<VectorId> = hyperedge.nodes.iter().cloned().collect();
179        self.hyperedge_to_entities
180            .insert(edge_id.clone(), nodes_set);
181
182        self.hyperedges.insert(edge_id, hyperedge);
183        Ok(())
184    }
185
186    /// Remove an entity node and optionally cascade-delete incident hyperedges.
187    /// Returns the count of deleted hyperedges (cascade only).
188    pub fn remove_entity(&mut self, id: &VectorId, cascade: bool) -> usize {
189        let deleted_edges = if cascade {
190            let edge_ids: Vec<String> = self
191                .entity_to_hyperedges
192                .get(id)
193                .map(|s| s.iter().cloned().collect())
194                .unwrap_or_default();
195            let mut count = 0;
196            for edge_id in &edge_ids {
197                if self.remove_hyperedge(edge_id) {
198                    count += 1;
199                }
200            }
201            count
202        } else {
203            // Still clean up the entity → hyperedge mapping without touching edges
204            if let Some(edge_ids) = self.entity_to_hyperedges.remove(id) {
205                for edge_id in &edge_ids {
206                    if let Some(nodes) = self.hyperedge_to_entities.get_mut(edge_id) {
207                        nodes.remove(id);
208                    }
209                }
210            }
211            0
212        };
213        self.entities.remove(id);
214        deleted_edges
215    }
216
217    /// Remove a hyperedge by ID, cleaning up all index entries.
218    /// Returns true if the hyperedge existed.
219    pub fn remove_hyperedge(&mut self, id: &str) -> bool {
220        if let Some(hyperedge) = self.hyperedges.remove(id) {
221            // Clean up entity → hyperedge reverse index
222            for node in &hyperedge.nodes {
223                if let Some(set) = self.entity_to_hyperedges.get_mut(node) {
224                    set.remove(id);
225                }
226            }
227            self.hyperedge_to_entities.remove(id);
228
229            // Clean up temporal index entries that reference this id
230            for bucket_edges in self.temporal_index.values_mut() {
231                bucket_edges.retain(|eid| eid != id);
232            }
233            true
234        } else {
235            false
236        }
237    }
238
239    /// Add a temporal hyperedge
240    pub fn add_temporal_hyperedge(&mut self, temporal_edge: TemporalHyperedge) -> Result<()> {
241        let bucket = temporal_edge.time_bucket();
242        let edge_id = temporal_edge.hyperedge.id.clone();
243
244        self.add_hyperedge(temporal_edge.hyperedge)?;
245
246        self.temporal_index.entry(bucket).or_default().push(edge_id);
247
248        Ok(())
249    }
250
251    /// Search hyperedges by embedding similarity
252    pub fn search_hyperedges(&self, query_embedding: &[f32], k: usize) -> Vec<(String, f32)> {
253        let mut results: Vec<(String, f32)> = self
254            .hyperedges
255            .iter()
256            .map(|(id, edge)| {
257                let distance = self.compute_distance(query_embedding, &edge.embedding);
258                (id.clone(), distance)
259            })
260            .collect();
261
262        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
263        results.truncate(k);
264        results
265    }
266
267    /// Get k-hop neighbors in hypergraph
268    /// Returns all nodes reachable within k hops from the start node
269    pub fn k_hop_neighbors(&self, start_node: VectorId, k: usize) -> HashSet<VectorId> {
270        let mut visited = HashSet::new();
271        let mut current_layer = HashSet::new();
272        current_layer.insert(start_node.clone());
273        visited.insert(start_node); // Start node is at distance 0
274
275        for _hop in 0..k {
276            let mut next_layer = HashSet::new();
277
278            for node in current_layer.iter() {
279                // Get all hyperedges containing this node
280                if let Some(hyperedges) = self.entity_to_hyperedges.get(node) {
281                    for edge_id in hyperedges {
282                        // Get all nodes in this hyperedge
283                        if let Some(nodes) = self.hyperedge_to_entities.get(edge_id) {
284                            for neighbor in nodes.iter() {
285                                if !visited.contains(neighbor) {
286                                    visited.insert(neighbor.clone());
287                                    next_layer.insert(neighbor.clone());
288                                }
289                            }
290                        }
291                    }
292                }
293            }
294
295            if next_layer.is_empty() {
296                break;
297            }
298            current_layer = next_layer;
299        }
300
301        visited
302    }
303
304    /// Query temporal hyperedges in a time range
305    pub fn query_temporal_range(&self, start_bucket: u64, end_bucket: u64) -> Vec<String> {
306        let mut results = Vec::new();
307        for bucket in start_bucket..=end_bucket {
308            if let Some(edges) = self.temporal_index.get(&bucket) {
309                results.extend(edges.iter().cloned());
310            }
311        }
312        results
313    }
314
315    /// Get hyperedge by ID
316    pub fn get_hyperedge(&self, id: &str) -> Option<&Hyperedge> {
317        self.hyperedges.get(id)
318    }
319
320    /// Get statistics
321    pub fn stats(&self) -> HypergraphStats {
322        let total_edges = self.hyperedges.len();
323        let total_entities = self.entities.len();
324        let avg_degree = if total_entities > 0 {
325            self.entity_to_hyperedges
326                .values()
327                .map(|edges| edges.len())
328                .sum::<usize>() as f32
329                / total_entities as f32
330        } else {
331            0.0
332        };
333
334        HypergraphStats {
335            total_entities,
336            total_hyperedges: total_edges,
337            avg_entity_degree: avg_degree,
338        }
339    }
340
341    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
342        crate::distance::distance(a, b, self.distance_metric).unwrap_or(f32::MAX)
343    }
344}
345
346/// Hypergraph statistics
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct HypergraphStats {
349    pub total_entities: usize,
350    pub total_hyperedges: usize,
351    pub avg_entity_degree: f32,
352}
353
354/// Causal hypergraph memory for agent reasoning
355pub struct CausalMemory {
356    /// Hypergraph index
357    index: HypergraphIndex,
358    /// Causal relationship tracking: (cause_id, effect_id) -> success_count
359    causal_counts: HashMap<(VectorId, VectorId), u32>,
360    /// Action latencies: action_id -> avg_latency_ms
361    latencies: HashMap<VectorId, f32>,
362    /// Utility function weights
363    alpha: f32, // similarity weight
364    beta: f32,  // causal uplift weight
365    gamma: f32, // latency penalty weight
366}
367
368impl CausalMemory {
369    /// Create a new causal memory with default utility weights
370    pub fn new(distance_metric: DistanceMetric) -> Self {
371        Self {
372            index: HypergraphIndex::new(distance_metric),
373            causal_counts: HashMap::new(),
374            latencies: HashMap::new(),
375            alpha: 0.7,
376            beta: 0.2,
377            gamma: 0.1,
378        }
379    }
380
381    /// Set custom utility function weights
382    pub fn with_weights(mut self, alpha: f32, beta: f32, gamma: f32) -> Self {
383        self.alpha = alpha;
384        self.beta = beta;
385        self.gamma = gamma;
386        self
387    }
388
389    /// Add a causal relationship
390    pub fn add_causal_edge(
391        &mut self,
392        cause: VectorId,
393        effect: VectorId,
394        context: Vec<VectorId>,
395        description: String,
396        embedding: Vec<f32>,
397        latency_ms: f32,
398    ) -> Result<()> {
399        // Create hyperedge connecting cause, effect, and context
400        let mut nodes = vec![cause.clone(), effect.clone()];
401        nodes.extend(context);
402
403        let hyperedge = Hyperedge::new(nodes, description, embedding, 1.0);
404        self.index.add_hyperedge(hyperedge)?;
405
406        // Update causal counts
407        *self
408            .causal_counts
409            .entry((cause.clone(), effect.clone()))
410            .or_insert(0) += 1;
411
412        // Update latency
413        let entry = self.latencies.entry(cause).or_insert(0.0);
414        *entry = (*entry + latency_ms) / 2.0; // Running average
415
416        Ok(())
417    }
418
419    /// Query with utility function: U = α·similarity + β·causal_uplift - γ·latency
420    pub fn query_with_utility(
421        &self,
422        query_embedding: &[f32],
423        action_id: VectorId,
424        k: usize,
425    ) -> Vec<(String, f32)> {
426        let mut results: Vec<(String, f32)> = self
427            .index
428            .hyperedges
429            .iter()
430            .filter(|(_, edge)| edge.contains_node(&action_id))
431            .map(|(id, edge)| {
432                let similarity = 1.0
433                    - self
434                        .index
435                        .compute_distance(query_embedding, &edge.embedding);
436                let causal_uplift = self.compute_causal_uplift(&edge.nodes);
437                let latency = self.latencies.get(&action_id).copied().unwrap_or(0.0);
438
439                let utility = self.alpha * similarity + self.beta * causal_uplift
440                    - self.gamma * (latency / 1000.0); // Normalize latency to 0-1 range
441
442                (id.clone(), utility)
443            })
444            .collect();
445
446        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); // Sort by utility descending
447        results.truncate(k);
448        results
449    }
450
451    fn compute_causal_uplift(&self, nodes: &[VectorId]) -> f32 {
452        if nodes.len() < 2 {
453            return 0.0;
454        }
455
456        // Compute average causal strength for pairs in this hyperedge
457        let mut total_uplift = 0.0;
458        let mut count = 0;
459
460        for i in 0..nodes.len() - 1 {
461            for j in i + 1..nodes.len() {
462                if let Some(&success_count) = self
463                    .causal_counts
464                    .get(&(nodes[i].clone(), nodes[j].clone()))
465                {
466                    total_uplift += (success_count as f32).ln_1p(); // Log scale
467                    count += 1;
468                }
469            }
470        }
471
472        if count > 0 {
473            total_uplift / count as f32
474        } else {
475            0.0
476        }
477    }
478
479    /// Get hypergraph index
480    pub fn index(&self) -> &HypergraphIndex {
481        &self.index
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_hyperedge_creation() {
491        let nodes = vec!["1".to_string(), "2".to_string(), "3".to_string()];
492        let desc = "Test relationship".to_string();
493        let embedding = vec![0.1, 0.2, 0.3];
494        let edge = Hyperedge::new(nodes, desc, embedding, 0.95);
495
496        assert_eq!(edge.order(), 3);
497        assert!(edge.contains_node(&"1".to_string()));
498        assert!(!edge.contains_node(&"4".to_string()));
499        assert_eq!(edge.confidence, 0.95);
500    }
501
502    #[test]
503    fn test_temporal_hyperedge() {
504        let nodes = vec!["1".to_string(), "2".to_string()];
505        let desc = "Temporal relationship".to_string();
506        let embedding = vec![0.1, 0.2];
507        let edge = Hyperedge::new(nodes, desc, embedding, 1.0);
508
509        let temporal = TemporalHyperedge::new(edge, TemporalGranularity::Hourly);
510
511        assert!(!temporal.is_expired());
512        assert!(temporal.time_bucket() > 0);
513    }
514
515    #[test]
516    fn test_hypergraph_index() {
517        let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
518
519        // Add entities
520        index.add_entity("1".to_string(), vec![1.0, 0.0, 0.0]);
521        index.add_entity("2".to_string(), vec![0.0, 1.0, 0.0]);
522        index.add_entity("3".to_string(), vec![0.0, 0.0, 1.0]);
523
524        // Add hyperedge
525        let edge = Hyperedge::new(
526            vec!["1".to_string(), "2".to_string(), "3".to_string()],
527            "Triple relationship".to_string(),
528            vec![0.5, 0.5, 0.5],
529            0.9,
530        );
531        index.add_hyperedge(edge).unwrap();
532
533        let stats = index.stats();
534        assert_eq!(stats.total_entities, 3);
535        assert_eq!(stats.total_hyperedges, 1);
536    }
537
538    #[test]
539    fn test_k_hop_neighbors() {
540        let mut index = HypergraphIndex::new(DistanceMetric::Cosine);
541
542        // Create a small hypergraph
543        index.add_entity("1".to_string(), vec![1.0]);
544        index.add_entity("2".to_string(), vec![1.0]);
545        index.add_entity("3".to_string(), vec![1.0]);
546        index.add_entity("4".to_string(), vec![1.0]);
547
548        let edge1 = Hyperedge::new(
549            vec!["1".to_string(), "2".to_string()],
550            "e1".to_string(),
551            vec![1.0],
552            1.0,
553        );
554        let edge2 = Hyperedge::new(
555            vec!["2".to_string(), "3".to_string()],
556            "e2".to_string(),
557            vec![1.0],
558            1.0,
559        );
560        let edge3 = Hyperedge::new(
561            vec!["3".to_string(), "4".to_string()],
562            "e3".to_string(),
563            vec![1.0],
564            1.0,
565        );
566
567        index.add_hyperedge(edge1).unwrap();
568        index.add_hyperedge(edge2).unwrap();
569        index.add_hyperedge(edge3).unwrap();
570
571        let neighbors = index.k_hop_neighbors("1".to_string(), 2);
572        assert!(neighbors.contains(&"1".to_string()));
573        assert!(neighbors.contains(&"2".to_string()));
574        assert!(neighbors.contains(&"3".to_string()));
575    }
576
577    #[test]
578    fn test_causal_memory() {
579        let mut memory = CausalMemory::new(DistanceMetric::Cosine);
580
581        memory.index.add_entity("1".to_string(), vec![1.0, 0.0]);
582        memory.index.add_entity("2".to_string(), vec![0.0, 1.0]);
583
584        memory
585            .add_causal_edge(
586                "1".to_string(),
587                "2".to_string(),
588                vec![],
589                "Action 1 causes effect 2".to_string(),
590                vec![0.5, 0.5],
591                100.0,
592            )
593            .unwrap();
594
595        let results = memory.query_with_utility(&[0.6, 0.4], "1".to_string(), 5);
596        assert!(!results.is_empty());
597    }
598}