sqlite_knowledge_graph/algorithms/
pagerank.rs1use crate::error::Result;
2use rusqlite::Connection;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct PageRankConfig {
9 pub damping: f64,
11 pub max_iterations: usize,
13 pub tolerance: f64,
15}
16
17impl Default for PageRankConfig {
18 fn default() -> Self {
19 Self {
20 damping: 0.85,
21 max_iterations: 100,
22 tolerance: 1e-6,
23 }
24 }
25}
26
27pub fn pagerank(conn: &Connection, config: PageRankConfig) -> Result<Vec<(i64, f64)>> {
31 let mut out_edges: HashMap<i64, Vec<i64>> = HashMap::new();
33 let mut in_edges: HashMap<i64, Vec<i64>> = HashMap::new();
34 let mut all_nodes: HashSet<i64> = HashSet::new();
35
36 let mut stmt = conn.prepare("SELECT source_id, target_id FROM kg_relations")?;
37
38 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
39
40 for row in rows {
41 let (from, to) = row?;
42 all_nodes.insert(from);
43 all_nodes.insert(to);
44 out_edges.entry(from).or_default().push(to);
45 in_edges.entry(to).or_default().push(from);
46 }
47
48 if all_nodes.is_empty() {
49 return Ok(Vec::new());
50 }
51
52 let n = all_nodes.len() as f64;
53 let initial_score = 1.0 / n;
54
55 let mut scores: HashMap<i64, f64> = all_nodes.iter().map(|&id| (id, initial_score)).collect();
57 let mut new_scores: HashMap<i64, f64> = HashMap::new();
58
59 for _ in 0..config.max_iterations {
61 let dangling_sum: f64 = all_nodes
62 .iter()
63 .filter(|&&id| match out_edges.get(&id) {
64 None => true,
65 Some(edges) => edges.is_empty(),
66 })
67 .map(|&id| scores[&id])
68 .sum();
69
70 for &node in &all_nodes {
71 let incoming_score: f64 = in_edges
72 .get(&node)
73 .map(|edges| {
74 edges
75 .iter()
76 .map(|&from| {
77 let out_degree = out_edges.get(&from).map_or(1, |e| e.len()) as f64;
78 scores[&from] / out_degree
79 })
80 .sum()
81 })
82 .unwrap_or(0.0);
83
84 new_scores.insert(
85 node,
86 (1.0 - config.damping) / n + config.damping * (incoming_score + dangling_sum / n),
87 );
88 }
89
90 let diff: f64 = all_nodes
92 .iter()
93 .map(|&id| (scores[&id] - new_scores[&id]).abs())
94 .sum();
95
96 std::mem::swap(&mut scores, &mut new_scores);
97
98 if diff < config.tolerance {
99 break;
100 }
101 }
102
103 let mut result: Vec<(i64, f64)> = scores.into_iter().collect();
105 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
106
107 Ok(result)
108}
109
110use std::collections::HashSet;
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use rusqlite::Connection;
116
117 fn setup_test_db() -> Connection {
118 let conn = Connection::open_in_memory().unwrap();
119 crate::schema::create_schema(&conn).unwrap();
120
121 use crate::graph::entity::{insert_entity, Entity};
123 use crate::graph::relation::{insert_relation, Relation};
124 let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
125 let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
126 let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
127 let _id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
128 insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
129 insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
130 insert_relation(&conn, &Relation::new(id1, id3, "link", 1.0).unwrap()).unwrap();
131
132 conn
133 }
134
135 #[test]
136 fn test_pagerank() {
137 let conn = setup_test_db();
138 let result = pagerank(&conn, PageRankConfig::default()).unwrap();
139
140 assert_eq!(result.len(), 3);
142 assert!(result[0].1 >= result[1].1);
144 assert!(result[1].1 >= result[2].1);
145 }
146
147 #[test]
148 fn test_pagerank_empty_graph() {
149 let conn = Connection::open_in_memory().unwrap();
150 crate::schema::create_schema(&conn).unwrap();
151
152 let result = pagerank(&conn, PageRankConfig::default()).unwrap();
153 assert!(result.is_empty());
154 }
155}