uni_algo/algo/algorithms/
katz_centrality.rs1use crate::algo::GraphProjection;
10use crate::algo::algorithms::Algorithm;
11use uni_common::core::id::Vid;
12
13pub struct KatzCentrality;
14
15#[derive(Debug, Clone)]
16pub struct KatzCentralityConfig {
17 pub alpha: f64,
18 pub beta: f64,
19 pub max_iterations: usize,
20 pub tolerance: f64,
21}
22
23impl Default for KatzCentralityConfig {
24 fn default() -> Self {
25 Self {
26 alpha: 0.1, beta: 1.0,
28 max_iterations: 100,
29 tolerance: 1e-6,
30 }
31 }
32}
33
34pub struct KatzCentralityResult {
35 pub scores: Vec<(Vid, f64)>,
36 pub iterations: usize,
37}
38
39impl Algorithm for KatzCentrality {
40 type Config = KatzCentralityConfig;
41 type Result = KatzCentralityResult;
42
43 fn name() -> &'static str {
44 "katz_centrality"
45 }
46
47 fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
48 let n = graph.vertex_count();
49 if n == 0 {
50 return KatzCentralityResult {
51 scores: Vec::new(),
52 iterations: 0,
53 };
54 }
55
56 let mut x = vec![config.beta; n]; let mut next_x = vec![0.0; n];
58 let mut iterations = 0;
59
60 for iter in 0..config.max_iterations {
61 iterations = iter + 1;
62 next_x.fill(config.beta);
65
66 for (u, &x_u) in x.iter().enumerate().take(n) {
67 if x_u == 0.0 {
68 continue;
69 }
70 for (i, &v_u32) in graph.out_neighbors(u as u32).iter().enumerate() {
71 let weight = if graph.has_weights() {
72 graph.out_weight(u as u32, i)
73 } else {
74 1.0
75 };
76 next_x[v_u32 as usize] += config.alpha * x_u * weight;
77 }
78 }
79
80 let mut diff = 0.0;
86 for i in 0..n {
87 diff += (next_x[i] - x[i]).abs();
88 }
89
90 x.copy_from_slice(&next_x);
91
92 if diff < config.tolerance {
93 break;
94 }
95 }
96
97 let scores = x
98 .into_iter()
99 .enumerate()
100 .map(|(i, s)| (graph.to_vid(i as u32), s))
101 .collect();
102
103 KatzCentralityResult { scores, iterations }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::algo::test_utils::build_test_graph;
111
112 #[test]
113 fn test_katz_centrality_dag() {
114 let vids = vec![Vid::from(0), Vid::from(1)];
120 let edges = vec![(Vid::from(1), Vid::from(0))];
121 let graph = build_test_graph(vids, edges);
122
123 let config = KatzCentralityConfig {
124 alpha: 0.1,
125 beta: 1.0,
126 ..Default::default()
127 };
128
129 let result = KatzCentrality::run(&graph, config);
130 let map: std::collections::HashMap<_, _> = result.scores.into_iter().collect();
131
132 assert!((map[&Vid::from(1)] - 1.0).abs() < 1e-6);
133 assert!((map[&Vid::from(0)] - 1.1).abs() < 1e-6);
134 }
135}