sqlite_knowledge_graph/algorithms/
pagerank.rs1use crate::error::Result;
2use rusqlite::Connection;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct PageRankConfig {
9 pub damping: f64,
11 pub max_iterations: usize,
13 pub tolerance: f64,
15}
16
17impl Default for PageRankConfig {
18 fn default() -> Self {
19 Self {
20 damping: 0.85,
21 max_iterations: 100,
22 tolerance: 1e-6,
23 }
24 }
25}
26
27pub fn pagerank(conn: &Connection, config: PageRankConfig) -> Result<Vec<(i64, f64)>> {
31 let mut out_edges: HashMap<i64, Vec<i64>> = HashMap::new();
33 let mut in_edges: HashMap<i64, Vec<i64>> = HashMap::new();
34 let mut all_nodes: HashSet<i64> = HashSet::new();
35
36 let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
37
38 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
39
40 for row in rows {
41 let (from, to) = row?;
42 all_nodes.insert(from);
43 all_nodes.insert(to);
44 out_edges.entry(from).or_default().push(to);
45 in_edges.entry(to).or_default().push(from);
46 }
47
48 if all_nodes.is_empty() {
49 return Ok(Vec::new());
50 }
51
52 let n = all_nodes.len() as f64;
53 let initial_score = 1.0 / n;
54
55 let mut scores: HashMap<i64, f64> = all_nodes.iter().map(|&id| (id, initial_score)).collect();
57 let mut new_scores: HashMap<i64, f64> = HashMap::new();
58
59 for _ in 0..config.max_iterations {
61 let dangling_sum: f64 = all_nodes
62 .iter()
63 .filter(|&&id| out_edges.get(&id).is_none_or(|edges| edges.is_empty()))
64 .map(|&id| scores[&id])
65 .sum();
66
67 for &node in &all_nodes {
68 let incoming_score: f64 = in_edges
69 .get(&node)
70 .map(|edges| {
71 edges
72 .iter()
73 .map(|&from| {
74 let out_degree = out_edges.get(&from).map_or(1, |e| e.len()) as f64;
75 scores[&from] / out_degree
76 })
77 .sum()
78 })
79 .unwrap_or(0.0);
80
81 new_scores.insert(
82 node,
83 (1.0 - config.damping) / n + config.damping * (incoming_score + dangling_sum / n),
84 );
85 }
86
87 let diff: f64 = all_nodes
89 .iter()
90 .map(|&id| (scores[&id] - new_scores[&id]).abs())
91 .sum();
92
93 std::mem::swap(&mut scores, &mut new_scores);
94
95 if diff < config.tolerance {
96 break;
97 }
98 }
99
100 let mut result: Vec<(i64, f64)> = scores.into_iter().collect();
102 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
103
104 Ok(result)
105}
106
107use std::collections::HashSet;
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use rusqlite::Connection;
113
114 fn setup_test_db() -> Connection {
115 let conn = Connection::open_in_memory().unwrap();
116
117 conn.execute_batch(
118 "CREATE TABLE entities (id INTEGER PRIMARY KEY);
119 CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER NOT NULL, to_id INTEGER NOT NULL, relation_type TEXT, weight REAL);"
120 ).unwrap();
121
122 conn.execute("INSERT INTO entities (id) VALUES (1), (2), (3), (4)", [])
124 .unwrap();
125 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
126 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
127 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 3, 'link', 1.0)", []).unwrap();
128
129 conn
130 }
131
132 #[test]
133 fn test_pagerank() {
134 let conn = setup_test_db();
135 let result = pagerank(&conn, PageRankConfig::default()).unwrap();
136
137 assert_eq!(result.len(), 3);
139
140 assert!(result.iter().any(|(id, _)| *id == 1));
142 assert!(result.iter().any(|(id, _)| *id == 2));
143 assert!(result.iter().any(|(id, _)| *id == 3));
144 }
145
146 #[test]
147 fn test_pagerank_empty_graph() {
148 let conn = Connection::open_in_memory().unwrap();
149 conn.execute_batch("CREATE TABLE entities (id INTEGER PRIMARY KEY); CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);").unwrap();
150
151 let result = pagerank(&conn, PageRankConfig::default()).unwrap();
152 assert!(result.is_empty());
153 }
154}