Skip to main content

uni_algo/algo/algorithms/
katz_centrality.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Katz Centrality Algorithm.
5//!
6//! Measures influence by taking into account total number of walks between nodes.
7//! x = alpha * A * x + beta.
8
9use crate::algo::GraphProjection;
10use crate::algo::algorithms::Algorithm;
11use uni_common::core::id::Vid;
12
13pub struct KatzCentrality;
14
15#[derive(Debug, Clone)]
16pub struct KatzCentralityConfig {
17    pub alpha: f64,
18    pub beta: f64,
19    pub max_iterations: usize,
20    pub tolerance: f64,
21}
22
23impl Default for KatzCentralityConfig {
24    fn default() -> Self {
25        Self {
26            alpha: 0.1, // Should be < 1/lambda_max
27            beta: 1.0,
28            max_iterations: 100,
29            tolerance: 1e-6,
30        }
31    }
32}
33
34pub struct KatzCentralityResult {
35    pub scores: Vec<(Vid, f64)>,
36    pub iterations: usize,
37}
38
39impl Algorithm for KatzCentrality {
40    type Config = KatzCentralityConfig;
41    type Result = KatzCentralityResult;
42
43    fn name() -> &'static str {
44        "katz_centrality"
45    }
46
47    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
48        let n = graph.vertex_count();
49        if n == 0 {
50            return KatzCentralityResult {
51                scores: Vec::new(),
52                iterations: 0,
53            };
54        }
55
56        let mut x = vec![config.beta; n]; // Start with beta
57        let mut next_x = vec![0.0; n];
58        let mut iterations = 0;
59
60        for iter in 0..config.max_iterations {
61            iterations = iter + 1;
62            // next_x = beta + alpha * A^T * x
63            // Initialize with beta
64            next_x.fill(config.beta);
65
66            for (u, &x_u) in x.iter().enumerate().take(n) {
67                if x_u == 0.0 {
68                    continue;
69                }
70                for (i, &v_u32) in graph.out_neighbors(u as u32).iter().enumerate() {
71                    let weight = if graph.has_weights() {
72                        graph.out_weight(u as u32, i)
73                    } else {
74                        1.0
75                    };
76                    next_x[v_u32 as usize] += config.alpha * x_u * weight;
77                }
78            }
79
80            // Normalize? Katz usually converges if alpha < 1/lambda.
81            // Normalization (L2) prevents overflow if alpha is large.
82            // Standard Katz doesn't normalize per step, but converges to (I - alpha*A)^-1 * beta.
83            // Let's check convergence.
84
85            let mut diff = 0.0;
86            for i in 0..n {
87                diff += (next_x[i] - x[i]).abs();
88            }
89
90            x.copy_from_slice(&next_x);
91
92            if diff < config.tolerance {
93                break;
94            }
95        }
96
97        let scores = x
98            .into_iter()
99            .enumerate()
100            .map(|(i, s)| (graph.to_vid(i as u32), s))
101            .collect();
102
103        KatzCentralityResult { scores, iterations }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::algo::test_utils::build_test_graph;
111
112    #[test]
113    fn test_katz_centrality_dag() {
114        // 1 -> 0
115        // x0 = beta + alpha * x1
116        // x1 = beta
117        // Expected: x1 = 1.0, x0 = 1.0 + 0.1 * 1.0 = 1.1
118
119        let vids = vec![Vid::from(0), Vid::from(1)];
120        let edges = vec![(Vid::from(1), Vid::from(0))];
121        let graph = build_test_graph(vids, edges);
122
123        let config = KatzCentralityConfig {
124            alpha: 0.1,
125            beta: 1.0,
126            ..Default::default()
127        };
128
129        let result = KatzCentrality::run(&graph, config);
130        let map: std::collections::HashMap<_, _> = result.scores.into_iter().collect();
131
132        assert!((map[&Vid::from(1)] - 1.0).abs() < 1e-6);
133        assert!((map[&Vid::from(0)] - 1.1).abs() < 1e-6);
134    }
135}