Skip to main content

sqlite_knowledge_graph/graph/
hyperedge.rs

1//! Hyperedge (higher-order relation) storage module.
2//!
3//! Provides storage and querying for multi-entity relationships (hyperedges).
4//! A hyperedge connects 2 or more entities in a single relation.
5
6use rusqlite::params;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10use crate::error::{Error, Result};
11use crate::graph::entity::{get_entity, Entity};
12
13/// A hyperedge representing a higher-order relation among multiple entities.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Hyperedge {
16    pub id: Option<i64>,
17    pub hyperedge_type: String,
18    pub entity_ids: Vec<i64>,
19    pub weight: f64,
20    pub arity: usize,
21    pub properties: HashMap<String, serde_json::Value>,
22    pub created_at: Option<i64>,
23    pub updated_at: Option<i64>,
24}
25
26impl Hyperedge {
27    /// Create a new hyperedge.
28    ///
29    /// Requires at least 2 entities and weight in [0.0, 1.0].
30    pub fn new(
31        entity_ids: Vec<i64>,
32        hyperedge_type: impl Into<String>,
33        weight: f64,
34    ) -> Result<Self> {
35        if entity_ids.len() < 2 {
36            return Err(Error::InvalidArity(entity_ids.len()));
37        }
38        if !(0.0..=1.0).contains(&weight) {
39            return Err(Error::InvalidWeight(weight));
40        }
41
42        let arity = entity_ids.len();
43        Ok(Self {
44            id: None,
45            hyperedge_type: hyperedge_type.into(),
46            entity_ids,
47            weight,
48            arity,
49            properties: HashMap::new(),
50            created_at: None,
51            updated_at: None,
52        })
53    }
54
55    /// Set a property on the hyperedge.
56    pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
57        self.properties.insert(key.into(), value);
58    }
59
60    /// Get a property value.
61    pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
62        self.properties.get(key)
63    }
64
65    /// Check if an entity participates in this hyperedge.
66    pub fn contains(&self, entity_id: i64) -> bool {
67        self.entity_ids.contains(&entity_id)
68    }
69
70    /// Get the entity set for efficient set operations.
71    pub fn entity_set(&self) -> HashSet<i64> {
72        self.entity_ids.iter().copied().collect()
73    }
74
75    /// Compute intersection with another hyperedge - O(k1 + k2).
76    pub fn intersection(&self, other: &Hyperedge) -> Vec<i64> {
77        let set1 = self.entity_set();
78        let set2 = other.entity_set();
79        set1.intersection(&set2).copied().collect()
80    }
81
82    /// Check if this hyperedge shares any entity with another - O(k1 + k2).
83    pub fn has_intersection(&self, other: &Hyperedge) -> bool {
84        let set1 = self.entity_set();
85        other.entity_ids.iter().any(|id| set1.contains(id))
86    }
87}
88
89/// A higher-order neighbor: an entity connected through a hyperedge.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct HigherOrderNeighbor {
92    pub entity: Entity,
93    pub hyperedge: Hyperedge,
94    pub position: Option<usize>,
95}
96
97/// A step in a higher-order path.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct HigherOrderPathStep {
100    pub hyperedge: Hyperedge,
101    pub from_entity: i64,
102    pub to_entity: i64,
103}
104
105/// A higher-order path between two entities.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct HigherOrderPath {
108    pub steps: Vec<HigherOrderPathStep>,
109    pub total_weight: f64,
110}
111
112/// Insert a hyperedge into the database.
113pub fn insert_hyperedge(conn: &rusqlite::Connection, hyperedge: &Hyperedge) -> Result<i64> {
114    // Validate all entities exist
115    for entity_id in &hyperedge.entity_ids {
116        get_entity(conn, *entity_id)?;
117    }
118
119    let entity_ids_json = serde_json::to_string(&hyperedge.entity_ids)?;
120    let properties_json = serde_json::to_string(&hyperedge.properties)?;
121
122    let tx = conn.unchecked_transaction()?;
123
124    tx.execute(
125        r#"
126        INSERT INTO kg_hyperedges (hyperedge_type, entity_ids, weight, arity, properties)
127        VALUES (?1, ?2, ?3, ?4, ?5)
128        "#,
129        params![
130            hyperedge.hyperedge_type,
131            entity_ids_json,
132            hyperedge.weight,
133            hyperedge.arity as i64,
134            properties_json
135        ],
136    )?;
137
138    let hyperedge_id = tx.last_insert_rowid();
139
140    // Insert entity-hyperedge associations
141    for (position, entity_id) in hyperedge.entity_ids.iter().enumerate() {
142        tx.execute(
143            "INSERT INTO kg_hyperedge_entities (hyperedge_id, entity_id, position) VALUES (?1, ?2, ?3)",
144            params![hyperedge_id, entity_id, position as i64],
145        )?;
146    }
147
148    tx.commit()?;
149    Ok(hyperedge_id)
150}
151
152/// Get a hyperedge by ID.
153pub fn get_hyperedge(conn: &rusqlite::Connection, id: i64) -> Result<Hyperedge> {
154    conn.query_row(
155        r#"
156        SELECT id, hyperedge_type, entity_ids, weight, arity, properties, created_at, updated_at
157        FROM kg_hyperedges WHERE id = ?1
158        "#,
159        params![id],
160        |row| {
161            let entity_ids_json: String = row.get(2)?;
162            let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
163
164            let properties_json: Option<String> = row.get(5)?;
165            let properties: HashMap<String, serde_json::Value> = properties_json
166                .as_deref()
167                .and_then(|s| serde_json::from_str(s).ok())
168                .unwrap_or_default();
169
170            let arity = entity_ids.len();
171            Ok(Hyperedge {
172                id: Some(row.get(0)?),
173                hyperedge_type: row.get(1)?,
174                entity_ids,
175                weight: row.get(3)?,
176                arity,
177                properties,
178                created_at: row.get(6)?,
179                updated_at: row.get(7)?,
180            })
181        },
182    )
183    .map_err(|_| Error::HyperedgeNotFound(id))
184}
185
186/// List hyperedges with optional filtering.
187pub fn list_hyperedges(
188    conn: &rusqlite::Connection,
189    hyperedge_type: Option<&str>,
190    min_arity: Option<usize>,
191    max_arity: Option<usize>,
192    limit: Option<i64>,
193) -> Result<Vec<Hyperedge>> {
194    let mut query = "SELECT id, hyperedge_type, entity_ids, weight, arity, properties, created_at, updated_at FROM kg_hyperedges WHERE 1=1".to_string();
195    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
196    let mut param_idx = 1;
197
198    if let Some(ht) = hyperedge_type {
199        query.push_str(&format!(" AND hyperedge_type = ?{param_idx}"));
200        params_vec.push(Box::new(ht.to_string()));
201        param_idx += 1;
202    }
203
204    if let Some(min) = min_arity {
205        query.push_str(&format!(" AND arity >= ?{param_idx}"));
206        params_vec.push(Box::new(min as i64));
207        param_idx += 1;
208    }
209
210    if let Some(max) = max_arity {
211        query.push_str(&format!(" AND arity <= ?{param_idx}"));
212        params_vec.push(Box::new(max as i64));
213        param_idx += 1;
214    }
215
216    query.push_str(" ORDER BY created_at DESC");
217
218    if let Some(lim) = limit {
219        query.push_str(&format!(" LIMIT ?{param_idx}"));
220        params_vec.push(Box::new(lim));
221    }
222
223    let mut stmt = conn.prepare(&query)?;
224    let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
225
226    let rows = stmt.query_map(params_refs.as_slice(), |row| {
227        let entity_ids_json: String = row.get(2)?;
228        let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
229
230        let properties_json: Option<String> = row.get(5)?;
231        let properties: HashMap<String, serde_json::Value> = properties_json
232            .as_deref()
233            .and_then(|s| serde_json::from_str(s).ok())
234            .unwrap_or_default();
235
236        let arity = entity_ids.len();
237        Ok(Hyperedge {
238            id: Some(row.get(0)?),
239            hyperedge_type: row.get(1)?,
240            entity_ids,
241            weight: row.get(3)?,
242            arity,
243            properties,
244            created_at: row.get(6)?,
245            updated_at: row.get(7)?,
246        })
247    })?;
248
249    let mut result = Vec::new();
250    for row in rows {
251        result.push(row?);
252    }
253    Ok(result)
254}
255
256/// Update a hyperedge.
257pub fn update_hyperedge(conn: &rusqlite::Connection, hyperedge: &Hyperedge) -> Result<()> {
258    let id = hyperedge.id.ok_or(Error::HyperedgeNotFound(0))?;
259
260    // Validate all entities exist
261    for entity_id in &hyperedge.entity_ids {
262        get_entity(conn, *entity_id)?;
263    }
264
265    let entity_ids_json = serde_json::to_string(&hyperedge.entity_ids)?;
266    let properties_json = serde_json::to_string(&hyperedge.properties)?;
267
268    let updated_at = std::time::SystemTime::now()
269        .duration_since(std::time::UNIX_EPOCH)
270        .unwrap()
271        .as_secs() as i64;
272
273    let tx = conn.unchecked_transaction()?;
274
275    let affected = tx.execute(
276        r#"
277        UPDATE kg_hyperedges
278        SET hyperedge_type = ?1, entity_ids = ?2, weight = ?3, arity = ?4, properties = ?5, updated_at = ?6
279        WHERE id = ?7
280        "#,
281        params![
282            hyperedge.hyperedge_type,
283            entity_ids_json,
284            hyperedge.weight,
285            hyperedge.arity as i64,
286            properties_json,
287            updated_at,
288            id
289        ],
290    )?;
291
292    if affected == 0 {
293        return Err(Error::HyperedgeNotFound(id));
294    }
295
296    // Rebuild entity associations
297    tx.execute(
298        "DELETE FROM kg_hyperedge_entities WHERE hyperedge_id = ?1",
299        params![id],
300    )?;
301
302    for (position, entity_id) in hyperedge.entity_ids.iter().enumerate() {
303        tx.execute(
304            "INSERT INTO kg_hyperedge_entities (hyperedge_id, entity_id, position) VALUES (?1, ?2, ?3)",
305            params![id, entity_id, position as i64],
306        )?;
307    }
308
309    tx.commit()?;
310    Ok(())
311}
312
313/// Delete a hyperedge by ID.
314pub fn delete_hyperedge(conn: &rusqlite::Connection, id: i64) -> Result<()> {
315    let affected = conn.execute("DELETE FROM kg_hyperedges WHERE id = ?1", params![id])?;
316    if affected == 0 {
317        return Err(Error::HyperedgeNotFound(id));
318    }
319    Ok(())
320}
321
322/// Get higher-order neighbors of an entity (entities connected through hyperedges).
323pub fn get_higher_order_neighbors(
324    conn: &rusqlite::Connection,
325    entity_id: i64,
326    min_arity: Option<usize>,
327    max_arity: Option<usize>,
328) -> Result<Vec<HigherOrderNeighbor>> {
329    // Validate entity exists
330    get_entity(conn, entity_id)?;
331
332    let min_arity = min_arity.unwrap_or(2) as i64;
333    let max_arity = max_arity.unwrap_or(100) as i64;
334
335    let mut stmt = conn.prepare(
336        r#"
337        SELECT h.id, h.hyperedge_type, h.entity_ids, h.weight, h.arity, h.properties,
338               h.created_at, h.updated_at,
339               he2.entity_id as neighbor_id, he2.position
340        FROM kg_hyperedge_entities he
341        JOIN kg_hyperedges h ON he.hyperedge_id = h.id
342        JOIN kg_hyperedge_entities he2 ON h.id = he2.hyperedge_id
343        WHERE he.entity_id = ?1
344          AND he2.entity_id != ?1
345          AND h.arity >= ?2
346          AND h.arity <= ?3
347        ORDER BY h.weight DESC
348        "#,
349    )?;
350
351    let rows = stmt.query_map(params![entity_id, min_arity, max_arity], |row| {
352        let entity_ids_json: String = row.get(2)?;
353        let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
354
355        let properties_json: Option<String> = row.get(5)?;
356        let properties: HashMap<String, serde_json::Value> = properties_json
357            .as_deref()
358            .and_then(|s| serde_json::from_str(s).ok())
359            .unwrap_or_default();
360
361        let arity = entity_ids.len();
362        let neighbor_id: i64 = row.get(8)?;
363        let position: i64 = row.get(9)?;
364
365        Ok((
366            Hyperedge {
367                id: Some(row.get(0)?),
368                hyperedge_type: row.get(1)?,
369                entity_ids,
370                weight: row.get(3)?,
371                arity,
372                properties,
373                created_at: row.get(6)?,
374                updated_at: row.get(7)?,
375            },
376            neighbor_id,
377            position as usize,
378        ))
379    })?;
380
381    let mut result = Vec::new();
382    for row in rows {
383        let (hyperedge, neighbor_id, position) = row?;
384        let entity = get_entity(conn, neighbor_id)?;
385        result.push(HigherOrderNeighbor {
386            entity,
387            hyperedge,
388            position: Some(position),
389        });
390    }
391
392    Ok(result)
393}
394
395/// Get all hyperedges that an entity participates in.
396pub fn get_entity_hyperedges(
397    conn: &rusqlite::Connection,
398    entity_id: i64,
399) -> Result<Vec<Hyperedge>> {
400    get_entity(conn, entity_id)?;
401
402    let mut stmt = conn.prepare(
403        r#"
404        SELECT h.id, h.hyperedge_type, h.entity_ids, h.weight, h.arity, h.properties,
405               h.created_at, h.updated_at
406        FROM kg_hyperedge_entities he
407        JOIN kg_hyperedges h ON he.hyperedge_id = h.id
408        WHERE he.entity_id = ?1
409        ORDER BY h.created_at DESC
410        "#,
411    )?;
412
413    let rows = stmt.query_map(params![entity_id], |row| {
414        let entity_ids_json: String = row.get(2)?;
415        let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
416
417        let properties_json: Option<String> = row.get(5)?;
418        let properties: HashMap<String, serde_json::Value> = properties_json
419            .as_deref()
420            .and_then(|s| serde_json::from_str(s).ok())
421            .unwrap_or_default();
422
423        let arity = entity_ids.len();
424        Ok(Hyperedge {
425            id: Some(row.get(0)?),
426            hyperedge_type: row.get(1)?,
427            entity_ids,
428            weight: row.get(3)?,
429            arity,
430            properties,
431            created_at: row.get(6)?,
432            updated_at: row.get(7)?,
433        })
434    })?;
435
436    let mut result = Vec::new();
437    for row in rows {
438        result.push(row?);
439    }
440    Ok(result)
441}
442
443/// Higher-order BFS traversal through hyperedges.
444pub fn higher_order_bfs(
445    conn: &rusqlite::Connection,
446    start_id: i64,
447    max_depth: u32,
448    min_arity: Option<usize>,
449) -> Result<Vec<crate::graph::traversal::TraversalNode>> {
450    use crate::graph::traversal::TraversalNode;
451
452    if max_depth == 0 {
453        return Ok(Vec::new());
454    }
455    if max_depth > 10 {
456        return Err(Error::InvalidDepth(max_depth));
457    }
458
459    get_entity(conn, start_id)?;
460
461    let mut visited = HashSet::new();
462    let mut queue = VecDeque::new();
463    let mut result = Vec::new();
464
465    visited.insert(start_id);
466    queue.push_back((start_id, 0u32));
467
468    while let Some((current_id, depth)) = queue.pop_front() {
469        if depth >= max_depth {
470            continue;
471        }
472
473        let neighbors = get_higher_order_neighbors(conn, current_id, min_arity, None)?;
474
475        for neighbor in neighbors {
476            let neighbor_id = neighbor.entity.id.unwrap();
477            if !visited.contains(&neighbor_id) {
478                visited.insert(neighbor_id);
479                queue.push_back((neighbor_id, depth + 1));
480                result.push(TraversalNode {
481                    entity_id: neighbor_id,
482                    entity_type: neighbor.entity.entity_type.clone(),
483                    depth: depth + 1,
484                });
485            }
486        }
487    }
488
489    Ok(result)
490}
491
492/// Find shortest path between two entities through hyperedges.
493pub fn higher_order_shortest_path(
494    conn: &rusqlite::Connection,
495    from_id: i64,
496    to_id: i64,
497    max_depth: u32,
498) -> Result<Option<HigherOrderPath>> {
499    if max_depth == 0 {
500        return Ok(None);
501    }
502    if max_depth > 10 {
503        return Err(Error::InvalidDepth(max_depth));
504    }
505
506    get_entity(conn, from_id)?;
507    get_entity(conn, to_id)?;
508
509    if from_id == to_id {
510        return Ok(Some(HigherOrderPath {
511            steps: Vec::new(),
512            total_weight: 0.0,
513        }));
514    }
515
516    let mut visited = HashSet::new();
517    let mut queue: VecDeque<(i64, u32)> = VecDeque::new();
518    // parent map: entity_id -> (parent_entity_id, hyperedge used)
519    let mut parent: HashMap<i64, (i64, Hyperedge)> = HashMap::new();
520
521    visited.insert(from_id);
522    queue.push_back((from_id, 0));
523
524    while let Some((current_id, depth)) = queue.pop_front() {
525        if depth >= max_depth {
526            continue;
527        }
528
529        let neighbors = get_higher_order_neighbors(conn, current_id, None, None)?;
530
531        for neighbor in neighbors {
532            let neighbor_id = neighbor.entity.id.unwrap();
533            if !visited.contains(&neighbor_id) {
534                visited.insert(neighbor_id);
535                parent.insert(neighbor_id, (current_id, neighbor.hyperedge));
536                if neighbor_id == to_id {
537                    // Reconstruct path
538                    return Ok(Some(reconstruct_path(&parent, from_id, to_id)));
539                }
540                queue.push_back((neighbor_id, depth + 1));
541            }
542        }
543    }
544
545    Ok(None)
546}
547
548fn reconstruct_path(
549    parent: &HashMap<i64, (i64, Hyperedge)>,
550    from_id: i64,
551    to_id: i64,
552) -> HigherOrderPath {
553    let mut steps = Vec::new();
554    let mut current = to_id;
555    let mut total_weight = 0.0;
556
557    while current != from_id {
558        let (prev, hyperedge) = parent.get(&current).unwrap();
559        total_weight += hyperedge.weight;
560        steps.push(HigherOrderPathStep {
561            hyperedge: hyperedge.clone(),
562            from_entity: *prev,
563            to_entity: current,
564        });
565        current = *prev;
566    }
567
568    steps.reverse();
569    HigherOrderPath {
570        steps,
571        total_weight,
572    }
573}
574
575/// Compute hyperedge degree centrality for an entity.
576pub fn hyperedge_degree(conn: &rusqlite::Connection, entity_id: i64) -> Result<f64> {
577    get_entity(conn, entity_id)?;
578
579    let count: i64 = conn.query_row(
580        "SELECT COUNT(DISTINCT hyperedge_id) FROM kg_hyperedge_entities WHERE entity_id = ?1",
581        params![entity_id],
582        |row| row.get(0),
583    )?;
584
585    Ok(count as f64)
586}
587
588/// Load all hyperedges from the database.
589pub fn load_all_hyperedges(conn: &rusqlite::Connection) -> Result<Vec<Hyperedge>> {
590    list_hyperedges(conn, None, None, None, None)
591}
592
593/// Compute entity-level hypergraph PageRank using Zhou formula.
594///
595/// Based on Zhou et al. (2006) - "Learning with Hypergraphs".
596///
597/// PR(v) = (1-d)/n + d * sum_{e: v in e} [w(e)/delta(e) * sum_{u in e, u!=v} PR(u) * (1/d(u)) * (1/delta(e))]
598///
599/// Simplified: PR(v) = (1-d)/n + d * sum_{e: v in e} [w(e)/delta(e)^2 * sum_{u in e, u!=v} PR(u)/d(u)]
600///
601/// Complexity: O(T * sum_e k_e^2), much faster than naive O(n^2) approaches.
602pub fn hypergraph_entity_pagerank(
603    conn: &rusqlite::Connection,
604    damping: f64,
605    max_iter: usize,
606    tolerance: f64,
607) -> Result<HashMap<i64, f64>> {
608    let hyperedges = load_all_hyperedges(conn)?;
609
610    if hyperedges.is_empty() {
611        return Ok(HashMap::new());
612    }
613
614    // Collect all entity IDs that appear in hyperedges
615    let mut all_entities: HashSet<i64> = HashSet::new();
616    for he in &hyperedges {
617        for &eid in &he.entity_ids {
618            all_entities.insert(eid);
619        }
620    }
621
622    let n = all_entities.len() as f64;
623    if n == 0.0 {
624        return Ok(HashMap::new());
625    }
626
627    // Compute hyperedge degree d(v) for each entity
628    // d(v) = number of hyperedges containing v
629    let mut entity_degree: HashMap<i64, usize> = HashMap::new();
630    for he in &hyperedges {
631        for &eid in &he.entity_ids {
632            *entity_degree.entry(eid).or_insert(0) += 1;
633        }
634    }
635
636    // Initialize PageRank scores uniformly
637    let mut scores: HashMap<i64, f64> = all_entities.iter().map(|&id| (id, 1.0 / n)).collect();
638
639    // Iterative update using Zhou formula
640    for _ in 0..max_iter {
641        let mut new_scores: HashMap<i64, f64> = HashMap::new();
642
643        // Initialize with random jump term
644        for &eid in &all_entities {
645            new_scores.insert(eid, (1.0 - damping) / n);
646        }
647
648        // For each hyperedge e, compute contribution to its entities
649        for he in &hyperedges {
650            let w_e = he.weight;
651            let delta_e = he.arity as f64;
652            // Zhou formula uses 1/delta(e) for each vertex in the hyperedge
653            let inv_delta = 1.0 / delta_e;
654
655            // Compute sum of PR(u)/d(u) for all u in e
656            let sum_pr_d: f64 = he
657                .entity_ids
658                .iter()
659                .map(|&u| {
660                    let d_u = *entity_degree.get(&u).unwrap_or(&1) as f64;
661                    let pr_u = scores.get(&u).copied().unwrap_or(0.0);
662                    pr_u / d_u
663                })
664                .sum();
665
666            // For each v in e, add contribution from other vertices
667            for &v in &he.entity_ids {
668                let d_v = *entity_degree.get(&v).unwrap_or(&1) as f64;
669                let pr_v = scores.get(&v).copied().unwrap_or(0.0);
670
671                // Subtract v's own contribution to get sum of u != v
672                let sum_pr_d_excluding_v = sum_pr_d - pr_v / d_v;
673
674                // Zhou formula: w(e) / delta(e)^2 * sum_{u != v} PR(u)/d(u)
675                let contribution = damping * w_e * inv_delta * inv_delta * sum_pr_d_excluding_v;
676
677                *new_scores.entry(v).or_insert(0.0) += contribution;
678            }
679        }
680
681        // Normalize scores to ensure sum = 1.0
682        let total: f64 = new_scores.values().sum();
683        if total > 0.0 {
684            for score in new_scores.values_mut() {
685                *score /= total;
686            }
687        }
688
689        // Check convergence
690        let diff: f64 = all_entities
691            .iter()
692            .map(|id| (new_scores.get(id).unwrap_or(&0.0) - scores.get(id).unwrap_or(&0.0)).abs())
693            .sum();
694
695        scores = new_scores;
696
697        if diff < tolerance {
698            break;
699        }
700    }
701
702    Ok(scores)
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use crate::graph::entity::insert_entity;
709    use rusqlite::Connection;
710
711    fn setup_db() -> Connection {
712        let conn = Connection::open_in_memory().unwrap();
713        conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
714        crate::schema::create_schema(&conn).unwrap();
715        conn
716    }
717
718    fn create_test_entities(conn: &Connection, count: usize) -> Vec<i64> {
719        (0..count)
720            .map(|i| insert_entity(conn, &Entity::new("person", format!("Person {i}"))).unwrap())
721            .collect()
722    }
723
724    #[test]
725    fn test_hyperedge_creation() {
726        let he = Hyperedge::new(vec![1, 2, 3], "collaboration", 0.8).unwrap();
727        assert_eq!(he.arity, 3);
728        assert!(he.contains(1));
729        assert!(he.contains(2));
730        assert!(he.contains(3));
731        assert!(!he.contains(4));
732    }
733
734    #[test]
735    fn test_hyperedge_invalid_arity() {
736        let result = Hyperedge::new(vec![1], "test", 0.5);
737        assert!(result.is_err());
738
739        let result = Hyperedge::new(vec![], "test", 0.5);
740        assert!(result.is_err());
741    }
742
743    #[test]
744    fn test_hyperedge_invalid_weight() {
745        let result = Hyperedge::new(vec![1, 2], "test", 1.5);
746        assert!(result.is_err());
747
748        let result = Hyperedge::new(vec![1, 2], "test", -0.1);
749        assert!(result.is_err());
750    }
751
752    #[test]
753    fn test_hyperedge_intersection() {
754        let he1 = Hyperedge::new(vec![1, 2, 3], "a", 0.5).unwrap();
755        let he2 = Hyperedge::new(vec![2, 3, 4], "b", 0.5).unwrap();
756        let mut inter = he1.intersection(&he2);
757        inter.sort();
758        assert_eq!(inter, vec![2, 3]);
759        assert!(he1.has_intersection(&he2));
760    }
761
762    #[test]
763    fn test_hyperedge_no_intersection() {
764        let he1 = Hyperedge::new(vec![1, 2], "a", 0.5).unwrap();
765        let he2 = Hyperedge::new(vec![3, 4], "b", 0.5).unwrap();
766        assert!(he1.intersection(&he2).is_empty());
767        assert!(!he1.has_intersection(&he2));
768    }
769
770    #[test]
771    fn test_insert_and_get_hyperedge() {
772        let conn = setup_db();
773        let ids = create_test_entities(&conn, 3);
774
775        let he = Hyperedge::new(ids.clone(), "collaboration", 0.8).unwrap();
776        let he_id = insert_hyperedge(&conn, &he).unwrap();
777        assert!(he_id > 0);
778
779        let retrieved = get_hyperedge(&conn, he_id).unwrap();
780        assert_eq!(retrieved.arity, 3);
781        assert_eq!(retrieved.hyperedge_type, "collaboration");
782        assert_eq!(retrieved.entity_ids, ids);
783        assert!((retrieved.weight - 0.8).abs() < f64::EPSILON);
784    }
785
786    #[test]
787    fn test_list_hyperedges() {
788        let conn = setup_db();
789        let ids = create_test_entities(&conn, 5);
790
791        insert_hyperedge(
792            &conn,
793            &Hyperedge::new(ids[0..3].to_vec(), "team", 0.9).unwrap(),
794        )
795        .unwrap();
796        insert_hyperedge(
797            &conn,
798            &Hyperedge::new(ids[2..5].to_vec(), "team", 0.8).unwrap(),
799        )
800        .unwrap();
801        insert_hyperedge(&conn, &Hyperedge::new(ids.clone(), "project", 0.7).unwrap()).unwrap();
802
803        let all = list_hyperedges(&conn, None, None, None, None).unwrap();
804        assert_eq!(all.len(), 3);
805
806        let teams = list_hyperedges(&conn, Some("team"), None, None, None).unwrap();
807        assert_eq!(teams.len(), 2);
808
809        let big = list_hyperedges(&conn, None, Some(4), None, None).unwrap();
810        assert_eq!(big.len(), 1);
811    }
812
813    #[test]
814    fn test_update_hyperedge() {
815        let conn = setup_db();
816        let ids = create_test_entities(&conn, 4);
817
818        let he = Hyperedge::new(ids[0..3].to_vec(), "team", 0.9).unwrap();
819        let he_id = insert_hyperedge(&conn, &he).unwrap();
820
821        let mut updated = get_hyperedge(&conn, he_id).unwrap();
822        updated.entity_ids = ids.clone();
823        updated.arity = ids.len();
824        updated.weight = 0.7;
825        update_hyperedge(&conn, &updated).unwrap();
826
827        let retrieved = get_hyperedge(&conn, he_id).unwrap();
828        assert_eq!(retrieved.arity, 4);
829        assert!((retrieved.weight - 0.7).abs() < f64::EPSILON);
830    }
831
832    #[test]
833    fn test_delete_hyperedge() {
834        let conn = setup_db();
835        let ids = create_test_entities(&conn, 3);
836
837        let he = Hyperedge::new(ids, "team", 0.9).unwrap();
838        let he_id = insert_hyperedge(&conn, &he).unwrap();
839
840        delete_hyperedge(&conn, he_id).unwrap();
841        assert!(get_hyperedge(&conn, he_id).is_err());
842    }
843
844    #[test]
845    fn test_delete_hyperedge_not_found() {
846        let conn = setup_db();
847        assert!(delete_hyperedge(&conn, 999).is_err());
848    }
849
850    #[test]
851    fn test_higher_order_neighbors() {
852        let conn = setup_db();
853        let ids = create_test_entities(&conn, 5);
854
855        // Team 1: Person 0, 1, 2
856        insert_hyperedge(
857            &conn,
858            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
859        )
860        .unwrap();
861
862        // Team 2: Person 2, 3, 4
863        insert_hyperedge(
864            &conn,
865            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
866        )
867        .unwrap();
868
869        // Neighbors of Person 0 through hyperedges
870        let neighbors = get_higher_order_neighbors(&conn, ids[0], None, None).unwrap();
871        assert_eq!(neighbors.len(), 2); // Person 1, Person 2
872
873        let neighbor_ids: HashSet<i64> = neighbors.iter().map(|n| n.entity.id.unwrap()).collect();
874        assert!(neighbor_ids.contains(&ids[1]));
875        assert!(neighbor_ids.contains(&ids[2]));
876    }
877
878    #[test]
879    fn test_entity_hyperedges() {
880        let conn = setup_db();
881        let ids = create_test_entities(&conn, 4);
882
883        insert_hyperedge(
884            &conn,
885            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
886        )
887        .unwrap();
888        insert_hyperedge(
889            &conn,
890            &Hyperedge::new(vec![ids[0], ids[3]], "pair", 0.5).unwrap(),
891        )
892        .unwrap();
893
894        let hyperedges = get_entity_hyperedges(&conn, ids[0]).unwrap();
895        assert_eq!(hyperedges.len(), 2);
896    }
897
898    #[test]
899    fn test_higher_order_bfs() {
900        let conn = setup_db();
901        let ids = create_test_entities(&conn, 5);
902
903        // Chain: {0,1,2} -- {2,3,4}
904        insert_hyperedge(
905            &conn,
906            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
907        )
908        .unwrap();
909        insert_hyperedge(
910            &conn,
911            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
912        )
913        .unwrap();
914
915        let traversal = higher_order_bfs(&conn, ids[0], 2, None).unwrap();
916        let traversed_ids: HashSet<i64> = traversal.iter().map(|n| n.entity_id).collect();
917
918        // Should reach all other entities through the chain
919        assert!(traversed_ids.contains(&ids[1]));
920        assert!(traversed_ids.contains(&ids[2]));
921        assert!(traversed_ids.contains(&ids[3]));
922        assert!(traversed_ids.contains(&ids[4]));
923    }
924
925    #[test]
926    fn test_higher_order_shortest_path() {
927        let conn = setup_db();
928        let ids = create_test_entities(&conn, 5);
929
930        // {0,1,2} and {2,3,4}
931        insert_hyperedge(
932            &conn,
933            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
934        )
935        .unwrap();
936        insert_hyperedge(
937            &conn,
938            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
939        )
940        .unwrap();
941
942        // Path from 0 to 4 should go through entity 2
943        let path = higher_order_shortest_path(&conn, ids[0], ids[4], 5)
944            .unwrap()
945            .unwrap();
946        assert_eq!(path.steps.len(), 2);
947
948        // No path if max_depth is too small
949        let path = higher_order_shortest_path(&conn, ids[0], ids[4], 0).unwrap();
950        assert!(path.is_none());
951    }
952
953    #[test]
954    fn test_hyperedge_degree() {
955        let conn = setup_db();
956        let ids = create_test_entities(&conn, 4);
957
958        insert_hyperedge(
959            &conn,
960            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
961        )
962        .unwrap();
963        insert_hyperedge(
964            &conn,
965            &Hyperedge::new(vec![ids[0], ids[3]], "pair", 0.5).unwrap(),
966        )
967        .unwrap();
968
969        assert!((hyperedge_degree(&conn, ids[0]).unwrap() - 2.0).abs() < f64::EPSILON);
970        assert!((hyperedge_degree(&conn, ids[1]).unwrap() - 1.0).abs() < f64::EPSILON);
971        assert!((hyperedge_degree(&conn, ids[3]).unwrap() - 1.0).abs() < f64::EPSILON);
972    }
973
974    #[test]
975    fn test_hypergraph_entity_pagerank() {
976        let conn = setup_db();
977        let ids = create_test_entities(&conn, 5);
978
979        // {0,1,2} and {2,3,4} - entity 2 is the bridge
980        insert_hyperedge(
981            &conn,
982            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
983        )
984        .unwrap();
985        insert_hyperedge(
986            &conn,
987            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
988        )
989        .unwrap();
990
991        let scores = hypergraph_entity_pagerank(&conn, 0.85, 100, 1e-6).unwrap();
992
993        // All 5 entities should have scores
994        assert_eq!(scores.len(), 5);
995
996        // Entity 2 (bridge) should have highest score
997        let score_2 = scores[&ids[2]];
998        for &id in &ids {
999            if id != ids[2] {
1000                assert!(
1001                    score_2 >= scores[&id],
1002                    "Bridge entity should have highest PageRank"
1003                );
1004            }
1005        }
1006
1007        // Scores should sum to approximately 1.0
1008        let total: f64 = scores.values().sum();
1009        assert!(
1010            (total - 1.0).abs() < 0.01,
1011            "PageRank scores should sum to ~1.0, got {total}"
1012        );
1013    }
1014
1015    #[test]
1016    fn test_hypergraph_pagerank_empty() {
1017        let conn = setup_db();
1018        let scores = hypergraph_entity_pagerank(&conn, 0.85, 100, 1e-6).unwrap();
1019        assert!(scores.is_empty());
1020    }
1021
1022    #[test]
1023    fn test_hyperedge_properties() {
1024        let conn = setup_db();
1025        let ids = create_test_entities(&conn, 3);
1026
1027        let mut he = Hyperedge::new(ids, "team", 0.9).unwrap();
1028        he.set_property("project", serde_json::json!("Alpha"));
1029        he.set_property("start_date", serde_json::json!("2026-01-01"));
1030
1031        let he_id = insert_hyperedge(&conn, &he).unwrap();
1032        let retrieved = get_hyperedge(&conn, he_id).unwrap();
1033
1034        assert_eq!(
1035            retrieved.get_property("project"),
1036            Some(&serde_json::json!("Alpha"))
1037        );
1038        assert_eq!(
1039            retrieved.get_property("start_date"),
1040            Some(&serde_json::json!("2026-01-01"))
1041        );
1042    }
1043}