sqlite_knowledge_graph/algorithms/
pagerank.rs1use crate::error::Result;
2use rusqlite::Connection;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct PageRankConfig {
9 pub damping: f64,
11 pub max_iterations: usize,
13 pub tolerance: f64,
15}
16
17impl Default for PageRankConfig {
18 fn default() -> Self {
19 Self {
20 damping: 0.85,
21 max_iterations: 100,
22 tolerance: 1e-6,
23 }
24 }
25}
26
27pub fn pagerank(conn: &Connection, config: PageRankConfig) -> Result<Vec<(i64, f64)>> {
31 let mut out_edges: HashMap<i64, Vec<i64>> = HashMap::new();
33 let mut in_edges: HashMap<i64, Vec<i64>> = HashMap::new();
34 let mut all_nodes: HashSet<i64> = HashSet::new();
35
36 let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
37
38 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
39
40 for row in rows {
41 let (from, to) = row?;
42 all_nodes.insert(from);
43 all_nodes.insert(to);
44 out_edges.entry(from).or_default().push(to);
45 in_edges.entry(to).or_default().push(from);
46 }
47
48 if all_nodes.is_empty() {
49 return Ok(Vec::new());
50 }
51
52 let n = all_nodes.len() as f64;
53 let initial_score = 1.0 / n;
54
55 let mut scores: HashMap<i64, f64> = all_nodes.iter().map(|&id| (id, initial_score)).collect();
57 let mut new_scores: HashMap<i64, f64> = HashMap::new();
58
59 for _ in 0..config.max_iterations {
61 let dangling_sum: f64 = all_nodes
62 .iter()
63 .filter(|&&id| match out_edges.get(&id) {
64 None => true,
65 Some(edges) => edges.is_empty(),
66 })
67 .map(|&id| scores[&id])
68 .sum();
69
70 for &node in &all_nodes {
71 let incoming_score: f64 = in_edges
72 .get(&node)
73 .map(|edges| {
74 edges
75 .iter()
76 .map(|&from| {
77 let out_degree = out_edges.get(&from).map_or(1, |e| e.len()) as f64;
78 scores[&from] / out_degree
79 })
80 .sum()
81 })
82 .unwrap_or(0.0);
83
84 new_scores.insert(
85 node,
86 (1.0 - config.damping) / n + config.damping * (incoming_score + dangling_sum / n),
87 );
88 }
89
90 let diff: f64 = all_nodes
92 .iter()
93 .map(|&id| (scores[&id] - new_scores[&id]).abs())
94 .sum();
95
96 std::mem::swap(&mut scores, &mut new_scores);
97
98 if diff < config.tolerance {
99 break;
100 }
101 }
102
103 let mut result: Vec<(i64, f64)> = scores.into_iter().collect();
105 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106
107 Ok(result)
108}
109
110use std::collections::HashSet;
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use rusqlite::Connection;
116
117 fn setup_test_db() -> Connection {
118 let conn = Connection::open_in_memory().unwrap();
119
120 conn.execute_batch(
121 "CREATE TABLE entities (id INTEGER PRIMARY KEY);
122 CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER NOT NULL, to_id INTEGER NOT NULL, relation_type TEXT, weight REAL);"
123 ).unwrap();
124
125 conn.execute("INSERT INTO entities (id) VALUES (1), (2), (3), (4)", [])
127 .unwrap();
128 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
129 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
130 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 3, 'link', 1.0)", []).unwrap();
131
132 conn
133 }
134
135 #[test]
136 fn test_pagerank() {
137 let conn = setup_test_db();
138 let result = pagerank(&conn, PageRankConfig::default()).unwrap();
139
140 assert_eq!(result.len(), 3);
142
143 assert!(result.iter().any(|(id, _)| *id == 1));
145 assert!(result.iter().any(|(id, _)| *id == 2));
146 assert!(result.iter().any(|(id, _)| *id == 3));
147 }
148
149 #[test]
150 fn test_pagerank_empty_graph() {
151 let conn = Connection::open_in_memory().unwrap();
152 conn.execute_batch("CREATE TABLE entities (id INTEGER PRIMARY KEY); CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);").unwrap();
153
154 let result = pagerank(&conn, PageRankConfig::default()).unwrap();
155 assert!(result.is_empty());
156 }
157}