Skip to main content

sqlite_knowledge_graph/algorithms/
pagerank.rs

1use crate::error::Result;
2/// PageRank algorithm implementation
3use rusqlite::Connection;
4use std::collections::HashMap;
5
6/// PageRank configuration
7#[derive(Debug, Clone)]
8pub struct PageRankConfig {
9    /// Damping factor (typically 0.85)
10    pub damping: f64,
11    /// Maximum iterations
12    pub max_iterations: usize,
13    /// Convergence threshold
14    pub tolerance: f64,
15}
16
17impl Default for PageRankConfig {
18    fn default() -> Self {
19        Self {
20            damping: 0.85,
21            max_iterations: 100,
22            tolerance: 1e-6,
23        }
24    }
25}
26
27/// Compute PageRank scores for all entities
28///
29/// Returns a vector of (entity_id, score) sorted by score descending.
30pub fn pagerank(conn: &Connection, config: PageRankConfig) -> Result<Vec<(i64, f64)>> {
31    // Build adjacency list from relations
32    let mut out_edges: HashMap<i64, Vec<i64>> = HashMap::new();
33    let mut in_edges: HashMap<i64, Vec<i64>> = HashMap::new();
34    let mut all_nodes: HashSet<i64> = HashSet::new();
35
36    let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
37
38    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
39
40    for row in rows {
41        let (from, to) = row?;
42        all_nodes.insert(from);
43        all_nodes.insert(to);
44        out_edges.entry(from).or_default().push(to);
45        in_edges.entry(to).or_default().push(from);
46    }
47
48    if all_nodes.is_empty() {
49        return Ok(Vec::new());
50    }
51
52    let n = all_nodes.len() as f64;
53    let initial_score = 1.0 / n;
54
55    // Initialize scores
56    let mut scores: HashMap<i64, f64> = all_nodes.iter().map(|&id| (id, initial_score)).collect();
57    let mut new_scores: HashMap<i64, f64> = HashMap::new();
58
59    // Iterate until convergence
60    for _ in 0..config.max_iterations {
61        let dangling_sum: f64 = all_nodes
62            .iter()
63            .filter(|&&id| out_edges.get(&id).is_none_or(|edges| edges.is_empty()))
64            .map(|&id| scores[&id])
65            .sum();
66
67        for &node in &all_nodes {
68            let incoming_score: f64 = in_edges
69                .get(&node)
70                .map(|edges| {
71                    edges
72                        .iter()
73                        .map(|&from| {
74                            let out_degree = out_edges.get(&from).map_or(1, |e| e.len()) as f64;
75                            scores[&from] / out_degree
76                        })
77                        .sum()
78                })
79                .unwrap_or(0.0);
80
81            new_scores.insert(
82                node,
83                (1.0 - config.damping) / n + config.damping * (incoming_score + dangling_sum / n),
84            );
85        }
86
87        // Check convergence
88        let diff: f64 = all_nodes
89            .iter()
90            .map(|&id| (scores[&id] - new_scores[&id]).abs())
91            .sum();
92
93        std::mem::swap(&mut scores, &mut new_scores);
94
95        if diff < config.tolerance {
96            break;
97        }
98    }
99
100    // Sort by score descending
101    let mut result: Vec<(i64, f64)> = scores.into_iter().collect();
102    result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103
104    Ok(result)
105}
106
107use std::collections::HashSet;
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use rusqlite::Connection;
113
114    fn setup_test_db() -> Connection {
115        let conn = Connection::open_in_memory().unwrap();
116
117        conn.execute_batch(
118            "CREATE TABLE entities (id INTEGER PRIMARY KEY);
119             CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER NOT NULL, to_id INTEGER NOT NULL, relation_type TEXT, weight REAL);"
120        ).unwrap();
121
122        // Create a simple graph: 1 -> 2 -> 3, 1 -> 3
123        conn.execute("INSERT INTO entities (id) VALUES (1), (2), (3), (4)", [])
124            .unwrap();
125        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
126        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
127        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 3, 'link', 1.0)", []).unwrap();
128
129        conn
130    }
131
132    #[test]
133    fn test_pagerank() {
134        let conn = setup_test_db();
135        let result = pagerank(&conn, PageRankConfig::default()).unwrap();
136
137        // Only nodes with relations are included (1, 2, 3)
138        assert_eq!(result.len(), 3);
139
140        // Node 3 has most incoming edges, should have highest score
141        assert!(result.iter().any(|(id, _)| *id == 1));
142        assert!(result.iter().any(|(id, _)| *id == 2));
143        assert!(result.iter().any(|(id, _)| *id == 3));
144    }
145
146    #[test]
147    fn test_pagerank_empty_graph() {
148        let conn = Connection::open_in_memory().unwrap();
149        conn.execute_batch("CREATE TABLE entities (id INTEGER PRIMARY KEY); CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);").unwrap();
150
151        let result = pagerank(&conn, PageRankConfig::default()).unwrap();
152        assert!(result.is_empty());
153    }
154}