Skip to main content

sqlite_knowledge_graph/algorithms/
pagerank.rs

1use crate::error::Result;
2/// PageRank algorithm implementation
3use rusqlite::Connection;
4use std::collections::HashMap;
5
6/// PageRank configuration
7#[derive(Debug, Clone)]
8pub struct PageRankConfig {
9    /// Damping factor (typically 0.85)
10    pub damping: f64,
11    /// Maximum iterations
12    pub max_iterations: usize,
13    /// Convergence threshold
14    pub tolerance: f64,
15}
16
17impl Default for PageRankConfig {
18    fn default() -> Self {
19        Self {
20            damping: 0.85,
21            max_iterations: 100,
22            tolerance: 1e-6,
23        }
24    }
25}
26
27/// Compute PageRank scores for all entities
28///
29/// Returns a vector of (entity_id, score) sorted by score descending.
30pub fn pagerank(conn: &Connection, config: PageRankConfig) -> Result<Vec<(i64, f64)>> {
31    // Build adjacency list from relations
32    let mut out_edges: HashMap<i64, Vec<i64>> = HashMap::new();
33    let mut in_edges: HashMap<i64, Vec<i64>> = HashMap::new();
34    let mut all_nodes: HashSet<i64> = HashSet::new();
35
36    let mut stmt = conn.prepare("SELECT source_id, target_id FROM kg_relations")?;
37
38    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
39
40    for row in rows {
41        let (from, to) = row?;
42        all_nodes.insert(from);
43        all_nodes.insert(to);
44        out_edges.entry(from).or_default().push(to);
45        in_edges.entry(to).or_default().push(from);
46    }
47
48    if all_nodes.is_empty() {
49        return Ok(Vec::new());
50    }
51
52    let n = all_nodes.len() as f64;
53    let initial_score = 1.0 / n;
54
55    // Initialize scores
56    let mut scores: HashMap<i64, f64> = all_nodes.iter().map(|&id| (id, initial_score)).collect();
57    let mut new_scores: HashMap<i64, f64> = HashMap::new();
58
59    // Iterate until convergence
60    for _ in 0..config.max_iterations {
61        let dangling_sum: f64 = all_nodes
62            .iter()
63            .filter(|&&id| match out_edges.get(&id) {
64                None => true,
65                Some(edges) => edges.is_empty(),
66            })
67            .map(|&id| scores[&id])
68            .sum();
69
70        for &node in &all_nodes {
71            let incoming_score: f64 = in_edges
72                .get(&node)
73                .map(|edges| {
74                    edges
75                        .iter()
76                        .map(|&from| {
77                            let out_degree = out_edges.get(&from).map_or(1, |e| e.len()) as f64;
78                            scores[&from] / out_degree
79                        })
80                        .sum()
81                })
82                .unwrap_or(0.0);
83
84            new_scores.insert(
85                node,
86                (1.0 - config.damping) / n + config.damping * (incoming_score + dangling_sum / n),
87            );
88        }
89
90        // Check convergence
91        let diff: f64 = all_nodes
92            .iter()
93            .map(|&id| (scores[&id] - new_scores[&id]).abs())
94            .sum();
95
96        std::mem::swap(&mut scores, &mut new_scores);
97
98        if diff < config.tolerance {
99            break;
100        }
101    }
102
103    // Sort by score descending
104    let mut result: Vec<(i64, f64)> = scores.into_iter().collect();
105    result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106
107    Ok(result)
108}
109
110use std::collections::HashSet;
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use rusqlite::Connection;
116
117    fn setup_test_db() -> Connection {
118        let conn = Connection::open_in_memory().unwrap();
119        crate::schema::create_schema(&conn).unwrap();
120
121        // Create a simple graph: 1 -> 2 -> 3, 1 -> 3
122        use crate::graph::entity::{insert_entity, Entity};
123        use crate::graph::relation::{insert_relation, Relation};
124        let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
125        let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
126        let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
127        let _id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
128        insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
129        insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
130        insert_relation(&conn, &Relation::new(id1, id3, "link", 1.0).unwrap()).unwrap();
131
132        conn
133    }
134
135    #[test]
136    fn test_pagerank() {
137        let conn = setup_test_db();
138        let result = pagerank(&conn, PageRankConfig::default()).unwrap();
139
140        // Only nodes with relations are included (3 of the 4 entities)
141        assert_eq!(result.len(), 3);
142        // Results are sorted by score descending; just verify we got scores
143        assert!(result[0].1 >= result[1].1);
144        assert!(result[1].1 >= result[2].1);
145    }
146
147    #[test]
148    fn test_pagerank_empty_graph() {
149        let conn = Connection::open_in_memory().unwrap();
150        crate::schema::create_schema(&conn).unwrap();
151
152        let result = pagerank(&conn, PageRankConfig::default()).unwrap();
153        assert!(result.is_empty());
154    }
155}