rustworkx_core/
centrality.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use std::collections::VecDeque;
14use std::hash::Hash;
15use std::sync::RwLock;
16
17use hashbrown::HashMap;
18use petgraph::algo::dijkstra;
19use petgraph::visit::{
20    EdgeCount,
21    EdgeIndexable,
22    EdgeRef,
23    GraphBase,
24    GraphProp, // allows is_directed
25    IntoEdges,
26    IntoEdgesDirected,
27    IntoNeighbors,
28    IntoNeighborsDirected,
29    IntoNodeIdentifiers,
30    NodeCount,
31    NodeIndexable,
32    Reversed,
33    Visitable,
34};
35use rayon_cond::CondIterator;
36
37/// Compute the betweenness centrality of all nodes in a graph.
38///
39/// The algorithm used in this function is based on:
40///
41/// Ulrik Brandes, A Faster Algorithm for Betweenness Centrality.
42/// Journal of Mathematical Sociology 25(2):163-177, 2001.
43///
44/// This function is multithreaded and will run in parallel if the number
45/// of nodes in the graph is above the value of ``parallel_threshold``. If the
46/// function will be running in parallel the env var ``RAYON_NUM_THREADS`` can
47/// be used to adjust how many threads will be used.
48///
49/// Arguments:
50///
51/// * `graph` - The graph object to run the algorithm on
52/// * `include_endpoints` - Whether to include the endpoints of paths in the path
53///     lengths used to compute the betweenness
54/// * `normalized` - Whether to normalize the betweenness scores by the number
55///     of distinct paths between all pairs of nodes
56/// * `parallel_threshold` - The number of nodes to calculate the betweenness
57///     centrality in parallel at, if the number of nodes in `graph` is less
58///     than this value it will run in a single thread. A good default to use
59///     here if you're not sure is `50` as that was found to be roughly the
60///     number of nodes where parallelism improves performance
61///
62/// # Example
63/// ```rust
64/// use rustworkx_core::petgraph;
65/// use rustworkx_core::centrality::betweenness_centrality;
66///
67/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
68///     (0, 4), (1, 2), (2, 3), (3, 4), (1, 4)
69/// ]);
70/// // Calculate the betweenness centrality
71/// let output = betweenness_centrality(&g, true, true, 200);
72/// assert_eq!(
73///     vec![Some(0.4), Some(0.5), Some(0.45), Some(0.5), Some(0.75)],
74///     output
75/// );
76/// ```
77/// # See Also
78/// [`edge_betweenness_centrality`]
79pub fn betweenness_centrality<G>(
80    graph: G,
81    include_endpoints: bool,
82    normalized: bool,
83    parallel_threshold: usize,
84) -> Vec<Option<f64>>
85where
86    G: NodeIndexable
87        + IntoNodeIdentifiers
88        + IntoNeighborsDirected
89        + NodeCount
90        + GraphProp
91        + GraphBase
92        + std::marker::Sync,
93    <G as GraphBase>::NodeId: std::cmp::Eq + Hash + Send,
94    // rustfmt deletes the following comments if placed inline above
95    // + IntoNodeIdentifiers // for node_identifiers()
96    // + IntoNeighborsDirected // for neighbors()
97    // + NodeCount // for node_count
98    // + GraphProp // for is_directed
99{
100    // Correspondence of variable names to quantities in the paper is as follows:
101    //
102    // P -- predecessors
103    // S -- verts_sorted_by_distance,
104    //      vertices in order of non-decreasing distance from s
105    // Q -- Q
106    // sigma -- sigma
107    // delta -- delta
108    // d -- distance
109    let max_index = graph.node_bound();
110
111    let mut betweenness: Vec<Option<f64>> = vec![None; max_index];
112    for node_s in graph.node_identifiers() {
113        let is: usize = graph.to_index(node_s);
114        betweenness[is] = Some(0.0);
115    }
116    let locked_betweenness = RwLock::new(&mut betweenness);
117    let node_indices: Vec<G::NodeId> = graph.node_identifiers().collect();
118
119    CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
120        .map(|node_s| (shortest_path_for_centrality(&graph, &node_s), node_s))
121        .for_each(|(mut shortest_path_calc, node_s)| {
122            _accumulate_vertices(
123                &locked_betweenness,
124                max_index,
125                &mut shortest_path_calc,
126                node_s,
127                &graph,
128                include_endpoints,
129            );
130        });
131
132    _rescale(
133        &mut betweenness,
134        graph.node_count(),
135        normalized,
136        graph.is_directed(),
137        include_endpoints,
138    );
139
140    betweenness
141}
142
143/// Compute the edge betweenness centrality of all edges in a graph.
144///
145/// The algorithm used in this function is based on:
146///
147/// Ulrik Brandes: On Variants of Shortest-Path Betweenness
148/// Centrality and their Generic Computation.
149/// Social Networks 30(2):136-145, 2008.
150/// <https://doi.org/10.1016/j.socnet.2007.11.001>.
151///
152/// This function is multithreaded and will run in parallel if the number
153/// of nodes in the graph is above the value of ``parallel_threshold``. If the
154/// function will be running in parallel the env var ``RAYON_NUM_THREADS`` can
155/// be used to adjust how many threads will be used.
156///
157/// Arguments:
158///
159/// * `graph` - The graph object to run the algorithm on
160/// * `normalized` - Whether to normalize the betweenness scores by the number
161///     of distinct paths between all pairs of nodes
162/// * `parallel_threshold` - The number of nodes to calculate the betweenness
163///     centrality in parallel at, if the number of nodes in `graph` is less
164///     than this value it will run in a single thread. A good default to use
165///     here if you're not sure is `50` as that was found to be roughly the
166///     number of nodes where parallelism improves performance
167///
168/// # Example
169/// ```rust
170/// use rustworkx_core::petgraph;
171/// use rustworkx_core::centrality::edge_betweenness_centrality;
172///
173/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
174///     (0, 4), (1, 2), (1, 3), (2, 3), (3, 4), (1, 4)
175/// ]);
176///
177/// let output = edge_betweenness_centrality(&g, false, 200);
178/// let expected = vec![Some(4.0), Some(2.0), Some(1.0), Some(2.0), Some(3.0), Some(3.0)];
179/// assert_eq!(output, expected);
180/// ```
181/// # See Also
182/// [`betweenness_centrality`]
183pub fn edge_betweenness_centrality<G>(
184    graph: G,
185    normalized: bool,
186    parallel_threshold: usize,
187) -> Vec<Option<f64>>
188where
189    G: NodeIndexable
190        + EdgeIndexable
191        + IntoEdges
192        + IntoNodeIdentifiers
193        + IntoNeighborsDirected
194        + NodeCount
195        + EdgeCount
196        + GraphProp
197        + Sync,
198    G::NodeId: Eq + Hash + Send,
199    G::EdgeId: Eq + Hash + Send,
200{
201    let max_index = graph.node_bound();
202    let mut betweenness = vec![None; graph.edge_bound()];
203    for edge in graph.edge_references() {
204        let is: usize = EdgeIndexable::to_index(&graph, edge.id());
205        betweenness[is] = Some(0.0);
206    }
207    let locked_betweenness = RwLock::new(&mut betweenness);
208    let node_indices: Vec<G::NodeId> = graph.node_identifiers().collect();
209    CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
210        .map(|node_s| shortest_path_for_edge_centrality(&graph, &node_s))
211        .for_each(|mut shortest_path_calc| {
212            accumulate_edges(
213                &locked_betweenness,
214                max_index,
215                &mut shortest_path_calc,
216                &graph,
217            );
218        });
219
220    _rescale(
221        &mut betweenness,
222        graph.node_count(),
223        normalized,
224        graph.is_directed(),
225        true,
226    );
227    betweenness
228}
229
230fn _rescale(
231    betweenness: &mut [Option<f64>],
232    node_count: usize,
233    normalized: bool,
234    directed: bool,
235    include_endpoints: bool,
236) {
237    let mut do_scale = true;
238    let mut scale = 1.0;
239    if normalized {
240        if include_endpoints {
241            if node_count < 2 {
242                do_scale = false;
243            } else {
244                scale = 1.0 / (node_count * (node_count - 1)) as f64;
245            }
246        } else if node_count <= 2 {
247            do_scale = false;
248        } else {
249            scale = 1.0 / ((node_count - 1) * (node_count - 2)) as f64;
250        }
251    } else if !directed {
252        scale = 0.5;
253    } else {
254        do_scale = false;
255    }
256    if do_scale {
257        for x in betweenness.iter_mut() {
258            *x = x.map(|y| y * scale);
259        }
260    }
261}
262
263fn _accumulate_vertices<G>(
264    locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
265    max_index: usize,
266    path_calc: &mut ShortestPathData<G>,
267    node_s: <G as GraphBase>::NodeId,
268    graph: G,
269    include_endpoints: bool,
270) where
271    G: NodeIndexable
272        + IntoNodeIdentifiers
273        + IntoNeighborsDirected
274        + NodeCount
275        + GraphProp
276        + GraphBase
277        + std::marker::Sync,
278    <G as GraphBase>::NodeId: std::cmp::Eq + Hash,
279{
280    let mut delta = vec![0.0; max_index];
281    for w in &path_calc.verts_sorted_by_distance {
282        let iw = graph.to_index(*w);
283        let coeff = (1.0 + delta[iw]) / path_calc.sigma[w];
284        let p_w = path_calc.predecessors.get(w).unwrap();
285        for v in p_w {
286            let iv = graph.to_index(*v);
287            delta[iv] += path_calc.sigma[v] * coeff;
288        }
289    }
290    let mut betweenness = locked_betweenness.write().unwrap();
291    if include_endpoints {
292        let i_node_s = graph.to_index(node_s);
293        betweenness[i_node_s] = betweenness[i_node_s]
294            .map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64));
295        for w in &path_calc.verts_sorted_by_distance {
296            if *w != node_s {
297                let iw = graph.to_index(*w);
298                betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0);
299            }
300        }
301    } else {
302        for w in &path_calc.verts_sorted_by_distance {
303            if *w != node_s {
304                let iw = graph.to_index(*w);
305                betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]);
306            }
307        }
308    }
309}
310
311fn accumulate_edges<G>(
312    locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
313    max_index: usize,
314    path_calc: &mut ShortestPathDataWithEdges<G>,
315    graph: G,
316) where
317    G: NodeIndexable + EdgeIndexable + Sync,
318    G::NodeId: Eq + Hash,
319    G::EdgeId: Eq + Hash,
320{
321    let mut delta = vec![0.0; max_index];
322    for w in &path_calc.verts_sorted_by_distance {
323        let iw = NodeIndexable::to_index(&graph, *w);
324        let coeff = (1.0 + delta[iw]) / path_calc.sigma[w];
325        let p_w = path_calc.predecessors.get(w).unwrap();
326        let e_w = path_calc.predecessor_edges.get(w).unwrap();
327        let mut betweenness = locked_betweenness.write().unwrap();
328        for i in 0..p_w.len() {
329            let v = p_w[i];
330            let iv = NodeIndexable::to_index(&graph, v);
331            let ie = EdgeIndexable::to_index(&graph, e_w[i]);
332            let c = path_calc.sigma[&v] * coeff;
333            betweenness[ie] = betweenness[ie].map(|x| x + c);
334            delta[iv] += c;
335        }
336    }
337}
338/// Compute the degree centrality of all nodes in a graph.
339///
340/// For undirected graphs, this calculates the normalized degree for each node.
341/// For directed graphs, this calculates the normalized out-degree for each node.
342///
343/// Arguments:
344///
345/// * `graph` - The graph object to calculate degree centrality for
346///
347/// # Example
348/// ```rust
349/// use rustworkx_core::petgraph::graph::{UnGraph, DiGraph};
350/// use rustworkx_core::centrality::degree_centrality;
351///
352/// // Undirected graph example
353/// let graph = UnGraph::<i32, ()>::from_edges(&[
354///     (0, 1), (1, 2), (2, 3), (3, 0)
355/// ]);
356/// let centrality = degree_centrality(&graph, None);
357///
358/// // Directed graph example
359/// let digraph = DiGraph::<i32, ()>::from_edges(&[
360///     (0, 1), (1, 2), (2, 3), (3, 0), (0, 2), (1, 3)
361/// ]);
362/// let centrality = degree_centrality(&digraph, None);
363/// ```
364pub fn degree_centrality<G>(graph: G, direction: Option<petgraph::Direction>) -> Vec<f64>
365where
366    G: NodeIndexable
367        + IntoNodeIdentifiers
368        + IntoNeighbors
369        + IntoNeighborsDirected
370        + NodeCount
371        + GraphProp,
372    G::NodeId: Eq + Hash,
373{
374    let node_count = graph.node_count() as f64;
375    let mut centrality = vec![0.0; graph.node_bound()];
376
377    for node in graph.node_identifiers() {
378        let (degree, normalization) = match (graph.is_directed(), direction) {
379            (true, None) => {
380                let out_degree = graph
381                    .neighbors_directed(node, petgraph::Direction::Outgoing)
382                    .count() as f64;
383                let in_degree = graph
384                    .neighbors_directed(node, petgraph::Direction::Incoming)
385                    .count() as f64;
386                let total = in_degree + out_degree;
387                // Use 2(n-1) normalization only if this is a complete graph
388                let norm = if total == 2.0 * (node_count - 1.0) {
389                    2.0 * (node_count - 1.0)
390                } else {
391                    node_count - 1.0
392                };
393                (total, norm)
394            }
395            (true, Some(dir)) => (
396                graph.neighbors_directed(node, dir).count() as f64,
397                node_count - 1.0,
398            ),
399            (false, _) => (graph.neighbors(node).count() as f64, node_count - 1.0),
400        };
401        centrality[graph.to_index(node)] = degree / normalization;
402    }
403
404    centrality
405}
406
407struct ShortestPathData<G>
408where
409    G: GraphBase,
410    <G as GraphBase>::NodeId: std::cmp::Eq + Hash,
411{
412    verts_sorted_by_distance: Vec<G::NodeId>,
413    predecessors: HashMap<G::NodeId, Vec<G::NodeId>>,
414    sigma: HashMap<G::NodeId, f64>,
415}
416
417fn shortest_path_for_centrality<G>(graph: G, node_s: &G::NodeId) -> ShortestPathData<G>
418where
419    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighborsDirected + NodeCount + GraphBase,
420    <G as GraphBase>::NodeId: std::cmp::Eq + Hash,
421{
422    let mut verts_sorted_by_distance: Vec<G::NodeId> = Vec::new(); // a stack
423    let c = graph.node_count();
424    let mut predecessors = HashMap::<G::NodeId, Vec<G::NodeId>>::with_capacity(c);
425    let mut sigma = HashMap::<G::NodeId, f64>::with_capacity(c);
426    let mut distance = HashMap::<G::NodeId, i64>::with_capacity(c);
427    #[allow(non_snake_case)]
428    let mut Q: VecDeque<G::NodeId> = VecDeque::with_capacity(c);
429
430    for node in graph.node_identifiers() {
431        predecessors.insert(node, Vec::new());
432        sigma.insert(node, 0.0);
433        distance.insert(node, -1);
434    }
435    sigma.insert(*node_s, 1.0);
436    distance.insert(*node_s, 0);
437    Q.push_back(*node_s);
438    while let Some(v) = Q.pop_front() {
439        verts_sorted_by_distance.push(v);
440        let distance_v = distance[&v];
441        for w in graph.neighbors(v) {
442            if distance[&w] < 0 {
443                Q.push_back(w);
444                distance.insert(w, distance_v + 1);
445            }
446            if distance[&w] == distance_v + 1 {
447                sigma.insert(w, sigma[&w] + sigma[&v]);
448                let e_p = predecessors.get_mut(&w).unwrap();
449                e_p.push(v);
450            }
451        }
452    }
453    verts_sorted_by_distance.reverse(); // will be effectively popping from the stack
454    ShortestPathData {
455        verts_sorted_by_distance,
456        predecessors,
457        sigma,
458    }
459}
460
461struct ShortestPathDataWithEdges<G>
462where
463    G: GraphBase,
464    G::NodeId: Eq + Hash,
465    G::EdgeId: Eq + Hash,
466{
467    verts_sorted_by_distance: Vec<G::NodeId>,
468    predecessors: HashMap<G::NodeId, Vec<G::NodeId>>,
469    predecessor_edges: HashMap<G::NodeId, Vec<G::EdgeId>>,
470    sigma: HashMap<G::NodeId, f64>,
471}
472
473fn shortest_path_for_edge_centrality<G>(
474    graph: G,
475    node_s: &G::NodeId,
476) -> ShortestPathDataWithEdges<G>
477where
478    G: NodeIndexable
479        + IntoNodeIdentifiers
480        + IntoNeighborsDirected
481        + NodeCount
482        + GraphBase
483        + IntoEdges,
484    G::NodeId: Eq + Hash,
485    G::EdgeId: Eq + Hash,
486{
487    let mut verts_sorted_by_distance: Vec<G::NodeId> = Vec::new(); // a stack
488    let c = graph.node_count();
489    let mut predecessors = HashMap::<G::NodeId, Vec<G::NodeId>>::with_capacity(c);
490    let mut predecessor_edges = HashMap::<G::NodeId, Vec<G::EdgeId>>::with_capacity(c);
491    let mut sigma = HashMap::<G::NodeId, f64>::with_capacity(c);
492    let mut distance = HashMap::<G::NodeId, i64>::with_capacity(c);
493    #[allow(non_snake_case)]
494    let mut Q: VecDeque<G::NodeId> = VecDeque::with_capacity(c);
495
496    for node in graph.node_identifiers() {
497        predecessors.insert(node, Vec::new());
498        predecessor_edges.insert(node, Vec::new());
499        sigma.insert(node, 0.0);
500        distance.insert(node, -1);
501    }
502    sigma.insert(*node_s, 1.0);
503    distance.insert(*node_s, 0);
504    Q.push_back(*node_s);
505    while let Some(v) = Q.pop_front() {
506        verts_sorted_by_distance.push(v);
507        let distance_v = distance[&v];
508        for edge in graph.edges(v) {
509            let w = edge.target();
510            if distance[&w] < 0 {
511                Q.push_back(w);
512                distance.insert(w, distance_v + 1);
513            }
514            if distance[&w] == distance_v + 1 {
515                sigma.insert(w, sigma[&w] + sigma[&v]);
516                let e_p = predecessors.get_mut(&w).unwrap();
517                e_p.push(v);
518                predecessor_edges.get_mut(&w).unwrap().push(edge.id());
519            }
520        }
521    }
522    verts_sorted_by_distance.reverse(); // will be effectively popping from the stack
523    ShortestPathDataWithEdges {
524        verts_sorted_by_distance,
525        predecessors,
526        predecessor_edges,
527        sigma,
528    }
529}
530
531#[cfg(test)]
532mod test_edge_betweenness_centrality {
533    use crate::centrality::edge_betweenness_centrality;
534    use petgraph::graph::edge_index;
535    use petgraph::prelude::StableGraph;
536    use petgraph::Undirected;
537
538    macro_rules! assert_almost_equal {
539        ($x:expr, $y:expr, $d:expr) => {
540            if ($x - $y).abs() >= $d {
541                panic!("{} != {} within delta of {}", $x, $y, $d);
542            }
543        };
544    }
545
546    #[test]
547    fn test_undirected_graph_normalized() {
548        let graph = petgraph::graph::UnGraph::<(), ()>::from_edges([
549            (0, 6),
550            (0, 4),
551            (0, 1),
552            (0, 5),
553            (1, 6),
554            (1, 7),
555            (1, 3),
556            (1, 4),
557            (2, 6),
558            (2, 3),
559            (3, 5),
560            (3, 7),
561            (3, 6),
562            (4, 5),
563            (5, 6),
564        ]);
565        let output = edge_betweenness_centrality(&graph, true, 50);
566        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
567        let expected_values = [
568            0.1023809, 0.0547619, 0.0922619, 0.05654762, 0.09940476, 0.125, 0.09940476, 0.12440476,
569            0.12857143, 0.12142857, 0.13511905, 0.125, 0.06547619, 0.08869048, 0.08154762,
570        ];
571        for i in 0..15 {
572            assert_almost_equal!(result[i], expected_values[i], 1e-4);
573        }
574    }
575
576    #[test]
577    fn test_undirected_graph_unnormalized() {
578        let graph = petgraph::graph::UnGraph::<(), ()>::from_edges([
579            (0, 2),
580            (0, 4),
581            (0, 1),
582            (1, 3),
583            (1, 5),
584            (1, 7),
585            (2, 7),
586            (2, 3),
587            (3, 5),
588            (3, 6),
589            (4, 6),
590            (5, 7),
591        ]);
592        let output = edge_betweenness_centrality(&graph, false, 50);
593        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
594        let expected_values = [
595            3.83333, 5.5, 5.33333, 3.5, 2.5, 3.0, 3.5, 4.0, 3.66667, 6.5, 3.5, 2.16667,
596        ];
597        for i in 0..12 {
598            assert_almost_equal!(result[i], expected_values[i], 1e-4);
599        }
600    }
601
602    #[test]
603    fn test_directed_graph_normalized() {
604        let graph = petgraph::graph::DiGraph::<(), ()>::from_edges([
605            (0, 1),
606            (1, 0),
607            (1, 3),
608            (1, 2),
609            (1, 4),
610            (2, 3),
611            (2, 4),
612            (2, 1),
613            (3, 2),
614            (4, 3),
615        ]);
616        let output = edge_betweenness_centrality(&graph, true, 50);
617        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
618        let expected_values = [0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.1, 0.3, 0.35, 0.2];
619        for i in 0..10 {
620            assert_almost_equal!(result[i], expected_values[i], 1e-4);
621        }
622    }
623
624    #[test]
625    fn test_directed_graph_unnormalized() {
626        let graph = petgraph::graph::DiGraph::<(), ()>::from_edges([
627            (0, 4),
628            (1, 0),
629            (1, 3),
630            (2, 3),
631            (2, 4),
632            (2, 0),
633            (3, 4),
634            (3, 2),
635            (3, 1),
636            (4, 1),
637        ]);
638        let output = edge_betweenness_centrality(&graph, false, 50);
639        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
640        let expected_values = [4.5, 3.0, 6.5, 1.5, 1.5, 1.5, 1.5, 4.5, 2.0, 7.5];
641        for i in 0..10 {
642            assert_almost_equal!(result[i], expected_values[i], 1e-4);
643        }
644    }
645
646    #[test]
647    fn test_stable_graph_with_removed_edges() {
648        let mut graph: StableGraph<(), (), Undirected> =
649            StableGraph::from_edges([(0, 1), (1, 2), (2, 3), (3, 0)]);
650        graph.remove_edge(edge_index(1));
651        let result = edge_betweenness_centrality(&graph, false, 50);
652        let expected_values = vec![Some(3.0), None, Some(3.0), Some(4.0)];
653        assert_eq!(result, expected_values);
654    }
655}
656
657/// Compute the eigenvector centrality of a graph
658///
659/// For details on the eigenvector centrality refer to:
660///
661/// Phillip Bonacich. “Power and Centrality: A Family of Measures.”
662/// American Journal of Sociology 92(5):1170–1182, 1986
663/// <https://doi.org/10.1086/228631>
664///
665/// This function uses a power iteration method to compute the eigenvector
666/// and convergence is not guaranteed. The function will stop when `max_iter`
667/// iterations is reached or when the computed vector between two iterations
668/// is smaller than the error tolerance multiplied by the number of nodes.
669/// The implementation of this algorithm is based on the NetworkX
670/// [`eigenvector_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.eigenvector_centrality.html)
671/// function.
672///
673/// In the case of multigraphs the weights of any parallel edges will be
674/// summed when computing the eigenvector centrality.
675///
676/// Arguments:
677///
678/// * `graph` - The graph object to run the algorithm on
679/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for
680///     an edge in the graph and is expected to return a `Result<f64>` of
681///     the weight of that edge.
682/// * `max_iter` - The maximum number of iterations in the power method. If
683///     set to `None` a default value of 100 is used.
684/// * `tol` - The error tolerance used when checking for convergence in the
685///     power method. If set to `None` a default value of 1e-6 is used.
686///
687/// # Example
688/// ```rust
689/// use rustworkx_core::Result;
690/// use rustworkx_core::petgraph;
691/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers};
692/// use rustworkx_core::centrality::eigenvector_centrality;
693///
694/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
695///     (0, 1), (1, 2)
696/// ]);
697/// // Calculate the eigenvector centrality
698/// let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| {Ok(1.)}, None, None);
699/// ```
700pub fn eigenvector_centrality<G, F, E>(
701    graph: G,
702    mut weight_fn: F,
703    max_iter: Option<usize>,
704    tol: Option<f64>,
705) -> Result<Option<Vec<f64>>, E>
706where
707    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount,
708    G::NodeId: Eq + std::hash::Hash,
709    F: FnMut(G::EdgeRef) -> Result<f64, E>,
710{
711    let tol: f64 = tol.unwrap_or(1e-6);
712    let max_iter = max_iter.unwrap_or(100);
713    let mut x: Vec<f64> = vec![1.; graph.node_bound()];
714    let node_count = graph.node_count();
715    for _ in 0..max_iter {
716        let x_last = x.clone();
717        for node_index in graph.node_identifiers() {
718            let node = graph.to_index(node_index);
719            for edge in graph.edges(node_index) {
720                let w = weight_fn(edge)?;
721                let neighbor = edge.target();
722                x[graph.to_index(neighbor)] += x_last[node] * w;
723            }
724        }
725        let norm: f64 = x.iter().map(|val| val.powi(2)).sum::<f64>().sqrt();
726        if norm == 0. {
727            return Ok(None);
728        }
729        for v in x.iter_mut() {
730            *v /= norm;
731        }
732        if (0..x.len())
733            .map(|node| (x[node] - x_last[node]).abs())
734            .sum::<f64>()
735            < node_count as f64 * tol
736        {
737            return Ok(Some(x));
738        }
739    }
740    Ok(None)
741}
742
743/// Compute the Katz centrality of a graph
744///
745/// For details on the Katz centrality refer to:
746///
747/// Leo Katz. “A New Status Index Derived from Sociometric Index.”
748/// Psychometrika 18(1):39–43, 1953
749/// <https://link.springer.com/content/pdf/10.1007/BF02289026.pdf>
750///
751/// This function uses a power iteration method to compute the eigenvector
752/// and convergence is not guaranteed. The function will stop when `max_iter`
753/// iterations is reached or when the computed vector between two iterations
754/// is smaller than the error tolerance multiplied by the number of nodes.
755/// The implementation of this algorithm is based on the NetworkX
756/// [`katz_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html)
757/// function.
758///
759/// In the case of multigraphs the weights of any parallel edges will be
760/// summed when computing the eigenvector centrality.
761///
762/// Arguments:
763///
764/// * `graph` - The graph object to run the algorithm on
765/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for
766///     an edge in the graph and is expected to return a `Result<f64>` of
767///     the weight of that edge.
768/// * `alpha` - Attenuation factor. If set to `None`, a default value of 0.1 is used.
769/// * `beta_map` - Immediate neighbourhood weights. Must contain all node indices or be `None`.
770/// * `beta_scalar` - Immediate neighbourhood scalar that replaces `beta_map` in case `beta_map` is None.
771///     Defaults to 1.0 in case `None` is provided.
772/// * `max_iter` - The maximum number of iterations in the power method. If
773///     set to `None` a default value of 100 is used.
774/// * `tol` - The error tolerance used when checking for convergence in the
775///     power method. If set to `None` a default value of 1e-6 is used.
776///
777/// # Example
778/// ```rust
779/// use rustworkx_core::Result;
780/// use rustworkx_core::petgraph;
781/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers};
782/// use rustworkx_core::centrality::katz_centrality;
783///
784/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
785///     (0, 1), (1, 2)
786/// ]);
787/// // Calculate the eigenvector centrality
788/// let output: Result<Option<Vec<f64>>> = katz_centrality(&g, |_| {Ok(1.)}, None, None, None, None, None);
789/// let centralities = output.unwrap().unwrap();
790/// assert!(centralities[1] > centralities[0], "Node 1 is more central than node 0");
791/// assert!(centralities[1] > centralities[2], "Node 1 is more central than node 2");
792/// ```
793pub fn katz_centrality<G, F, E>(
794    graph: G,
795    mut weight_fn: F,
796    alpha: Option<f64>,
797    beta_map: Option<HashMap<usize, f64>>,
798    beta_scalar: Option<f64>,
799    max_iter: Option<usize>,
800    tol: Option<f64>,
801) -> Result<Option<Vec<f64>>, E>
802where
803    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount,
804    G::NodeId: Eq + std::hash::Hash,
805    F: FnMut(G::EdgeRef) -> Result<f64, E>,
806{
807    let alpha: f64 = alpha.unwrap_or(0.1);
808
809    let beta: HashMap<usize, f64> = beta_map.unwrap_or_default();
810    //Initialize the beta vector in case a beta map was not provided
811    let mut beta_v = vec![beta_scalar.unwrap_or(1.0); graph.node_bound()];
812
813    if !beta.is_empty() {
814        // Check if beta contains all node indices
815        for node_index in graph.node_identifiers() {
816            let node = graph.to_index(node_index);
817            if !beta.contains_key(&node) {
818                return Ok(None); // beta_map was provided but did not include all nodes
819            }
820            beta_v[node] = *beta.get(&node).unwrap(); //Initialize the beta vector with the provided values
821        }
822    }
823
824    let tol: f64 = tol.unwrap_or(1e-6);
825    let max_iter = max_iter.unwrap_or(1000);
826
827    let mut x: Vec<f64> = vec![0.; graph.node_bound()];
828    let node_count = graph.node_count();
829    for _ in 0..max_iter {
830        let x_last = x.clone();
831        x = vec![0.; graph.node_bound()];
832        for node_index in graph.node_identifiers() {
833            let node = graph.to_index(node_index);
834            for edge in graph.edges(node_index) {
835                let w = weight_fn(edge)?;
836                let neighbor = edge.target();
837                x[graph.to_index(neighbor)] += x_last[node] * w;
838            }
839        }
840        for node_index in graph.node_identifiers() {
841            let node = graph.to_index(node_index);
842            x[node] = alpha * x[node] + beta_v[node];
843        }
844        if (0..x.len())
845            .map(|node| (x[node] - x_last[node]).abs())
846            .sum::<f64>()
847            < node_count as f64 * tol
848        {
849            // Normalize vector
850            let norm: f64 = x.iter().map(|val| val.powi(2)).sum::<f64>().sqrt();
851            if norm == 0. {
852                return Ok(None);
853            }
854            for v in x.iter_mut() {
855                *v /= norm;
856            }
857
858            return Ok(Some(x));
859        }
860    }
861
862    Ok(None)
863}
864
865#[cfg(test)]
866mod test_eigenvector_centrality {
867
868    use crate::centrality::eigenvector_centrality;
869    use crate::petgraph;
870    use crate::Result;
871
872    macro_rules! assert_almost_equal {
873        ($x:expr, $y:expr, $d:expr) => {
874            if ($x - $y).abs() >= $d {
875                panic!("{} != {} within delta of {}", $x, $y, $d);
876            }
877        };
878    }
879    #[test]
880    fn test_no_convergence() {
881        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
882        let output: Result<Option<Vec<f64>>> =
883            eigenvector_centrality(&g, |_| Ok(1.), Some(0), None);
884        let result = output.unwrap();
885        assert_eq!(None, result);
886    }
887
888    #[test]
889    fn test_undirected_complete_graph() {
890        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
891            (0, 1),
892            (0, 2),
893            (0, 3),
894            (0, 4),
895            (1, 2),
896            (1, 3),
897            (1, 4),
898            (2, 3),
899            (2, 4),
900            (3, 4),
901        ]);
902        let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| Ok(1.), None, None);
903        let result = output.unwrap().unwrap();
904        let expected_value: f64 = (1_f64 / 5_f64).sqrt();
905        let expected_values: Vec<f64> = vec![expected_value; 5];
906        for i in 0..5 {
907            assert_almost_equal!(expected_values[i], result[i], 1e-4);
908        }
909    }
910
911    #[test]
912    fn test_undirected_path_graph() {
913        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
914        let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| Ok(1.), None, None);
915        let result = output.unwrap().unwrap();
916        let expected_values: Vec<f64> = vec![0.5, std::f64::consts::FRAC_1_SQRT_2, 0.5];
917        for i in 0..3 {
918            assert_almost_equal!(expected_values[i], result[i], 1e-4);
919        }
920    }
921
922    #[test]
923    fn test_directed_graph() {
924        let g = petgraph::graph::DiGraph::<i32, ()>::from_edges([
925            (0, 1),
926            (0, 2),
927            (1, 3),
928            (2, 1),
929            (2, 4),
930            (3, 1),
931            (3, 4),
932            (3, 5),
933            (4, 5),
934            (4, 6),
935            (4, 7),
936            (5, 7),
937            (6, 0),
938            (6, 4),
939            (6, 7),
940            (7, 5),
941            (7, 6),
942        ]);
943        let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| Ok(2.), None, None);
944        let result = output.unwrap().unwrap();
945        let expected_values: Vec<f64> = vec![
946            0.2140437, 0.2009269, 0.1036383, 0.0972886, 0.3113323, 0.4891686, 0.4420605, 0.6016448,
947        ];
948        for i in 0..8 {
949            assert_almost_equal!(expected_values[i], result[i], 1e-4);
950        }
951    }
952}
953
954#[cfg(test)]
955mod test_katz_centrality {
956
957    use crate::centrality::katz_centrality;
958    use crate::petgraph;
959    use crate::Result;
960    use hashbrown::HashMap;
961
962    macro_rules! assert_almost_equal {
963        ($x:expr, $y:expr, $d:expr) => {
964            if ($x - $y).abs() >= $d {
965                panic!("{} != {} within delta of {}", $x, $y, $d);
966            }
967        };
968    }
969    #[test]
970    fn test_no_convergence() {
971        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
972        let output: Result<Option<Vec<f64>>> =
973            katz_centrality(&g, |_| Ok(1.), None, None, None, Some(0), None);
974        let result = output.unwrap();
975        assert_eq!(None, result);
976    }
977
978    #[test]
979    fn test_incomplete_beta() {
980        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
981        let beta_map: HashMap<usize, f64> = [(0, 1.0)].iter().cloned().collect();
982        let output: Result<Option<Vec<f64>>> =
983            katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None);
984        let result = output.unwrap();
985        assert_eq!(None, result);
986    }
987
988    #[test]
989    fn test_complete_beta() {
990        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
991        let beta_map: HashMap<usize, f64> =
992            [(0, 0.5), (1, 1.0), (2, 0.5)].iter().cloned().collect();
993        let output: Result<Option<Vec<f64>>> =
994            katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None);
995        let result = output.unwrap().unwrap();
996        let expected_values: Vec<f64> =
997            vec![0.4318894504492167, 0.791797325823564, 0.4318894504492167];
998        for i in 0..3 {
999            assert_almost_equal!(expected_values[i], result[i], 1e-4);
1000        }
1001    }
1002
1003    #[test]
1004    fn test_undirected_complete_graph() {
1005        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
1006            (0, 1),
1007            (0, 2),
1008            (0, 3),
1009            (0, 4),
1010            (1, 2),
1011            (1, 3),
1012            (1, 4),
1013            (2, 3),
1014            (2, 4),
1015            (3, 4),
1016        ]);
1017        let output: Result<Option<Vec<f64>>> =
1018            katz_centrality(&g, |_| Ok(1.), Some(0.2), None, Some(1.1), None, None);
1019        let result = output.unwrap().unwrap();
1020        let expected_value: f64 = (1_f64 / 5_f64).sqrt();
1021        let expected_values: Vec<f64> = vec![expected_value; 5];
1022        for i in 0..5 {
1023            assert_almost_equal!(expected_values[i], result[i], 1e-4);
1024        }
1025    }
1026
1027    #[test]
1028    fn test_directed_graph() {
1029        let g = petgraph::graph::DiGraph::<i32, ()>::from_edges([
1030            (0, 1),
1031            (0, 2),
1032            (1, 3),
1033            (2, 1),
1034            (2, 4),
1035            (3, 1),
1036            (3, 4),
1037            (3, 5),
1038            (4, 5),
1039            (4, 6),
1040            (4, 7),
1041            (5, 7),
1042            (6, 0),
1043            (6, 4),
1044            (6, 7),
1045            (7, 5),
1046            (7, 6),
1047        ]);
1048        let output: Result<Option<Vec<f64>>> =
1049            katz_centrality(&g, |_| Ok(1.), None, None, None, None, None);
1050        let result = output.unwrap().unwrap();
1051        let expected_values: Vec<f64> = vec![
1052            0.3135463087489011,
1053            0.3719056758615039,
1054            0.3094350787809586,
1055            0.31527101632646026,
1056            0.3760169058294464,
1057            0.38618584417917906,
1058            0.35465874858087904,
1059            0.38976653416801743,
1060        ];
1061
1062        for i in 0..8 {
1063            assert_almost_equal!(expected_values[i], result[i], 1e-4);
1064        }
1065    }
1066}
1067
1068/// Compute the closeness centrality of each node in the graph.
1069///
1070/// The closeness centrality of a node `u` is the reciprocal of the average
1071/// shortest path distance to `u` over all `n-1` reachable nodes.
1072///
1073/// In the case of a graphs with more than one connected component there is
1074/// an alternative improved formula that calculates the closeness centrality
1075/// as "a ratio of the fraction of actors in the group who are reachable, to
1076/// the average distance".[^WF]
1077/// You can enable this by setting `wf_improved` to `true`.
1078///
1079/// [^WF]: Wasserman, S., & Faust, K. (1994). Social Network Analysis:
1080///     Methods and Applications (Structural Analysis in the Social Sciences).
1081///     Cambridge: Cambridge University Press.
1082///     <https://doi.org/10.1017/CBO9780511815478>
1083///
1084/// # Arguments
1085///
1086/// * `graph` - The graph object to run the algorithm on
1087/// * `wf_improved` - If `true`, scale by the fraction of nodes reachable.
1088///
1089/// # Example
1090///
1091/// ```rust
1092/// use rustworkx_core::petgraph;
1093/// use rustworkx_core::centrality::closeness_centrality;
1094///
1095/// // Calculate the closeness centrality of Graph
1096/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
1097///     (0, 4), (1, 2), (2, 3), (3, 4), (1, 4)
1098/// ]);
1099/// let output = closeness_centrality(&g, true);
1100/// assert_eq!(
1101///     vec![Some(1./2.), Some(2./3.), Some(4./7.), Some(2./3.), Some(4./5.)],
1102///     output
1103/// );
1104///
1105/// // Calculate the closeness centrality of DiGraph
1106/// let dg = petgraph::graph::DiGraph::<i32, ()>::from_edges(&[
1107///     (0, 4), (1, 2), (2, 3), (3, 4), (1, 4)
1108/// ]);
1109/// let output = closeness_centrality(&dg, true);
1110/// assert_eq!(
1111///     vec![Some(0.), Some(0.), Some(1./4.), Some(1./3.), Some(4./5.)],
1112///     output
1113/// );
1114/// ```
1115pub fn closeness_centrality<G>(graph: G, wf_improved: bool) -> Vec<Option<f64>>
1116where
1117    G: NodeIndexable
1118        + IntoNodeIdentifiers
1119        + GraphBase
1120        + IntoEdges
1121        + Visitable
1122        + NodeCount
1123        + IntoEdgesDirected,
1124    G::NodeId: std::hash::Hash + Eq,
1125{
1126    let max_index = graph.node_bound();
1127    let mut closeness: Vec<Option<f64>> = vec![None; max_index];
1128    for node_s in graph.node_identifiers() {
1129        let is = graph.to_index(node_s);
1130        let map = dijkstra(Reversed(&graph), node_s, None, |_| 1);
1131        let reachable_nodes_count = map.len();
1132        let dists_sum: usize = map.into_values().sum();
1133        if reachable_nodes_count == 1 {
1134            closeness[is] = Some(0.0);
1135            continue;
1136        }
1137        closeness[is] = Some((reachable_nodes_count - 1) as f64 / dists_sum as f64);
1138        if wf_improved {
1139            let node_count = graph.node_count();
1140            closeness[is] = closeness[is]
1141                .map(|c| c * (reachable_nodes_count - 1) as f64 / (node_count - 1) as f64);
1142        }
1143    }
1144    closeness
1145}