Skip to main content

sqlite_knowledge_graph/algorithms/
pagerank.rs

1use crate::error::Result;
2/// PageRank algorithm implementation
3use rusqlite::Connection;
4use std::collections::HashMap;
5
6/// PageRank configuration
7#[derive(Debug, Clone)]
8pub struct PageRankConfig {
9    /// Damping factor (typically 0.85)
10    pub damping: f64,
11    /// Maximum iterations
12    pub max_iterations: usize,
13    /// Convergence threshold
14    pub tolerance: f64,
15}
16
17impl Default for PageRankConfig {
18    fn default() -> Self {
19        Self {
20            damping: 0.85,
21            max_iterations: 100,
22            tolerance: 1e-6,
23        }
24    }
25}
26
27/// Compute PageRank scores for all entities
28///
29/// Returns a vector of (entity_id, score) sorted by score descending.
30pub fn pagerank(conn: &Connection, config: PageRankConfig) -> Result<Vec<(i64, f64)>> {
31    // Build adjacency list from relations
32    let mut out_edges: HashMap<i64, Vec<i64>> = HashMap::new();
33    let mut in_edges: HashMap<i64, Vec<i64>> = HashMap::new();
34    let mut all_nodes: HashSet<i64> = HashSet::new();
35
36    let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
37
38    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
39
40    for row in rows {
41        let (from, to) = row?;
42        all_nodes.insert(from);
43        all_nodes.insert(to);
44        out_edges.entry(from).or_default().push(to);
45        in_edges.entry(to).or_default().push(from);
46    }
47
48    if all_nodes.is_empty() {
49        return Ok(Vec::new());
50    }
51
52    let n = all_nodes.len() as f64;
53    let initial_score = 1.0 / n;
54
55    // Initialize scores
56    let mut scores: HashMap<i64, f64> = all_nodes.iter().map(|&id| (id, initial_score)).collect();
57    let mut new_scores: HashMap<i64, f64> = HashMap::new();
58
59    // Iterate until convergence
60    for _ in 0..config.max_iterations {
61        let dangling_sum: f64 = all_nodes
62            .iter()
63            .filter(|&&id| match out_edges.get(&id) {
64                None => true,
65                Some(edges) => edges.is_empty(),
66            })
67            .map(|&id| scores[&id])
68            .sum();
69
70        for &node in &all_nodes {
71            let incoming_score: f64 = in_edges
72                .get(&node)
73                .map(|edges| {
74                    edges
75                        .iter()
76                        .map(|&from| {
77                            let out_degree = out_edges.get(&from).map_or(1, |e| e.len()) as f64;
78                            scores[&from] / out_degree
79                        })
80                        .sum()
81                })
82                .unwrap_or(0.0);
83
84            new_scores.insert(
85                node,
86                (1.0 - config.damping) / n + config.damping * (incoming_score + dangling_sum / n),
87            );
88        }
89
90        // Check convergence
91        let diff: f64 = all_nodes
92            .iter()
93            .map(|&id| (scores[&id] - new_scores[&id]).abs())
94            .sum();
95
96        std::mem::swap(&mut scores, &mut new_scores);
97
98        if diff < config.tolerance {
99            break;
100        }
101    }
102
103    // Sort by score descending
104    let mut result: Vec<(i64, f64)> = scores.into_iter().collect();
105    result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106
107    Ok(result)
108}
109
110use std::collections::HashSet;
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use rusqlite::Connection;
116
117    fn setup_test_db() -> Connection {
118        let conn = Connection::open_in_memory().unwrap();
119
120        conn.execute_batch(
121            "CREATE TABLE entities (id INTEGER PRIMARY KEY);
122             CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER NOT NULL, to_id INTEGER NOT NULL, relation_type TEXT, weight REAL);"
123        ).unwrap();
124
125        // Create a simple graph: 1 -> 2 -> 3, 1 -> 3
126        conn.execute("INSERT INTO entities (id) VALUES (1), (2), (3), (4)", [])
127            .unwrap();
128        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
129        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
130        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 3, 'link', 1.0)", []).unwrap();
131
132        conn
133    }
134
135    #[test]
136    fn test_pagerank() {
137        let conn = setup_test_db();
138        let result = pagerank(&conn, PageRankConfig::default()).unwrap();
139
140        // Only nodes with relations are included (1, 2, 3)
141        assert_eq!(result.len(), 3);
142
143        // Node 3 has most incoming edges, should have highest score
144        assert!(result.iter().any(|(id, _)| *id == 1));
145        assert!(result.iter().any(|(id, _)| *id == 2));
146        assert!(result.iter().any(|(id, _)| *id == 3));
147    }
148
149    #[test]
150    fn test_pagerank_empty_graph() {
151        let conn = Connection::open_in_memory().unwrap();
152        conn.execute_batch("CREATE TABLE entities (id INTEGER PRIMARY KEY); CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);").unwrap();
153
154        let result = pagerank(&conn, PageRankConfig::default()).unwrap();
155        assert!(result.is_empty());
156    }
157}