Skip to main content

sqlite_knowledge_graph/graph/
hyperedge.rs

1//! Hyperedge (higher-order relation) storage module.
2//!
3//! Provides storage and querying for multi-entity relationships (hyperedges).
4//! A hyperedge connects 2 or more entities in a single relation.
5
6use rusqlite::params;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9
10use crate::error::{Error, Result};
11use crate::graph::entity::{get_entity, Entity};
12
13/// A hyperedge representing a higher-order relation among multiple entities.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Hyperedge {
16    pub id: Option<i64>,
17    pub hyperedge_type: String,
18    pub entity_ids: Vec<i64>,
19    pub weight: f64,
20    pub arity: usize,
21    pub properties: HashMap<String, serde_json::Value>,
22    pub created_at: Option<i64>,
23    pub updated_at: Option<i64>,
24}
25
26impl Hyperedge {
27    /// Create a new hyperedge.
28    ///
29    /// Requires at least 2 entities and weight in [0.0, 1.0].
30    pub fn new(
31        entity_ids: Vec<i64>,
32        hyperedge_type: impl Into<String>,
33        weight: f64,
34    ) -> Result<Self> {
35        if entity_ids.len() < 2 {
36            return Err(Error::InvalidArity(entity_ids.len()));
37        }
38        if !(0.0..=1.0).contains(&weight) {
39            return Err(Error::InvalidWeight(weight));
40        }
41
42        let arity = entity_ids.len();
43        Ok(Self {
44            id: None,
45            hyperedge_type: hyperedge_type.into(),
46            entity_ids,
47            weight,
48            arity,
49            properties: HashMap::new(),
50            created_at: None,
51            updated_at: None,
52        })
53    }
54
55    /// Set a property on the hyperedge.
56    pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
57        self.properties.insert(key.into(), value);
58    }
59
60    /// Get a property value.
61    pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
62        self.properties.get(key)
63    }
64
65    /// Check if an entity participates in this hyperedge.
66    pub fn contains(&self, entity_id: i64) -> bool {
67        self.entity_ids.contains(&entity_id)
68    }
69
70    /// Get the entity set for efficient set operations.
71    pub fn entity_set(&self) -> HashSet<i64> {
72        self.entity_ids.iter().copied().collect()
73    }
74
75    /// Compute intersection with another hyperedge - O(k1 + k2).
76    pub fn intersection(&self, other: &Hyperedge) -> Vec<i64> {
77        let set1 = self.entity_set();
78        let set2 = other.entity_set();
79        set1.intersection(&set2).copied().collect()
80    }
81
82    /// Check if this hyperedge shares any entity with another - O(k1 + k2).
83    pub fn has_intersection(&self, other: &Hyperedge) -> bool {
84        let set1 = self.entity_set();
85        other.entity_ids.iter().any(|id| set1.contains(id))
86    }
87}
88
89/// A higher-order neighbor: an entity connected through a hyperedge.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct HigherOrderNeighbor {
92    pub entity: Entity,
93    pub hyperedge: Hyperedge,
94    pub position: Option<usize>,
95}
96
97/// A step in a higher-order path.
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct HigherOrderPathStep {
100    pub hyperedge: Hyperedge,
101    pub from_entity: i64,
102    pub to_entity: i64,
103}
104
105/// A higher-order path between two entities.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct HigherOrderPath {
108    pub steps: Vec<HigherOrderPathStep>,
109    pub total_weight: f64,
110}
111
112/// Insert a hyperedge into the database.
113pub fn insert_hyperedge(conn: &rusqlite::Connection, hyperedge: &Hyperedge) -> Result<i64> {
114    // Validate all entities exist
115    for entity_id in &hyperedge.entity_ids {
116        get_entity(conn, *entity_id)?;
117    }
118
119    let entity_ids_json = serde_json::to_string(&hyperedge.entity_ids)?;
120    let properties_json = serde_json::to_string(&hyperedge.properties)?;
121
122    let tx = conn.unchecked_transaction()?;
123
124    tx.execute(
125        r#"
126        INSERT INTO kg_hyperedges (hyperedge_type, entity_ids, weight, arity, properties)
127        VALUES (?1, ?2, ?3, ?4, ?5)
128        "#,
129        params![
130            hyperedge.hyperedge_type,
131            entity_ids_json,
132            hyperedge.weight,
133            hyperedge.arity as i64,
134            properties_json
135        ],
136    )?;
137
138    let hyperedge_id = tx.last_insert_rowid();
139
140    // Insert entity-hyperedge associations
141    for (position, entity_id) in hyperedge.entity_ids.iter().enumerate() {
142        tx.execute(
143            "INSERT INTO kg_hyperedge_entities (hyperedge_id, entity_id, position) VALUES (?1, ?2, ?3)",
144            params![hyperedge_id, entity_id, position as i64],
145        )?;
146    }
147
148    tx.commit()?;
149    Ok(hyperedge_id)
150}
151
152/// Get a hyperedge by ID.
153pub fn get_hyperedge(conn: &rusqlite::Connection, id: i64) -> Result<Hyperedge> {
154    conn.query_row(
155        r#"
156        SELECT id, hyperedge_type, entity_ids, weight, arity, properties, created_at, updated_at
157        FROM kg_hyperedges WHERE id = ?1
158        "#,
159        params![id],
160        |row| {
161            let entity_ids_json: String = row.get(2)?;
162            let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
163
164            let properties_json: Option<String> = row.get(5)?;
165            let properties: HashMap<String, serde_json::Value> = properties_json
166                .as_deref()
167                .and_then(|s| serde_json::from_str(s).ok())
168                .unwrap_or_default();
169
170            let arity = entity_ids.len();
171            Ok(Hyperedge {
172                id: Some(row.get(0)?),
173                hyperedge_type: row.get(1)?,
174                entity_ids,
175                weight: row.get(3)?,
176                arity,
177                properties,
178                created_at: row.get(6)?,
179                updated_at: row.get(7)?,
180            })
181        },
182    )
183    .map_err(|_| Error::HyperedgeNotFound(id))
184}
185
186/// List hyperedges with optional filtering.
187pub fn list_hyperedges(
188    conn: &rusqlite::Connection,
189    hyperedge_type: Option<&str>,
190    min_arity: Option<usize>,
191    max_arity: Option<usize>,
192    limit: Option<i64>,
193) -> Result<Vec<Hyperedge>> {
194    let mut query = "SELECT id, hyperedge_type, entity_ids, weight, arity, properties, created_at, updated_at FROM kg_hyperedges WHERE 1=1".to_string();
195    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
196    let mut param_idx = 1;
197
198    if let Some(ht) = hyperedge_type {
199        query.push_str(&format!(" AND hyperedge_type = ?{param_idx}"));
200        params_vec.push(Box::new(ht.to_string()));
201        param_idx += 1;
202    }
203
204    if let Some(min) = min_arity {
205        query.push_str(&format!(" AND arity >= ?{param_idx}"));
206        params_vec.push(Box::new(min as i64));
207        param_idx += 1;
208    }
209
210    if let Some(max) = max_arity {
211        query.push_str(&format!(" AND arity <= ?{param_idx}"));
212        params_vec.push(Box::new(max as i64));
213        param_idx += 1;
214    }
215
216    query.push_str(" ORDER BY created_at DESC");
217
218    if let Some(lim) = limit {
219        query.push_str(&format!(" LIMIT ?{param_idx}"));
220        params_vec.push(Box::new(lim));
221    }
222
223    let mut stmt = conn.prepare(&query)?;
224    let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
225
226    let rows = stmt.query_map(params_refs.as_slice(), |row| {
227        let entity_ids_json: String = row.get(2)?;
228        let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
229
230        let properties_json: Option<String> = row.get(5)?;
231        let properties: HashMap<String, serde_json::Value> = properties_json
232            .as_deref()
233            .and_then(|s| serde_json::from_str(s).ok())
234            .unwrap_or_default();
235
236        let arity = entity_ids.len();
237        Ok(Hyperedge {
238            id: Some(row.get(0)?),
239            hyperedge_type: row.get(1)?,
240            entity_ids,
241            weight: row.get(3)?,
242            arity,
243            properties,
244            created_at: row.get(6)?,
245            updated_at: row.get(7)?,
246        })
247    })?;
248
249    let mut result = Vec::new();
250    for row in rows {
251        result.push(row?);
252    }
253    Ok(result)
254}
255
256/// Update a hyperedge.
257pub fn update_hyperedge(conn: &rusqlite::Connection, hyperedge: &Hyperedge) -> Result<()> {
258    let id = hyperedge.id.ok_or(Error::HyperedgeNotFound(0))?;
259
260    // Validate all entities exist
261    for entity_id in &hyperedge.entity_ids {
262        get_entity(conn, *entity_id)?;
263    }
264
265    let entity_ids_json = serde_json::to_string(&hyperedge.entity_ids)?;
266    let properties_json = serde_json::to_string(&hyperedge.properties)?;
267
268    let updated_at = std::time::SystemTime::now()
269        .duration_since(std::time::UNIX_EPOCH)
270        .unwrap()
271        .as_secs() as i64;
272
273    let tx = conn.unchecked_transaction()?;
274
275    let affected = tx.execute(
276        r#"
277        UPDATE kg_hyperedges
278        SET hyperedge_type = ?1, entity_ids = ?2, weight = ?3, arity = ?4, properties = ?5, updated_at = ?6
279        WHERE id = ?7
280        "#,
281        params![
282            hyperedge.hyperedge_type,
283            entity_ids_json,
284            hyperedge.weight,
285            hyperedge.arity as i64,
286            properties_json,
287            updated_at,
288            id
289        ],
290    )?;
291
292    if affected == 0 {
293        return Err(Error::HyperedgeNotFound(id));
294    }
295
296    // Rebuild entity associations
297    tx.execute(
298        "DELETE FROM kg_hyperedge_entities WHERE hyperedge_id = ?1",
299        params![id],
300    )?;
301
302    for (position, entity_id) in hyperedge.entity_ids.iter().enumerate() {
303        tx.execute(
304            "INSERT INTO kg_hyperedge_entities (hyperedge_id, entity_id, position) VALUES (?1, ?2, ?3)",
305            params![id, entity_id, position as i64],
306        )?;
307    }
308
309    tx.commit()?;
310    Ok(())
311}
312
313/// Delete a hyperedge by ID.
314pub fn delete_hyperedge(conn: &rusqlite::Connection, id: i64) -> Result<()> {
315    let affected = conn.execute("DELETE FROM kg_hyperedges WHERE id = ?1", params![id])?;
316    if affected == 0 {
317        return Err(Error::HyperedgeNotFound(id));
318    }
319    Ok(())
320}
321
322/// Get higher-order neighbors of an entity (entities connected through hyperedges).
323pub fn get_higher_order_neighbors(
324    conn: &rusqlite::Connection,
325    entity_id: i64,
326    min_arity: Option<usize>,
327    max_arity: Option<usize>,
328) -> Result<Vec<HigherOrderNeighbor>> {
329    // Validate entity exists
330    get_entity(conn, entity_id)?;
331
332    let min_arity = min_arity.unwrap_or(2) as i64;
333    let max_arity = max_arity.unwrap_or(100) as i64;
334
335    let mut stmt = conn.prepare(
336        r#"
337        SELECT h.id, h.hyperedge_type, h.entity_ids, h.weight, h.arity, h.properties,
338               h.created_at, h.updated_at,
339               he2.entity_id as neighbor_id, he2.position
340        FROM kg_hyperedge_entities he
341        JOIN kg_hyperedges h ON he.hyperedge_id = h.id
342        JOIN kg_hyperedge_entities he2 ON h.id = he2.hyperedge_id
343        WHERE he.entity_id = ?1
344          AND he2.entity_id != ?1
345          AND h.arity >= ?2
346          AND h.arity <= ?3
347        ORDER BY h.weight DESC
348        "#,
349    )?;
350
351    let rows = stmt.query_map(params![entity_id, min_arity, max_arity], |row| {
352        let entity_ids_json: String = row.get(2)?;
353        let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
354
355        let properties_json: Option<String> = row.get(5)?;
356        let properties: HashMap<String, serde_json::Value> = properties_json
357            .as_deref()
358            .and_then(|s| serde_json::from_str(s).ok())
359            .unwrap_or_default();
360
361        let arity = entity_ids.len();
362        let neighbor_id: i64 = row.get(8)?;
363        let position: i64 = row.get(9)?;
364
365        Ok((
366            Hyperedge {
367                id: Some(row.get(0)?),
368                hyperedge_type: row.get(1)?,
369                entity_ids,
370                weight: row.get(3)?,
371                arity,
372                properties,
373                created_at: row.get(6)?,
374                updated_at: row.get(7)?,
375            },
376            neighbor_id,
377            position as usize,
378        ))
379    })?;
380
381    let mut result = Vec::new();
382    for row in rows {
383        let (hyperedge, neighbor_id, position) = row?;
384        let entity = get_entity(conn, neighbor_id)?;
385        result.push(HigherOrderNeighbor {
386            entity,
387            hyperedge,
388            position: Some(position),
389        });
390    }
391
392    Ok(result)
393}
394
395/// Get all hyperedges that an entity participates in.
396pub fn get_entity_hyperedges(
397    conn: &rusqlite::Connection,
398    entity_id: i64,
399) -> Result<Vec<Hyperedge>> {
400    get_entity(conn, entity_id)?;
401
402    let mut stmt = conn.prepare(
403        r#"
404        SELECT h.id, h.hyperedge_type, h.entity_ids, h.weight, h.arity, h.properties,
405               h.created_at, h.updated_at
406        FROM kg_hyperedge_entities he
407        JOIN kg_hyperedges h ON he.hyperedge_id = h.id
408        WHERE he.entity_id = ?1
409        ORDER BY h.created_at DESC
410        "#,
411    )?;
412
413    let rows = stmt.query_map(params![entity_id], |row| {
414        let entity_ids_json: String = row.get(2)?;
415        let entity_ids: Vec<i64> = serde_json::from_str(&entity_ids_json).unwrap_or_default();
416
417        let properties_json: Option<String> = row.get(5)?;
418        let properties: HashMap<String, serde_json::Value> = properties_json
419            .as_deref()
420            .and_then(|s| serde_json::from_str(s).ok())
421            .unwrap_or_default();
422
423        let arity = entity_ids.len();
424        Ok(Hyperedge {
425            id: Some(row.get(0)?),
426            hyperedge_type: row.get(1)?,
427            entity_ids,
428            weight: row.get(3)?,
429            arity,
430            properties,
431            created_at: row.get(6)?,
432            updated_at: row.get(7)?,
433        })
434    })?;
435
436    let mut result = Vec::new();
437    for row in rows {
438        result.push(row?);
439    }
440    Ok(result)
441}
442
443/// Higher-order BFS traversal through hyperedges.
444pub fn higher_order_bfs(
445    conn: &rusqlite::Connection,
446    start_id: i64,
447    max_depth: u32,
448    min_arity: Option<usize>,
449) -> Result<Vec<crate::graph::traversal::TraversalNode>> {
450    use crate::graph::traversal::TraversalNode;
451
452    if max_depth == 0 {
453        return Ok(Vec::new());
454    }
455    if max_depth > 10 {
456        return Err(Error::InvalidDepth(max_depth));
457    }
458
459    get_entity(conn, start_id)?;
460
461    let mut visited = HashSet::new();
462    let mut queue = VecDeque::new();
463    let mut result = Vec::new();
464
465    visited.insert(start_id);
466    queue.push_back((start_id, 0u32));
467
468    while let Some((current_id, depth)) = queue.pop_front() {
469        if depth >= max_depth {
470            continue;
471        }
472
473        let neighbors = get_higher_order_neighbors(conn, current_id, min_arity, None)?;
474
475        for neighbor in neighbors {
476            let neighbor_id = neighbor.entity.id.ok_or(Error::EntityNotFound(0))?;
477            if !visited.contains(&neighbor_id) {
478                visited.insert(neighbor_id);
479                queue.push_back((neighbor_id, depth + 1));
480                result.push(TraversalNode {
481                    entity_id: neighbor_id,
482                    entity_type: neighbor.entity.entity_type.clone(),
483                    depth: depth + 1,
484                });
485            }
486        }
487    }
488
489    Ok(result)
490}
491
492/// Find shortest path between two entities through hyperedges.
493pub fn higher_order_shortest_path(
494    conn: &rusqlite::Connection,
495    from_id: i64,
496    to_id: i64,
497    max_depth: u32,
498) -> Result<Option<HigherOrderPath>> {
499    if max_depth == 0 {
500        return Ok(None);
501    }
502    if max_depth > 10 {
503        return Err(Error::InvalidDepth(max_depth));
504    }
505
506    get_entity(conn, from_id)?;
507    get_entity(conn, to_id)?;
508
509    if from_id == to_id {
510        return Ok(Some(HigherOrderPath {
511            steps: Vec::new(),
512            total_weight: 0.0,
513        }));
514    }
515
516    let mut visited = HashSet::new();
517    let mut queue: VecDeque<(i64, u32)> = VecDeque::new();
518    // parent map: entity_id -> (parent_entity_id, hyperedge used)
519    let mut parent: HashMap<i64, (i64, Hyperedge)> = HashMap::new();
520
521    visited.insert(from_id);
522    queue.push_back((from_id, 0));
523
524    while let Some((current_id, depth)) = queue.pop_front() {
525        if depth >= max_depth {
526            continue;
527        }
528
529        let neighbors = get_higher_order_neighbors(conn, current_id, None, None)?;
530
531        for neighbor in neighbors {
532            let neighbor_id = neighbor.entity.id.ok_or(Error::EntityNotFound(0))?;
533            if !visited.contains(&neighbor_id) {
534                visited.insert(neighbor_id);
535                parent.insert(neighbor_id, (current_id, neighbor.hyperedge));
536                if neighbor_id == to_id {
537                    // Reconstruct path
538                    return Ok(Some(reconstruct_path(&parent, from_id, to_id)));
539                }
540                queue.push_back((neighbor_id, depth + 1));
541            }
542        }
543    }
544
545    Ok(None)
546}
547
548fn reconstruct_path(
549    parent: &HashMap<i64, (i64, Hyperedge)>,
550    from_id: i64,
551    to_id: i64,
552) -> HigherOrderPath {
553    let mut steps = Vec::new();
554    let mut current = to_id;
555    let mut total_weight = 0.0;
556
557    while current != from_id {
558        // parent was populated for every node we visited, so this entry
559        // is guaranteed to exist; using if-let avoids an unwrap panic.
560        if let Some((prev, hyperedge)) = parent.get(&current) {
561            total_weight += hyperedge.weight;
562            steps.push(HigherOrderPathStep {
563                hyperedge: hyperedge.clone(),
564                from_entity: *prev,
565                to_entity: current,
566            });
567            current = *prev;
568        } else {
569            break; // defensive: should never happen in a well-formed graph
570        }
571    }
572
573    steps.reverse();
574    HigherOrderPath {
575        steps,
576        total_weight,
577    }
578}
579
580/// Compute hyperedge degree centrality for an entity.
581pub fn hyperedge_degree(conn: &rusqlite::Connection, entity_id: i64) -> Result<f64> {
582    get_entity(conn, entity_id)?;
583
584    let count: i64 = conn.query_row(
585        "SELECT COUNT(DISTINCT hyperedge_id) FROM kg_hyperedge_entities WHERE entity_id = ?1",
586        params![entity_id],
587        |row| row.get(0),
588    )?;
589
590    Ok(count as f64)
591}
592
593/// Load all hyperedges from the database.
594pub fn load_all_hyperedges(conn: &rusqlite::Connection) -> Result<Vec<Hyperedge>> {
595    list_hyperedges(conn, None, None, None, None)
596}
597
598/// Compute entity-level hypergraph PageRank using Zhou formula.
599///
600/// Based on Zhou et al. (2006) - "Learning with Hypergraphs".
601///
602/// PR(v) = (1-d)/n + d * sum_{e: v in e} [w(e)/delta(e) * sum_{u in e, u!=v} PR(u) * (1/d(u)) * (1/delta(e))]
603///
604/// Simplified: PR(v) = (1-d)/n + d * sum_{e: v in e} [w(e)/delta(e)^2 * sum_{u in e, u!=v} PR(u)/d(u)]
605///
606/// Complexity: O(T * sum_e k_e^2), much faster than naive O(n^2) approaches.
607pub fn hypergraph_entity_pagerank(
608    conn: &rusqlite::Connection,
609    damping: f64,
610    max_iter: usize,
611    tolerance: f64,
612) -> Result<HashMap<i64, f64>> {
613    let hyperedges = load_all_hyperedges(conn)?;
614
615    if hyperedges.is_empty() {
616        return Ok(HashMap::new());
617    }
618
619    // Collect all entity IDs that appear in hyperedges
620    let mut all_entities: HashSet<i64> = HashSet::new();
621    for he in &hyperedges {
622        for &eid in &he.entity_ids {
623            all_entities.insert(eid);
624        }
625    }
626
627    let n = all_entities.len() as f64;
628    if n == 0.0 {
629        return Ok(HashMap::new());
630    }
631
632    // Compute hyperedge degree d(v) for each entity
633    // d(v) = number of hyperedges containing v
634    let mut entity_degree: HashMap<i64, usize> = HashMap::new();
635    for he in &hyperedges {
636        for &eid in &he.entity_ids {
637            *entity_degree.entry(eid).or_insert(0) += 1;
638        }
639    }
640
641    // Initialize PageRank scores uniformly
642    let mut scores: HashMap<i64, f64> = all_entities.iter().map(|&id| (id, 1.0 / n)).collect();
643
644    // Iterative update using Zhou formula
645    for _ in 0..max_iter {
646        let mut new_scores: HashMap<i64, f64> = HashMap::new();
647
648        // Initialize with random jump term
649        for &eid in &all_entities {
650            new_scores.insert(eid, (1.0 - damping) / n);
651        }
652
653        // For each hyperedge e, compute contribution to its entities
654        for he in &hyperedges {
655            let w_e = he.weight;
656            let delta_e = he.arity as f64;
657            // Zhou formula uses 1/delta(e) for each vertex in the hyperedge
658            let inv_delta = 1.0 / delta_e;
659
660            // Compute sum of PR(u)/d(u) for all u in e
661            let sum_pr_d: f64 = he
662                .entity_ids
663                .iter()
664                .map(|&u| {
665                    let d_u = *entity_degree.get(&u).unwrap_or(&1) as f64;
666                    let pr_u = scores.get(&u).copied().unwrap_or(0.0);
667                    pr_u / d_u
668                })
669                .sum();
670
671            // For each v in e, add contribution from other vertices
672            for &v in &he.entity_ids {
673                let d_v = *entity_degree.get(&v).unwrap_or(&1) as f64;
674                let pr_v = scores.get(&v).copied().unwrap_or(0.0);
675
676                // Subtract v's own contribution to get sum of u != v
677                let sum_pr_d_excluding_v = sum_pr_d - pr_v / d_v;
678
679                // Zhou formula: w(e) / delta(e)^2 * sum_{u != v} PR(u)/d(u)
680                let contribution = damping * w_e * inv_delta * inv_delta * sum_pr_d_excluding_v;
681
682                *new_scores.entry(v).or_insert(0.0) += contribution;
683            }
684        }
685
686        // Normalize scores to ensure sum = 1.0
687        let total: f64 = new_scores.values().sum();
688        if total > 0.0 {
689            for score in new_scores.values_mut() {
690                *score /= total;
691            }
692        }
693
694        // Check convergence
695        let diff: f64 = all_entities
696            .iter()
697            .map(|id| (new_scores.get(id).unwrap_or(&0.0) - scores.get(id).unwrap_or(&0.0)).abs())
698            .sum();
699
700        scores = new_scores;
701
702        if diff < tolerance {
703            break;
704        }
705    }
706
707    Ok(scores)
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use crate::graph::entity::insert_entity;
714    use rusqlite::Connection;
715
716    fn setup_db() -> Connection {
717        let conn = Connection::open_in_memory().unwrap();
718        conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
719        crate::schema::create_schema(&conn).unwrap();
720        conn
721    }
722
723    fn create_test_entities(conn: &Connection, count: usize) -> Vec<i64> {
724        (0..count)
725            .map(|i| insert_entity(conn, &Entity::new("person", format!("Person {i}"))).unwrap())
726            .collect()
727    }
728
729    #[test]
730    fn test_hyperedge_creation() {
731        let he = Hyperedge::new(vec![1, 2, 3], "collaboration", 0.8).unwrap();
732        assert_eq!(he.arity, 3);
733        assert!(he.contains(1));
734        assert!(he.contains(2));
735        assert!(he.contains(3));
736        assert!(!he.contains(4));
737    }
738
739    #[test]
740    fn test_hyperedge_invalid_arity() {
741        let result = Hyperedge::new(vec![1], "test", 0.5);
742        assert!(result.is_err());
743
744        let result = Hyperedge::new(vec![], "test", 0.5);
745        assert!(result.is_err());
746    }
747
748    #[test]
749    fn test_hyperedge_invalid_weight() {
750        let result = Hyperedge::new(vec![1, 2], "test", 1.5);
751        assert!(result.is_err());
752
753        let result = Hyperedge::new(vec![1, 2], "test", -0.1);
754        assert!(result.is_err());
755    }
756
757    #[test]
758    fn test_hyperedge_intersection() {
759        let he1 = Hyperedge::new(vec![1, 2, 3], "a", 0.5).unwrap();
760        let he2 = Hyperedge::new(vec![2, 3, 4], "b", 0.5).unwrap();
761        let mut inter = he1.intersection(&he2);
762        inter.sort();
763        assert_eq!(inter, vec![2, 3]);
764        assert!(he1.has_intersection(&he2));
765    }
766
767    #[test]
768    fn test_hyperedge_no_intersection() {
769        let he1 = Hyperedge::new(vec![1, 2], "a", 0.5).unwrap();
770        let he2 = Hyperedge::new(vec![3, 4], "b", 0.5).unwrap();
771        assert!(he1.intersection(&he2).is_empty());
772        assert!(!he1.has_intersection(&he2));
773    }
774
775    #[test]
776    fn test_insert_and_get_hyperedge() {
777        let conn = setup_db();
778        let ids = create_test_entities(&conn, 3);
779
780        let he = Hyperedge::new(ids.clone(), "collaboration", 0.8).unwrap();
781        let he_id = insert_hyperedge(&conn, &he).unwrap();
782        assert!(he_id > 0);
783
784        let retrieved = get_hyperedge(&conn, he_id).unwrap();
785        assert_eq!(retrieved.arity, 3);
786        assert_eq!(retrieved.hyperedge_type, "collaboration");
787        assert_eq!(retrieved.entity_ids, ids);
788        assert!((retrieved.weight - 0.8).abs() < f64::EPSILON);
789    }
790
791    #[test]
792    fn test_list_hyperedges() {
793        let conn = setup_db();
794        let ids = create_test_entities(&conn, 5);
795
796        insert_hyperedge(
797            &conn,
798            &Hyperedge::new(ids[0..3].to_vec(), "team", 0.9).unwrap(),
799        )
800        .unwrap();
801        insert_hyperedge(
802            &conn,
803            &Hyperedge::new(ids[2..5].to_vec(), "team", 0.8).unwrap(),
804        )
805        .unwrap();
806        insert_hyperedge(&conn, &Hyperedge::new(ids.clone(), "project", 0.7).unwrap()).unwrap();
807
808        let all = list_hyperedges(&conn, None, None, None, None).unwrap();
809        assert_eq!(all.len(), 3);
810
811        let teams = list_hyperedges(&conn, Some("team"), None, None, None).unwrap();
812        assert_eq!(teams.len(), 2);
813
814        let big = list_hyperedges(&conn, None, Some(4), None, None).unwrap();
815        assert_eq!(big.len(), 1);
816    }
817
818    #[test]
819    fn test_update_hyperedge() {
820        let conn = setup_db();
821        let ids = create_test_entities(&conn, 4);
822
823        let he = Hyperedge::new(ids[0..3].to_vec(), "team", 0.9).unwrap();
824        let he_id = insert_hyperedge(&conn, &he).unwrap();
825
826        let mut updated = get_hyperedge(&conn, he_id).unwrap();
827        updated.entity_ids = ids.clone();
828        updated.arity = ids.len();
829        updated.weight = 0.7;
830        update_hyperedge(&conn, &updated).unwrap();
831
832        let retrieved = get_hyperedge(&conn, he_id).unwrap();
833        assert_eq!(retrieved.arity, 4);
834        assert!((retrieved.weight - 0.7).abs() < f64::EPSILON);
835    }
836
837    #[test]
838    fn test_delete_hyperedge() {
839        let conn = setup_db();
840        let ids = create_test_entities(&conn, 3);
841
842        let he = Hyperedge::new(ids, "team", 0.9).unwrap();
843        let he_id = insert_hyperedge(&conn, &he).unwrap();
844
845        delete_hyperedge(&conn, he_id).unwrap();
846        assert!(get_hyperedge(&conn, he_id).is_err());
847    }
848
849    #[test]
850    fn test_delete_hyperedge_not_found() {
851        let conn = setup_db();
852        assert!(delete_hyperedge(&conn, 999).is_err());
853    }
854
855    #[test]
856    fn test_higher_order_neighbors() {
857        let conn = setup_db();
858        let ids = create_test_entities(&conn, 5);
859
860        // Team 1: Person 0, 1, 2
861        insert_hyperedge(
862            &conn,
863            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
864        )
865        .unwrap();
866
867        // Team 2: Person 2, 3, 4
868        insert_hyperedge(
869            &conn,
870            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
871        )
872        .unwrap();
873
874        // Neighbors of Person 0 through hyperedges
875        let neighbors = get_higher_order_neighbors(&conn, ids[0], None, None).unwrap();
876        assert_eq!(neighbors.len(), 2); // Person 1, Person 2
877
878        let neighbor_ids: HashSet<i64> = neighbors.iter().map(|n| n.entity.id.unwrap()).collect();
879        assert!(neighbor_ids.contains(&ids[1]));
880        assert!(neighbor_ids.contains(&ids[2]));
881    }
882
883    #[test]
884    fn test_entity_hyperedges() {
885        let conn = setup_db();
886        let ids = create_test_entities(&conn, 4);
887
888        insert_hyperedge(
889            &conn,
890            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
891        )
892        .unwrap();
893        insert_hyperedge(
894            &conn,
895            &Hyperedge::new(vec![ids[0], ids[3]], "pair", 0.5).unwrap(),
896        )
897        .unwrap();
898
899        let hyperedges = get_entity_hyperedges(&conn, ids[0]).unwrap();
900        assert_eq!(hyperedges.len(), 2);
901    }
902
903    #[test]
904    fn test_higher_order_bfs() {
905        let conn = setup_db();
906        let ids = create_test_entities(&conn, 5);
907
908        // Chain: {0,1,2} -- {2,3,4}
909        insert_hyperedge(
910            &conn,
911            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
912        )
913        .unwrap();
914        insert_hyperedge(
915            &conn,
916            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
917        )
918        .unwrap();
919
920        let traversal = higher_order_bfs(&conn, ids[0], 2, None).unwrap();
921        let traversed_ids: HashSet<i64> = traversal.iter().map(|n| n.entity_id).collect();
922
923        // Should reach all other entities through the chain
924        assert!(traversed_ids.contains(&ids[1]));
925        assert!(traversed_ids.contains(&ids[2]));
926        assert!(traversed_ids.contains(&ids[3]));
927        assert!(traversed_ids.contains(&ids[4]));
928    }
929
930    #[test]
931    fn test_higher_order_shortest_path() {
932        let conn = setup_db();
933        let ids = create_test_entities(&conn, 5);
934
935        // {0,1,2} and {2,3,4}
936        insert_hyperedge(
937            &conn,
938            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
939        )
940        .unwrap();
941        insert_hyperedge(
942            &conn,
943            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
944        )
945        .unwrap();
946
947        // Path from 0 to 4 should go through entity 2
948        let path = higher_order_shortest_path(&conn, ids[0], ids[4], 5)
949            .unwrap()
950            .unwrap();
951        assert_eq!(path.steps.len(), 2);
952
953        // No path if max_depth is too small
954        let path = higher_order_shortest_path(&conn, ids[0], ids[4], 0).unwrap();
955        assert!(path.is_none());
956    }
957
958    #[test]
959    fn test_hyperedge_degree() {
960        let conn = setup_db();
961        let ids = create_test_entities(&conn, 4);
962
963        insert_hyperedge(
964            &conn,
965            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
966        )
967        .unwrap();
968        insert_hyperedge(
969            &conn,
970            &Hyperedge::new(vec![ids[0], ids[3]], "pair", 0.5).unwrap(),
971        )
972        .unwrap();
973
974        assert!((hyperedge_degree(&conn, ids[0]).unwrap() - 2.0).abs() < f64::EPSILON);
975        assert!((hyperedge_degree(&conn, ids[1]).unwrap() - 1.0).abs() < f64::EPSILON);
976        assert!((hyperedge_degree(&conn, ids[3]).unwrap() - 1.0).abs() < f64::EPSILON);
977    }
978
979    #[test]
980    fn test_hypergraph_entity_pagerank() {
981        let conn = setup_db();
982        let ids = create_test_entities(&conn, 5);
983
984        // {0,1,2} and {2,3,4} - entity 2 is the bridge
985        insert_hyperedge(
986            &conn,
987            &Hyperedge::new(vec![ids[0], ids[1], ids[2]], "team", 0.9).unwrap(),
988        )
989        .unwrap();
990        insert_hyperedge(
991            &conn,
992            &Hyperedge::new(vec![ids[2], ids[3], ids[4]], "team", 0.8).unwrap(),
993        )
994        .unwrap();
995
996        let scores = hypergraph_entity_pagerank(&conn, 0.85, 100, 1e-6).unwrap();
997
998        // All 5 entities should have scores
999        assert_eq!(scores.len(), 5);
1000
1001        // Entity 2 (bridge) should have highest score
1002        let score_2 = scores[&ids[2]];
1003        for &id in &ids {
1004            if id != ids[2] {
1005                assert!(
1006                    score_2 >= scores[&id],
1007                    "Bridge entity should have highest PageRank"
1008                );
1009            }
1010        }
1011
1012        // Scores should sum to approximately 1.0
1013        let total: f64 = scores.values().sum();
1014        assert!(
1015            (total - 1.0).abs() < 0.01,
1016            "PageRank scores should sum to ~1.0, got {total}"
1017        );
1018    }
1019
1020    #[test]
1021    fn test_hypergraph_pagerank_empty() {
1022        let conn = setup_db();
1023        let scores = hypergraph_entity_pagerank(&conn, 0.85, 100, 1e-6).unwrap();
1024        assert!(scores.is_empty());
1025    }
1026
1027    #[test]
1028    fn test_hyperedge_properties() {
1029        let conn = setup_db();
1030        let ids = create_test_entities(&conn, 3);
1031
1032        let mut he = Hyperedge::new(ids, "team", 0.9).unwrap();
1033        he.set_property("project", serde_json::json!("Alpha"));
1034        he.set_property("start_date", serde_json::json!("2026-01-01"));
1035
1036        let he_id = insert_hyperedge(&conn, &he).unwrap();
1037        let retrieved = get_hyperedge(&conn, he_id).unwrap();
1038
1039        assert_eq!(
1040            retrieved.get_property("project"),
1041            Some(&serde_json::json!("Alpha"))
1042        );
1043        assert_eq!(
1044            retrieved.get_property("start_date"),
1045            Some(&serde_json::json!("2026-01-01"))
1046        );
1047    }
1048}