Skip to main content

rustworkx_core/
centrality.rs

1// Licensed under the Apache License, Version 2.0 (the "License"); you may
2// not use this file except in compliance with the License. You may obtain
3// a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
9// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
10// License for the specific language governing permissions and limitations
11// under the License.
12
13use std::collections::VecDeque;
14use std::hash::Hash;
15use std::sync::RwLock;
16
17use hashbrown::HashSet;
18
19use hashbrown::HashMap;
20use petgraph::algo::dijkstra;
21use petgraph::visit::{
22    Bfs, EdgeCount, EdgeIndexable, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges,
23    IntoEdgesDirected, IntoNeighbors, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount,
24    NodeIndexable, Reversed, ReversedEdgeReference, Visitable,
25};
26use rayon_cond::CondIterator;
27
28/// Compute the betweenness centrality of all nodes in a graph.
29///
30/// The algorithm used in this function is based on:
31///
32/// Ulrik Brandes, A Faster Algorithm for Betweenness Centrality.
33/// Journal of Mathematical Sociology 25(2):163-177, 2001.
34///
35/// This function is multithreaded and will run in parallel if the number
36/// of nodes in the graph is above the value of ``parallel_threshold``. If the
37/// function will be running in parallel the env var ``RAYON_NUM_THREADS`` can
38/// be used to adjust how many threads will be used.
39///
40/// Arguments:
41///
42/// * `graph` - The graph object to run the algorithm on
43/// * `include_endpoints` - Whether to include the endpoints of paths in the path
44///   lengths used to compute the betweenness
45/// * `normalized` - Whether to normalize the betweenness scores by the number
46///   of distinct paths between all pairs of nodes
47/// * `parallel_threshold` - The number of nodes to calculate the betweenness
48///   centrality in parallel at, if the number of nodes in `graph` is less
49///   than this value it will run in a single thread. A good default to use
50///   here if you're not sure is `50` as that was found to be roughly the
51///   number of nodes where parallelism improves performance
52///
53/// # Example
54/// ```rust
55/// use rustworkx_core::petgraph;
56/// use rustworkx_core::centrality::betweenness_centrality;
57///
58/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
59///     (0, 4), (1, 2), (2, 3), (3, 4), (1, 4)
60/// ]);
61/// // Calculate the betweenness centrality
62/// let output = betweenness_centrality(&g, true, true, 200);
63/// assert_eq!(
64///     vec![Some(0.4), Some(0.5), Some(0.45), Some(0.5), Some(0.75)],
65///     output
66/// );
67/// ```
68/// # See Also
69/// [`edge_betweenness_centrality`]
70pub fn betweenness_centrality<G>(
71    graph: G,
72    include_endpoints: bool,
73    normalized: bool,
74    parallel_threshold: usize,
75) -> Vec<Option<f64>>
76where
77    G: NodeIndexable
78        + IntoNodeIdentifiers
79        + IntoNeighborsDirected
80        + NodeCount
81        + GraphProp
82        + GraphBase
83        + std::marker::Sync,
84    <G as GraphBase>::NodeId: std::cmp::Eq + Send,
85    // rustfmt deletes the following comments if placed inline above
86    // + IntoNodeIdentifiers // for node_identifiers()
87    // + IntoNeighborsDirected // for neighbors()
88    // + NodeCount // for node_count
89    // + GraphProp // for is_directed
90{
91    // Correspondence of variable names to quantities in the paper is as follows:
92    //
93    // P -- predecessors
94    // S -- verts_sorted_by_distance,
95    //      vertices in order of non-decreasing distance from s
96    // Q -- Q
97    // sigma -- sigma
98    // delta -- delta
99    // d -- distance
100    let max_index = graph.node_bound();
101
102    let mut betweenness: Vec<Option<f64>> = vec![None; max_index];
103    for node_s in graph.node_identifiers() {
104        let is: usize = graph.to_index(node_s);
105        betweenness[is] = Some(0.0);
106    }
107    let locked_betweenness = RwLock::new(&mut betweenness);
108    let node_indices: Vec<G::NodeId> = graph.node_identifiers().collect();
109
110    CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
111        .map(|node_s| (shortest_path_for_centrality(&graph, &node_s), node_s))
112        .for_each(|(mut shortest_path_calc, node_s)| {
113            _accumulate_vertices(
114                &locked_betweenness,
115                max_index,
116                &mut shortest_path_calc,
117                node_s,
118                &graph,
119                include_endpoints,
120            );
121        });
122
123    _rescale(
124        &mut betweenness,
125        graph.node_count(),
126        normalized,
127        graph.is_directed(),
128        include_endpoints,
129    );
130
131    betweenness
132}
133
134/// Compute the edge betweenness centrality of all edges in a graph.
135///
136/// The algorithm used in this function is based on:
137///
138/// Ulrik Brandes: On Variants of Shortest-Path Betweenness
139/// Centrality and their Generic Computation.
140/// Social Networks 30(2):136-145, 2008.
141/// <https://doi.org/10.1016/j.socnet.2007.11.001>.
142///
143/// This function is multithreaded and will run in parallel if the number
144/// of nodes in the graph is above the value of ``parallel_threshold``. If the
145/// function will be running in parallel the env var ``RAYON_NUM_THREADS`` can
146/// be used to adjust how many threads will be used.
147///
148/// Arguments:
149///
150/// * `graph` - The graph object to run the algorithm on
151/// * `normalized` - Whether to normalize the betweenness scores by the number
152///   of distinct paths between all pairs of nodes
153/// * `parallel_threshold` - The number of nodes to calculate the betweenness
154///   centrality in parallel at, if the number of nodes in `graph` is less
155///   than this value it will run in a single thread. A good default to use
156///   here if you're not sure is `50` as that was found to be roughly the
157///   number of nodes where parallelism improves performance
158///
159/// # Example
160/// ```rust
161/// use rustworkx_core::petgraph;
162/// use rustworkx_core::centrality::edge_betweenness_centrality;
163///
164/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
165///     (0, 4), (1, 2), (1, 3), (2, 3), (3, 4), (1, 4)
166/// ]);
167///
168/// let output = edge_betweenness_centrality(&g, false, 200);
169/// let expected = vec![Some(4.0), Some(2.0), Some(1.0), Some(2.0), Some(3.0), Some(3.0)];
170/// assert_eq!(output, expected);
171/// ```
172/// # See Also
173/// [`betweenness_centrality`]
174pub fn edge_betweenness_centrality<G>(
175    graph: G,
176    normalized: bool,
177    parallel_threshold: usize,
178) -> Vec<Option<f64>>
179where
180    G: NodeIndexable
181        + EdgeIndexable
182        + IntoEdges
183        + IntoNodeIdentifiers
184        + IntoNeighborsDirected
185        + NodeCount
186        + EdgeCount
187        + GraphProp
188        + Sync,
189    G::NodeId: Eq + Send,
190    G::EdgeId: Eq + Send,
191{
192    let max_index = graph.node_bound();
193    let mut betweenness = vec![None; graph.edge_bound()];
194    for edge in graph.edge_references() {
195        let is: usize = EdgeIndexable::to_index(&graph, edge.id());
196        betweenness[is] = Some(0.0);
197    }
198    let locked_betweenness = RwLock::new(&mut betweenness);
199    let node_indices: Vec<G::NodeId> = graph.node_identifiers().collect();
200    CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
201        .map(|node_s| shortest_path_for_edge_centrality(&graph, &node_s))
202        .for_each(|mut shortest_path_calc| {
203            accumulate_edges(
204                &locked_betweenness,
205                max_index,
206                &mut shortest_path_calc,
207                &graph,
208            );
209        });
210
211    _rescale(
212        &mut betweenness,
213        graph.node_count(),
214        normalized,
215        graph.is_directed(),
216        true,
217    );
218    betweenness
219}
220
221fn _rescale(
222    betweenness: &mut [Option<f64>],
223    node_count: usize,
224    normalized: bool,
225    directed: bool,
226    include_endpoints: bool,
227) {
228    let mut do_scale = true;
229    let mut scale = 1.0;
230    if normalized {
231        if include_endpoints {
232            if node_count < 2 {
233                do_scale = false;
234            } else {
235                scale = 1.0 / (node_count * (node_count - 1)) as f64;
236            }
237        } else if node_count <= 2 {
238            do_scale = false;
239        } else {
240            scale = 1.0 / ((node_count - 1) * (node_count - 2)) as f64;
241        }
242    } else if !directed {
243        scale = 0.5;
244    } else {
245        do_scale = false;
246    }
247    if do_scale {
248        for x in betweenness.iter_mut() {
249            *x = x.map(|y| y * scale);
250        }
251    }
252}
253
254fn _accumulate_vertices<G>(
255    locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
256    max_index: usize,
257    path_calc: &mut ShortestPathData<G>,
258    node_s: <G as GraphBase>::NodeId,
259    graph: G,
260    include_endpoints: bool,
261) where
262    G: NodeIndexable
263        + IntoNodeIdentifiers
264        + IntoNeighborsDirected
265        + NodeCount
266        + GraphProp
267        + GraphBase
268        + std::marker::Sync,
269    <G as GraphBase>::NodeId: std::cmp::Eq,
270{
271    let mut delta = vec![0.0; max_index];
272    for w in &path_calc.verts_sorted_by_distance {
273        let iw = graph.to_index(*w);
274        let coeff = (1.0 + delta[iw]) / path_calc.sigma[iw];
275        let p_w = path_calc.predecessors.get(iw).unwrap();
276        for iv in p_w {
277            delta[*iv] += path_calc.sigma[*iv] * coeff;
278        }
279    }
280    let mut betweenness = locked_betweenness.write().unwrap();
281    if include_endpoints {
282        let i_node_s = graph.to_index(node_s);
283        betweenness[i_node_s] = betweenness[i_node_s]
284            .map(|x| x + ((path_calc.verts_sorted_by_distance.len() - 1) as f64));
285        for w in &path_calc.verts_sorted_by_distance {
286            if *w != node_s {
287                let iw = graph.to_index(*w);
288                betweenness[iw] = betweenness[iw].map(|x| x + delta[iw] + 1.0);
289            }
290        }
291    } else {
292        for w in &path_calc.verts_sorted_by_distance {
293            if *w != node_s {
294                let iw = graph.to_index(*w);
295                betweenness[iw] = betweenness[iw].map(|x| x + delta[iw]);
296            }
297        }
298    }
299}
300
301fn accumulate_edges<G>(
302    locked_betweenness: &RwLock<&mut Vec<Option<f64>>>,
303    max_index: usize,
304    path_calc: &mut ShortestPathDataWithEdges<G>,
305    graph: G,
306) where
307    G: NodeIndexable + EdgeIndexable + Sync,
308    G::NodeId: Eq,
309    G::EdgeId: Eq,
310{
311    let mut delta = vec![0.0; max_index];
312    for w in &path_calc.verts_sorted_by_distance {
313        let iw = NodeIndexable::to_index(&graph, *w);
314        let coeff = (1.0 + delta[iw]) / path_calc.sigma[iw];
315        let p_w = path_calc.predecessors.get(iw).unwrap();
316        let e_w = path_calc.predecessor_edges.get(iw).unwrap();
317        let mut betweenness = locked_betweenness.write().unwrap();
318        for i in 0..p_w.len() {
319            let v = p_w[i];
320            let iv = NodeIndexable::to_index(&graph, v);
321            let ie = EdgeIndexable::to_index(&graph, e_w[i]);
322            let c = path_calc.sigma[iv] * coeff;
323            betweenness[ie] = betweenness[ie].map(|x| x + c);
324            delta[iv] += c;
325        }
326    }
327}
328/// Compute the degree centrality of all nodes in a graph.
329///
330/// For undirected graphs, this calculates the normalized degree for each node.
331/// For directed graphs, this calculates the normalized out-degree for each node.
332///
333/// Arguments:
334///
335/// * `graph` - The graph object to calculate degree centrality for
336///
337/// # Example
338/// ```rust
339/// use rustworkx_core::petgraph::graph::{UnGraph, DiGraph};
340/// use rustworkx_core::centrality::degree_centrality;
341///
342/// // Undirected graph example
343/// let graph = UnGraph::<i32, ()>::from_edges(&[
344///     (0, 1), (1, 2), (2, 3), (3, 0)
345/// ]);
346/// let centrality = degree_centrality(&graph, None);
347///
348/// // Directed graph example
349/// let digraph = DiGraph::<i32, ()>::from_edges(&[
350///     (0, 1), (1, 2), (2, 3), (3, 0), (0, 2), (1, 3)
351/// ]);
352/// let centrality = degree_centrality(&digraph, None);
353/// ```
354pub fn degree_centrality<G>(graph: G, direction: Option<petgraph::Direction>) -> Vec<f64>
355where
356    G: NodeIndexable
357        + IntoNodeIdentifiers
358        + IntoNeighbors
359        + IntoNeighborsDirected
360        + NodeCount
361        + GraphProp,
362    G::NodeId: Eq,
363{
364    let node_count = graph.node_count() as f64;
365    let mut centrality = vec![0.0; graph.node_bound()];
366
367    for node in graph.node_identifiers() {
368        let (degree, normalization) = match (graph.is_directed(), direction) {
369            (true, None) => {
370                let out_degree = graph
371                    .neighbors_directed(node, petgraph::Direction::Outgoing)
372                    .count() as f64;
373                let in_degree = graph
374                    .neighbors_directed(node, petgraph::Direction::Incoming)
375                    .count() as f64;
376                let total = in_degree + out_degree;
377                // Use 2(n-1) normalization only if this is a complete graph
378                let norm = if total == 2.0 * (node_count - 1.0) {
379                    2.0 * (node_count - 1.0)
380                } else {
381                    node_count - 1.0
382                };
383                (total, norm)
384            }
385            (true, Some(dir)) => (
386                graph.neighbors_directed(node, dir).count() as f64,
387                node_count - 1.0,
388            ),
389            (false, _) => (graph.neighbors(node).count() as f64, node_count - 1.0),
390        };
391        centrality[graph.to_index(node)] = degree / normalization;
392    }
393
394    centrality
395}
396
397struct ShortestPathData<G>
398where
399    G: GraphBase,
400    <G as GraphBase>::NodeId: std::cmp::Eq,
401{
402    verts_sorted_by_distance: Vec<G::NodeId>,
403    predecessors: Vec<Vec<usize>>,
404    sigma: Vec<f64>,
405}
406fn shortest_path_for_centrality<G>(graph: G, node_s: &G::NodeId) -> ShortestPathData<G>
407where
408    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighborsDirected + NodeCount + GraphBase,
409    <G as GraphBase>::NodeId: std::cmp::Eq,
410{
411    let c = graph.node_count();
412    let max_index = graph.node_bound();
413    let mut verts_sorted_by_distance: Vec<G::NodeId> = Vec::with_capacity(c); // a stack
414    let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); max_index];
415    let mut sigma: Vec<f64> = vec![0.; max_index];
416    let mut distance: Vec<Option<usize>> = vec![None; max_index];
417    #[allow(non_snake_case)]
418    let mut Q: VecDeque<G::NodeId> = VecDeque::with_capacity(c);
419    let node_s_index = graph.to_index(*node_s);
420    sigma[node_s_index] = 1.0;
421    distance[node_s_index] = Some(0);
422    Q.push_back(*node_s);
423    while let Some(v) = Q.pop_front() {
424        verts_sorted_by_distance.push(v);
425        let v_idx = graph.to_index(v);
426        let distance_v = distance[v_idx].unwrap();
427        for w in graph.neighbors(v) {
428            let w_idx = graph.to_index(w);
429            if distance[w_idx].is_none() {
430                Q.push_back(w);
431                distance[w_idx] = Some(distance_v + 1);
432            }
433            if distance[w_idx] == Some(distance_v + 1) {
434                sigma[w_idx] += sigma[v_idx];
435                predecessors[w_idx].push(v_idx);
436            }
437        }
438    }
439    verts_sorted_by_distance.reverse(); // will be effectively popping from the stack
440    ShortestPathData {
441        verts_sorted_by_distance,
442        predecessors,
443        sigma,
444    }
445}
446
447struct ShortestPathDataWithEdges<G>
448where
449    G: GraphBase,
450    G::NodeId: Eq,
451    G::EdgeId: Eq,
452{
453    verts_sorted_by_distance: Vec<G::NodeId>,
454    predecessors: Vec<Vec<G::NodeId>>,
455    predecessor_edges: Vec<Vec<G::EdgeId>>,
456    sigma: Vec<f64>,
457}
458
459fn shortest_path_for_edge_centrality<G>(
460    graph: G,
461    node_s: &G::NodeId,
462) -> ShortestPathDataWithEdges<G>
463where
464    G: NodeIndexable
465        + IntoNodeIdentifiers
466        + IntoNeighborsDirected
467        + NodeCount
468        + GraphBase
469        + IntoEdges,
470    G::NodeId: Eq,
471    G::EdgeId: Eq,
472{
473    let mut verts_sorted_by_distance: Vec<G::NodeId> = Vec::new(); // a stack
474    let c = graph.node_bound();
475    let mut predecessors = vec![Vec::new(); c];
476    let mut predecessor_edges = vec![Vec::new(); c];
477    let mut sigma = vec![0.0; c];
478    let mut distance: Vec<Option<usize>> = vec![None; c];
479    #[allow(non_snake_case)]
480    let mut Q: VecDeque<G::NodeId> = VecDeque::with_capacity(c);
481
482    sigma[graph.to_index(*node_s)] = 1.;
483    distance[graph.to_index(*node_s)] = Some(0);
484    Q.push_back(*node_s);
485    while let Some(v) = Q.pop_front() {
486        verts_sorted_by_distance.push(v);
487        let v_index = graph.to_index(v);
488        let distance_v = distance[v_index].unwrap();
489        for edge in graph.edges(v) {
490            let w = edge.target();
491            let w_index = graph.to_index(w);
492            if distance[w_index].is_none() {
493                Q.push_back(w);
494                distance[w_index] = Some(distance_v + 1);
495            }
496            if distance[w_index] == Some(distance_v + 1) {
497                sigma[w_index] += sigma[v_index];
498                let e_p = predecessors.get_mut(w_index).unwrap();
499                e_p.push(v);
500                predecessor_edges.get_mut(w_index).unwrap().push(edge.id());
501            }
502        }
503    }
504    verts_sorted_by_distance.reverse(); // will be effectively popping from the stack
505    ShortestPathDataWithEdges {
506        verts_sorted_by_distance,
507        predecessors,
508        predecessor_edges,
509        sigma,
510    }
511}
512
513#[cfg(test)]
514mod test_edge_betweenness_centrality {
515    use crate::centrality::edge_betweenness_centrality;
516    use petgraph::Undirected;
517    use petgraph::graph::edge_index;
518    use petgraph::prelude::StableGraph;
519
520    macro_rules! assert_almost_equal {
521        ($x:expr, $y:expr, $d:expr) => {
522            if ($x - $y).abs() >= $d {
523                panic!("{} != {} within delta of {}", $x, $y, $d);
524            }
525        };
526    }
527
528    #[test]
529    fn test_undirected_graph_normalized() {
530        let graph = petgraph::graph::UnGraph::<(), ()>::from_edges([
531            (0, 6),
532            (0, 4),
533            (0, 1),
534            (0, 5),
535            (1, 6),
536            (1, 7),
537            (1, 3),
538            (1, 4),
539            (2, 6),
540            (2, 3),
541            (3, 5),
542            (3, 7),
543            (3, 6),
544            (4, 5),
545            (5, 6),
546        ]);
547        let output = edge_betweenness_centrality(&graph, true, 50);
548        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
549        let expected_values = [
550            0.1023809, 0.0547619, 0.0922619, 0.05654762, 0.09940476, 0.125, 0.09940476, 0.12440476,
551            0.12857143, 0.12142857, 0.13511905, 0.125, 0.06547619, 0.08869048, 0.08154762,
552        ];
553        for i in 0..15 {
554            assert_almost_equal!(result[i], expected_values[i], 1e-4);
555        }
556    }
557
558    #[test]
559    fn test_undirected_graph_unnormalized() {
560        let graph = petgraph::graph::UnGraph::<(), ()>::from_edges([
561            (0, 2),
562            (0, 4),
563            (0, 1),
564            (1, 3),
565            (1, 5),
566            (1, 7),
567            (2, 7),
568            (2, 3),
569            (3, 5),
570            (3, 6),
571            (4, 6),
572            (5, 7),
573        ]);
574        let output = edge_betweenness_centrality(&graph, false, 50);
575        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
576        let expected_values = [
577            3.83333, 5.5, 5.33333, 3.5, 2.5, 3.0, 3.5, 4.0, 3.66667, 6.5, 3.5, 2.16667,
578        ];
579        for i in 0..12 {
580            assert_almost_equal!(result[i], expected_values[i], 1e-4);
581        }
582    }
583
584    #[test]
585    fn test_directed_graph_normalized() {
586        let graph = petgraph::graph::DiGraph::<(), ()>::from_edges([
587            (0, 1),
588            (1, 0),
589            (1, 3),
590            (1, 2),
591            (1, 4),
592            (2, 3),
593            (2, 4),
594            (2, 1),
595            (3, 2),
596            (4, 3),
597        ]);
598        let output = edge_betweenness_centrality(&graph, true, 50);
599        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
600        let expected_values = [0.2, 0.2, 0.1, 0.1, 0.1, 0.05, 0.1, 0.3, 0.35, 0.2];
601        for i in 0..10 {
602            assert_almost_equal!(result[i], expected_values[i], 1e-4);
603        }
604    }
605
606    #[test]
607    fn test_directed_graph_unnormalized() {
608        let graph = petgraph::graph::DiGraph::<(), ()>::from_edges([
609            (0, 4),
610            (1, 0),
611            (1, 3),
612            (2, 3),
613            (2, 4),
614            (2, 0),
615            (3, 4),
616            (3, 2),
617            (3, 1),
618            (4, 1),
619        ]);
620        let output = edge_betweenness_centrality(&graph, false, 50);
621        let result = output.iter().map(|x| x.unwrap()).collect::<Vec<f64>>();
622        let expected_values = [4.5, 3.0, 6.5, 1.5, 1.5, 1.5, 1.5, 4.5, 2.0, 7.5];
623        for i in 0..10 {
624            assert_almost_equal!(result[i], expected_values[i], 1e-4);
625        }
626    }
627
628    #[test]
629    fn test_stable_graph_with_removed_edges() {
630        let mut graph: StableGraph<(), (), Undirected> =
631            StableGraph::from_edges([(0, 1), (1, 2), (2, 3), (3, 0)]);
632        graph.remove_edge(edge_index(1));
633        let result = edge_betweenness_centrality(&graph, false, 50);
634        let expected_values = vec![Some(3.0), None, Some(3.0), Some(4.0)];
635        assert_eq!(result, expected_values);
636    }
637
638    #[test]
639    fn test_stable_graph_with_removed_nodes_and_edges() {
640        let mut graph: StableGraph<(), (), Undirected> = StableGraph::default();
641        let n0 = graph.add_node(());
642        let d0 = graph.add_node(());
643        let n1 = graph.add_node(());
644        let d1 = graph.add_node(());
645        let n2 = graph.add_node(());
646        let d2 = graph.add_node(());
647        let n3 = graph.add_node(());
648
649        graph.remove_node(d0);
650        graph.remove_node(d1);
651        graph.remove_node(d2);
652
653        graph.add_edge(n0, n1, ());
654        graph.add_edge(n1, n2, ());
655        graph.add_edge(n2, n3, ());
656        graph.add_edge(n3, n0, ());
657
658        graph.remove_edge(edge_index(1));
659        let result = edge_betweenness_centrality(&graph, false, 50);
660        let expected_values = vec![Some(3.0), None, Some(3.0), Some(4.0)];
661        assert_eq!(result, expected_values);
662    }
663}
664
665#[cfg(test)]
666mod test_betweenness_centrality {
667    use crate::centrality::betweenness_centrality;
668    use petgraph::Undirected;
669    use petgraph::graph::edge_index;
670    use petgraph::prelude::StableGraph;
671
672    #[test]
673    fn test_stable_graph_with_removed_nodes_and_edges() {
674        let mut graph: StableGraph<(), (), Undirected> = StableGraph::default();
675        let n0 = graph.add_node(());
676        let d0 = graph.add_node(());
677        let n1 = graph.add_node(());
678        let d1 = graph.add_node(());
679        let n2 = graph.add_node(());
680        let d2 = graph.add_node(());
681        let n3 = graph.add_node(());
682
683        graph.remove_node(d0);
684        graph.remove_node(d1);
685        graph.remove_node(d2);
686
687        graph.add_edge(n0, n1, ());
688        graph.add_edge(n1, n2, ());
689        graph.add_edge(n2, n3, ());
690        graph.add_edge(n3, n0, ());
691        graph.remove_edge(edge_index(1));
692
693        let result = betweenness_centrality(&graph, false, false, 50);
694        let expected_values = vec![Some(2.0), None, Some(0.0), None, Some(0.0), None, Some(2.0)];
695        assert_eq!(result, expected_values);
696    }
697}
698
699/// Compute the eigenvector centrality of a graph
700///
701/// For details on the eigenvector centrality refer to:
702///
703/// Phillip Bonacich. “Power and Centrality: A Family of Measures.”
704/// American Journal of Sociology 92(5):1170–1182, 1986
705/// <https://doi.org/10.1086/228631>
706///
707/// This function uses a power iteration method to compute the eigenvector
708/// and convergence is not guaranteed. The function will stop when `max_iter`
709/// iterations is reached or when the computed vector between two iterations
710/// is smaller than the error tolerance multiplied by the number of nodes.
711/// The implementation of this algorithm is based on the NetworkX
712/// [`eigenvector_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.eigenvector_centrality.html)
713/// function.
714///
715/// In the case of multigraphs the weights of any parallel edges will be
716/// summed when computing the eigenvector centrality.
717///
718/// Arguments:
719///
720/// * `graph` - The graph object to run the algorithm on
721/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for
722///   an edge in the graph and is expected to return a `Result<f64>` of
723///   the weight of that edge.
724/// * `max_iter` - The maximum number of iterations in the power method. If
725///   set to `None` a default value of 100 is used.
726/// * `tol` - The error tolerance used when checking for convergence in the
727///   power method. If set to `None` a default value of 1e-6 is used.
728///
729/// # Example
730/// ```rust
731/// use rustworkx_core::Result;
732/// use rustworkx_core::petgraph;
733/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers};
734/// use rustworkx_core::centrality::eigenvector_centrality;
735///
736/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
737///     (0, 1), (1, 2)
738/// ]);
739/// // Calculate the eigenvector centrality
740/// let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| {Ok(1.)}, None, None);
741/// ```
742pub fn eigenvector_centrality<G, F, E>(
743    graph: G,
744    mut weight_fn: F,
745    max_iter: Option<usize>,
746    tol: Option<f64>,
747) -> Result<Option<Vec<f64>>, E>
748where
749    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount,
750    G::NodeId: Eq,
751    F: FnMut(G::EdgeRef) -> Result<f64, E>,
752{
753    let tol: f64 = tol.unwrap_or(1e-6);
754    let max_iter = max_iter.unwrap_or(100);
755    let mut x: Vec<f64> = vec![1.; graph.node_bound()];
756    let node_count = graph.node_count();
757    for _ in 0..max_iter {
758        let x_last = x.clone();
759        for node_index in graph.node_identifiers() {
760            let node = graph.to_index(node_index);
761            for edge in graph.edges(node_index) {
762                let w = weight_fn(edge)?;
763                let neighbor = edge.target();
764                x[graph.to_index(neighbor)] += x_last[node] * w;
765            }
766        }
767        let norm: f64 = x.iter().map(|val| val.powi(2)).sum::<f64>().sqrt();
768        if norm == 0. {
769            return Ok(None);
770        }
771        for v in x.iter_mut() {
772            *v /= norm;
773        }
774        if (0..x.len())
775            .map(|node| (x[node] - x_last[node]).abs())
776            .sum::<f64>()
777            < node_count as f64 * tol
778        {
779            return Ok(Some(x));
780        }
781    }
782    Ok(None)
783}
784
785/// Compute the Katz centrality of a graph
786///
787/// For details on the Katz centrality refer to:
788///
789/// Leo Katz. “A New Status Index Derived from Sociometric Index.”
790/// Psychometrika 18(1):39–43, 1953
791/// <https://link.springer.com/content/pdf/10.1007/BF02289026.pdf>
792///
793/// This function uses a power iteration method to compute the eigenvector
794/// and convergence is not guaranteed. The function will stop when `max_iter`
795/// iterations is reached or when the computed vector between two iterations
796/// is smaller than the error tolerance multiplied by the number of nodes.
797/// The implementation of this algorithm is based on the NetworkX
798/// [`katz_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.katz_centrality.html)
799/// function.
800///
801/// In the case of multigraphs the weights of any parallel edges will be
802/// summed when computing the eigenvector centrality.
803///
804/// Arguments:
805///
806/// * `graph` - The graph object to run the algorithm on
807/// * `weight_fn` - An input callable that will be passed the `EdgeRef` for
808///   an edge in the graph and is expected to return a `Result<f64>` of
809///   the weight of that edge.
810/// * `alpha` - Attenuation factor. If set to `None`, a default value of 0.1 is used.
811/// * `beta_map` - Immediate neighbourhood weights. Must contain all node indices or be `None`.
812/// * `beta_scalar` - Immediate neighbourhood scalar that replaces `beta_map` in case `beta_map` is None.
813///   Defaults to 1.0 in case `None` is provided.
814/// * `max_iter` - The maximum number of iterations in the power method. If
815///   set to `None` a default value of 100 is used.
816/// * `tol` - The error tolerance used when checking for convergence in the
817///   power method. If set to `None` a default value of 1e-6 is used.
818///
819/// # Example
820/// ```rust
821/// use rustworkx_core::Result;
822/// use rustworkx_core::petgraph;
823/// use rustworkx_core::petgraph::visit::{IntoEdges, IntoNodeIdentifiers};
824/// use rustworkx_core::centrality::katz_centrality;
825///
826/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
827///     (0, 1), (1, 2)
828/// ]);
829/// // Calculate the eigenvector centrality
830/// let output: Result<Option<Vec<f64>>> = katz_centrality(&g, |_| {Ok(1.)}, None, None, None, None, None);
831/// let centralities = output.unwrap().unwrap();
832/// assert!(centralities[1] > centralities[0], "Node 1 is more central than node 0");
833/// assert!(centralities[1] > centralities[2], "Node 1 is more central than node 2");
834/// ```
835pub fn katz_centrality<G, F, E>(
836    graph: G,
837    mut weight_fn: F,
838    alpha: Option<f64>,
839    beta_map: Option<HashMap<usize, f64>>,
840    beta_scalar: Option<f64>,
841    max_iter: Option<usize>,
842    tol: Option<f64>,
843) -> Result<Option<Vec<f64>>, E>
844where
845    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighbors + IntoEdges + NodeCount,
846    G::NodeId: Eq,
847    F: FnMut(G::EdgeRef) -> Result<f64, E>,
848{
849    let alpha: f64 = alpha.unwrap_or(0.1);
850
851    let beta: HashMap<usize, f64> = beta_map.unwrap_or_default();
852    //Initialize the beta vector in case a beta map was not provided
853    let mut beta_v = vec![beta_scalar.unwrap_or(1.0); graph.node_bound()];
854
855    if !beta.is_empty() {
856        // Check if beta contains all node indices
857        for node_index in graph.node_identifiers() {
858            let node = graph.to_index(node_index);
859            if !beta.contains_key(&node) {
860                return Ok(None); // beta_map was provided but did not include all nodes
861            }
862            beta_v[node] = *beta.get(&node).unwrap(); //Initialize the beta vector with the provided values
863        }
864    }
865
866    let tol: f64 = tol.unwrap_or(1e-6);
867    let max_iter = max_iter.unwrap_or(1000);
868
869    let mut x: Vec<f64> = vec![0.; graph.node_bound()];
870    let node_count = graph.node_count();
871    for _ in 0..max_iter {
872        let x_last = x.clone();
873        x = vec![0.; graph.node_bound()];
874        for node_index in graph.node_identifiers() {
875            let node = graph.to_index(node_index);
876            for edge in graph.edges(node_index) {
877                let w = weight_fn(edge)?;
878                let neighbor = edge.target();
879                x[graph.to_index(neighbor)] += x_last[node] * w;
880            }
881        }
882        for node_index in graph.node_identifiers() {
883            let node = graph.to_index(node_index);
884            x[node] = alpha * x[node] + beta_v[node];
885        }
886        if (0..x.len())
887            .map(|node| (x[node] - x_last[node]).abs())
888            .sum::<f64>()
889            < node_count as f64 * tol
890        {
891            // Normalize vector
892            let norm: f64 = x.iter().map(|val| val.powi(2)).sum::<f64>().sqrt();
893            if norm == 0. {
894                return Ok(None);
895            }
896            for v in x.iter_mut() {
897                *v /= norm;
898            }
899
900            return Ok(Some(x));
901        }
902    }
903
904    Ok(None)
905}
906
907#[cfg(test)]
908mod test_eigenvector_centrality {
909
910    use crate::Result;
911    use crate::centrality::eigenvector_centrality;
912    use crate::petgraph;
913
914    macro_rules! assert_almost_equal {
915        ($x:expr, $y:expr, $d:expr) => {
916            if ($x - $y).abs() >= $d {
917                panic!("{} != {} within delta of {}", $x, $y, $d);
918            }
919        };
920    }
921    #[test]
922    fn test_no_convergence() {
923        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
924        let output: Result<Option<Vec<f64>>> =
925            eigenvector_centrality(&g, |_| Ok(1.), Some(0), None);
926        let result = output.unwrap();
927        assert_eq!(None, result);
928    }
929
930    #[test]
931    fn test_undirected_complete_graph() {
932        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
933            (0, 1),
934            (0, 2),
935            (0, 3),
936            (0, 4),
937            (1, 2),
938            (1, 3),
939            (1, 4),
940            (2, 3),
941            (2, 4),
942            (3, 4),
943        ]);
944        let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| Ok(1.), None, None);
945        let result = output.unwrap().unwrap();
946        let expected_value: f64 = (1_f64 / 5_f64).sqrt();
947        let expected_values: Vec<f64> = vec![expected_value; 5];
948        for i in 0..5 {
949            assert_almost_equal!(expected_values[i], result[i], 1e-4);
950        }
951    }
952
953    #[test]
954    fn test_undirected_path_graph() {
955        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
956        let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| Ok(1.), None, None);
957        let result = output.unwrap().unwrap();
958        let expected_values: Vec<f64> = vec![0.5, std::f64::consts::FRAC_1_SQRT_2, 0.5];
959        for i in 0..3 {
960            assert_almost_equal!(expected_values[i], result[i], 1e-4);
961        }
962    }
963
964    #[test]
965    fn test_directed_graph() {
966        let g = petgraph::graph::DiGraph::<i32, ()>::from_edges([
967            (0, 1),
968            (0, 2),
969            (1, 3),
970            (2, 1),
971            (2, 4),
972            (3, 1),
973            (3, 4),
974            (3, 5),
975            (4, 5),
976            (4, 6),
977            (4, 7),
978            (5, 7),
979            (6, 0),
980            (6, 4),
981            (6, 7),
982            (7, 5),
983            (7, 6),
984        ]);
985        let output: Result<Option<Vec<f64>>> = eigenvector_centrality(&g, |_| Ok(2.), None, None);
986        let result = output.unwrap().unwrap();
987        let expected_values: Vec<f64> = vec![
988            0.2140437, 0.2009269, 0.1036383, 0.0972886, 0.3113323, 0.4891686, 0.4420605, 0.6016448,
989        ];
990        for i in 0..8 {
991            assert_almost_equal!(expected_values[i], result[i], 1e-4);
992        }
993    }
994}
995
996#[cfg(test)]
997mod test_katz_centrality {
998
999    use crate::Result;
1000    use crate::centrality::katz_centrality;
1001    use crate::petgraph;
1002    use hashbrown::HashMap;
1003
1004    macro_rules! assert_almost_equal {
1005        ($x:expr, $y:expr, $d:expr) => {
1006            if ($x - $y).abs() >= $d {
1007                panic!("{} != {} within delta of {}", $x, $y, $d);
1008            }
1009        };
1010    }
1011    #[test]
1012    fn test_no_convergence() {
1013        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
1014        let output: Result<Option<Vec<f64>>> =
1015            katz_centrality(&g, |_| Ok(1.), None, None, None, Some(0), None);
1016        let result = output.unwrap();
1017        assert_eq!(None, result);
1018    }
1019
1020    #[test]
1021    fn test_incomplete_beta() {
1022        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
1023        let beta_map: HashMap<usize, f64> = [(0, 1.0)].iter().cloned().collect();
1024        let output: Result<Option<Vec<f64>>> =
1025            katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None);
1026        let result = output.unwrap();
1027        assert_eq!(None, result);
1028    }
1029
1030    #[test]
1031    fn test_complete_beta() {
1032        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([(0, 1), (1, 2)]);
1033        let beta_map: HashMap<usize, f64> =
1034            [(0, 0.5), (1, 1.0), (2, 0.5)].iter().cloned().collect();
1035        let output: Result<Option<Vec<f64>>> =
1036            katz_centrality(&g, |_| Ok(1.), None, Some(beta_map), None, None, None);
1037        let result = output.unwrap().unwrap();
1038        let expected_values: Vec<f64> =
1039            vec![0.4318894504492167, 0.791797325823564, 0.4318894504492167];
1040        for i in 0..3 {
1041            assert_almost_equal!(expected_values[i], result[i], 1e-4);
1042        }
1043    }
1044
1045    #[test]
1046    fn test_undirected_complete_graph() {
1047        let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
1048            (0, 1),
1049            (0, 2),
1050            (0, 3),
1051            (0, 4),
1052            (1, 2),
1053            (1, 3),
1054            (1, 4),
1055            (2, 3),
1056            (2, 4),
1057            (3, 4),
1058        ]);
1059        let output: Result<Option<Vec<f64>>> =
1060            katz_centrality(&g, |_| Ok(1.), Some(0.2), None, Some(1.1), None, None);
1061        let result = output.unwrap().unwrap();
1062        let expected_value: f64 = (1_f64 / 5_f64).sqrt();
1063        let expected_values: Vec<f64> = vec![expected_value; 5];
1064        for i in 0..5 {
1065            assert_almost_equal!(expected_values[i], result[i], 1e-4);
1066        }
1067    }
1068
1069    #[test]
1070    fn test_directed_graph() {
1071        let g = petgraph::graph::DiGraph::<i32, ()>::from_edges([
1072            (0, 1),
1073            (0, 2),
1074            (1, 3),
1075            (2, 1),
1076            (2, 4),
1077            (3, 1),
1078            (3, 4),
1079            (3, 5),
1080            (4, 5),
1081            (4, 6),
1082            (4, 7),
1083            (5, 7),
1084            (6, 0),
1085            (6, 4),
1086            (6, 7),
1087            (7, 5),
1088            (7, 6),
1089        ]);
1090        let output: Result<Option<Vec<f64>>> =
1091            katz_centrality(&g, |_| Ok(1.), None, None, None, None, None);
1092        let result = output.unwrap().unwrap();
1093        let expected_values: Vec<f64> = vec![
1094            0.3135463087489011,
1095            0.3719056758615039,
1096            0.3094350787809586,
1097            0.31527101632646026,
1098            0.3760169058294464,
1099            0.38618584417917906,
1100            0.35465874858087904,
1101            0.38976653416801743,
1102        ];
1103
1104        for i in 0..8 {
1105            assert_almost_equal!(expected_values[i], result[i], 1e-4);
1106        }
1107    }
1108}
1109
1110/// Compute the closeness centrality of each node in the graph.
1111///
1112/// The closeness centrality of a node `u` is the reciprocal of the average
1113/// shortest path distance to `u` over all `n-1` reachable nodes.
1114///
1115/// In the case of a graphs with more than one connected component there is
1116/// an alternative improved formula that calculates the closeness centrality
1117/// as "a ratio of the fraction of actors in the group who are reachable, to
1118/// the average distance".[^WF]
1119/// You can enable this by setting `wf_improved` to `true`.
1120///
1121/// [^WF]: Wasserman, S., & Faust, K. (1994). Social Network Analysis:
1122///     Methods and Applications (Structural Analysis in the Social Sciences).
1123///     Cambridge: Cambridge University Press.
1124///     <https://doi.org/10.1017/CBO9780511815478>
1125///
1126/// This function is multithreaded and will run in parallel if the number
1127/// of nodes in the graph is above the value of ``parallel_threshold``. If the
1128/// function will be running in parallel the env var ``RAYON_NUM_THREADS`` can
1129/// be used to adjust how many threads will be used.
1130///
1131/// # Arguments
1132///
1133/// * `graph` - The graph object to run the algorithm on
1134/// * `wf_improved` - If `true`, scale by the fraction of nodes reachable.
1135/// * `parallel_threshold` - The number of nodes to calculate the betweenness
1136///   centrality in parallel at, if the number of nodes in `graph` is less
1137///   than this value it will run in a single thread. The suggested default to use
1138///   here is `50`.
1139///
1140/// # Example
1141///
1142/// ```rust
1143/// use rustworkx_core::petgraph;
1144/// use rustworkx_core::centrality::closeness_centrality;
1145///
1146/// // Calculate the closeness centrality of Graph
1147/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges(&[
1148///     (0, 4), (1, 2), (2, 3), (3, 4), (1, 4)
1149/// ]);
1150/// let output = closeness_centrality(&g, true, 200);
1151/// assert_eq!(
1152///     vec![Some(1./2.), Some(2./3.), Some(4./7.), Some(2./3.), Some(4./5.)],
1153///     output
1154/// );
1155///
1156/// // Calculate the closeness centrality of DiGraph
1157/// let dg = petgraph::graph::DiGraph::<i32, ()>::from_edges(&[
1158///     (0, 4), (1, 2), (2, 3), (3, 4), (1, 4)
1159/// ]);
1160/// let output = closeness_centrality(&dg, true, 200);
1161/// assert_eq!(
1162///     vec![Some(0.), Some(0.), Some(1./4.), Some(1./3.), Some(4./5.)],
1163///     output
1164/// );
1165/// ```
1166pub fn closeness_centrality<G>(
1167    graph: G,
1168    wf_improved: bool,
1169    parallel_threshold: usize,
1170) -> Vec<Option<f64>>
1171where
1172    G: NodeIndexable
1173        + IntoNodeIdentifiers
1174        + GraphBase
1175        + IntoEdges
1176        + Visitable
1177        + NodeCount
1178        + IntoEdgesDirected
1179        + std::marker::Sync,
1180    G::NodeId: Eq + Hash + Send,
1181    G::EdgeId: Eq + Hash + Send,
1182{
1183    let max_index = graph.node_bound();
1184    let mut node_indices: Vec<Option<G::NodeId>> = vec![None; max_index];
1185    graph.node_identifiers().for_each(|node| {
1186        let index = graph.to_index(node);
1187        node_indices[index] = Some(node);
1188    });
1189
1190    let unweighted_shortest_path = |g: Reversed<&G>, s: G::NodeId| -> HashMap<G::NodeId, usize> {
1191        let mut distances = HashMap::new();
1192        let mut bfs = Bfs::new(g, s);
1193        distances.insert(s, 0);
1194        while let Some(node) = bfs.next(g) {
1195            let distance = distances[&node];
1196            for edge in g.edges(node) {
1197                let target = edge.target();
1198                distances.entry(target).or_insert(distance + 1);
1199            }
1200        }
1201        distances
1202    };
1203
1204    let closeness: Vec<Option<f64>> =
1205        CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
1206            .map(|node_s| {
1207                let node_s = node_s?;
1208                let map = unweighted_shortest_path(Reversed(&graph), node_s);
1209                let reachable_nodes_count = map.len();
1210                let dists_sum: usize = map.into_values().sum();
1211                if reachable_nodes_count == 1 {
1212                    return Some(0.0);
1213                }
1214                let mut centrality_s = Some((reachable_nodes_count - 1) as f64 / dists_sum as f64);
1215                if wf_improved {
1216                    let node_count = graph.node_count();
1217                    centrality_s = centrality_s
1218                        .map(|c| c * (reachable_nodes_count - 1) as f64 / (node_count - 1) as f64);
1219                }
1220                centrality_s
1221            })
1222            .collect();
1223    closeness
1224}
1225
1226/// Compute the weighted closeness centrality of each node in the graph.
1227///
1228/// The weighted closeness centrality is an extension of the standard closeness
1229/// centrality measure where edge weights represent connection strength rather
1230/// than distance. To properly compute shortest paths, weights are inverted
1231/// so that stronger connections correspond to shorter effective distances.
1232/// The algorithm follows the method described by Newman (2001) in analyzing
1233/// weighted graphs.[^Newman]
1234///
1235/// The edges originally represent connection strength between nodes.
1236/// The idea is that if two nodes have a strong connection, the computed
1237/// distance between them should be small (shorter), and vice versa.
1238/// Note that this assume that the graph is modelling a measure of
1239/// connection strength (e.g. trust, collaboration, or similarity).
1240/// If the graph is not modelling a measure of connection strength,
1241/// the function `weight_fn` should invert the weights before calling this
1242/// function, if not it is considered as a logical error.
1243///
1244/// In the case of a graphs with more than one connected component there is
1245/// an alternative improved formula that calculates the closeness centrality
1246/// as "a ratio of the fraction of actors in the group who are reachable, to
1247/// the average distance".[^WF]
1248/// You can enable this by setting `wf_improved` to `true`.
1249///
1250/// [^Newman]: Newman, M. E. J. (2001). Scientific collaboration networks.
1251///     II. Shortest paths, weighted networks, and centrality.
1252///     Physical Review E, 64(1), 016132.
1253///
1254/// [^WF]: Wasserman, S., & Faust, K. (1994). Social Network Analysis:
1255///     Methods and Applications (Structural Analysis in the Social Sciences).
1256///     Cambridge: Cambridge University Press.
1257///     <https://doi.org/10.1017/CBO9780511815478>
1258///
1259/// This function is multithreaded and will run in parallel if the number
1260/// of nodes in the graph is above the value of ``parallel_threshold``. If the
1261/// function will be running in parallel the env var ``RAYON_NUM_THREADS`` can
1262/// be used to adjust how many threads will be used.
1263///
1264/// # Arguments
1265/// * `graph` - The graph object to run the algorithm on
1266/// * `wf_improved` - If `true`, scale by the fraction of nodes reachable.
1267/// * `weight_fn` - An input callable that will be passed the
1268///   `ReversedEdgeReference<<G as IntoEdgeReferences>::EdgeRef>` for
1269///   an edge in the graph and is expected to return a `f64` of
1270///   the weight of that edge.
1271/// * `parallel_threshold` - The number of nodes to calculate the betweenness
1272///   centrality in parallel at, if the number of nodes in `graph` is less
1273///   than this value it will run in a single thread. The suggested default to use
1274///   here is `50`.
1275///
1276/// # Example
1277///
1278/// ```rust
1279/// use rustworkx_core::petgraph;
1280/// use rustworkx_core::centrality::newman_weighted_closeness_centrality;
1281/// use crate::rustworkx_core::petgraph::visit::EdgeRef;
1282///
1283/// // Calculate the closeness centrality of Graph
1284/// let g = petgraph::graph::UnGraph::<i32, f64>::from_edges(&[
1285///     (0, 1, 0.7), (1, 2, 0.2), (2, 3, 0.5),
1286/// ]);
1287/// let output = newman_weighted_closeness_centrality(&g, false, |x| *x.weight(), 200);
1288/// assert!(output[1] > output[3]);
1289///
1290/// // Calculate the closeness centrality of DiGraph
1291/// let g = petgraph::graph::DiGraph::<i32, f64>::from_edges(&[
1292///     (0, 1, 0.7), (1, 2, 0.2), (2, 3, 0.5),
1293/// ]);
1294/// let output = newman_weighted_closeness_centrality(&g, false, |x| *x.weight(), 200);
1295/// assert!(output[1] > output[3]);
1296/// ```
1297pub fn newman_weighted_closeness_centrality<G, F>(
1298    graph: G,
1299    wf_improved: bool,
1300    weight_fn: F,
1301    parallel_threshold: usize,
1302) -> Vec<Option<f64>>
1303where
1304    G: NodeIndexable
1305        + IntoNodeIdentifiers
1306        + GraphBase
1307        + IntoEdges
1308        + Visitable
1309        + NodeCount
1310        + IntoEdgesDirected
1311        + std::marker::Sync,
1312    G::NodeId: Eq + Hash + Send,
1313    G::EdgeId: Eq + Hash + Send,
1314    F: Fn(ReversedEdgeReference<<G as IntoEdgeReferences>::EdgeRef>) -> f64 + Sync,
1315{
1316    // The edges originally represent `connection strength` between nodes.
1317    // As shown in the paper, the weight of the edges should be inverted to
1318    // ensure that stronger ties correspond to shorter effective distances.
1319    // The idea is that if two nodes have a strong connection, the computed
1320    // distance between them should be small (shorter), and vice versa.
1321    //
1322    // Note that this assume that the graph is modelling a measure of
1323    // connection strength (e.g. trust, collaboration, or similarity).
1324    // If the graph is not modelling a measure of connection strength,
1325    // the user should invert the weights before calling this function,
1326    // if not it is considered as a logical error.
1327    let inverted_weight_fn =
1328        |x: ReversedEdgeReference<<G as IntoEdgeReferences>::EdgeRef>| 1.0 / weight_fn(x);
1329
1330    let max_index = graph.node_bound();
1331    let mut node_indices: Vec<Option<G::NodeId>> = vec![None; max_index];
1332    graph.node_identifiers().for_each(|node| {
1333        let index = graph.to_index(node);
1334        node_indices[index] = Some(node);
1335    });
1336
1337    let closeness: Vec<Option<f64>> =
1338        CondIterator::new(node_indices, graph.node_count() >= parallel_threshold)
1339            .map(|node_s| {
1340                let node_s = node_s?;
1341                let map = dijkstra(Reversed(&graph), node_s, None, &inverted_weight_fn);
1342                let reachable_nodes_count = map.len();
1343                let dists_sum: f64 = map.into_values().sum();
1344                if reachable_nodes_count == 1 {
1345                    return Some(0.0);
1346                }
1347                let mut centrality_s = Some((reachable_nodes_count - 1) as f64 / dists_sum as f64);
1348                if wf_improved {
1349                    let node_count = graph.node_count();
1350                    centrality_s = centrality_s
1351                        .map(|c| c * (reachable_nodes_count - 1) as f64 / (node_count - 1) as f64);
1352                }
1353                centrality_s
1354            })
1355            .collect();
1356    closeness
1357}
1358
1359#[cfg(test)]
1360mod test_newman_weighted_closeness_centrality {
1361    use crate::centrality::closeness_centrality;
1362
1363    use super::newman_weighted_closeness_centrality;
1364    use petgraph::visit::EdgeRef;
1365
1366    macro_rules! assert_almost_equal {
1367        ($x:expr, $y:expr, $d:expr) => {
1368            if ($x - $y).abs() >= $d {
1369                panic!("{} != {} within delta of {}", $x, $y, $d);
1370            }
1371        };
1372    }
1373
1374    macro_rules! assert_almost_equal_iter {
1375        ($expected:expr, $computed:expr, $tolerance:expr) => {
1376            for (&expected, &computed) in $expected.iter().zip($computed.iter()) {
1377                assert_almost_equal!(expected.unwrap(), computed.unwrap(), $tolerance);
1378            }
1379        };
1380    }
1381
1382    #[test]
1383    fn test_weighted_closeness_graph() {
1384        let test_case = |parallel_threshold: usize| {
1385            let g = petgraph::graph::UnGraph::<u32, f64>::from_edges([
1386                (0, 1, 1.0),
1387                (1, 2, 1.0),
1388                (2, 3, 1.0),
1389                (3, 4, 1.0),
1390                (4, 5, 1.0),
1391                (5, 6, 1.0),
1392            ]);
1393            let classic_closeness = closeness_centrality(&g, false, parallel_threshold);
1394            let weighted_closeness = newman_weighted_closeness_centrality(
1395                &g,
1396                false,
1397                |x| *x.weight(),
1398                parallel_threshold,
1399            );
1400
1401            assert_eq!(classic_closeness, weighted_closeness);
1402        };
1403        test_case(200); // sequential
1404        test_case(1); // parallel
1405    }
1406
1407    #[test]
1408    fn test_the_same_as_closeness_centrality_when_weights_are_1_not_improved_digraph() {
1409        let test_case = |parallel_threshold: usize| {
1410            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1411                (0, 1, 1.0),
1412                (1, 2, 1.0),
1413                (2, 3, 1.0),
1414                (3, 4, 1.0),
1415                (4, 5, 1.0),
1416                (5, 6, 1.0),
1417            ]);
1418            let classic_closeness = closeness_centrality(&g, false, parallel_threshold);
1419            let weighted_closeness = newman_weighted_closeness_centrality(
1420                &g,
1421                false,
1422                |x| *x.weight(),
1423                parallel_threshold,
1424            );
1425
1426            assert_eq!(classic_closeness, weighted_closeness);
1427        };
1428        test_case(200); // sequential
1429        test_case(1); // parallel
1430    }
1431
1432    #[test]
1433    fn test_the_same_as_closeness_centrality_when_weights_are_1_improved_digraph() {
1434        let test_case = |parallel_threshold: usize| {
1435            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1436                (0, 1, 1.0),
1437                (1, 2, 1.0),
1438                (2, 3, 1.0),
1439                (3, 4, 1.0),
1440                (4, 5, 1.0),
1441                (5, 6, 1.0),
1442            ]);
1443            let classic_closeness = closeness_centrality(&g, true, parallel_threshold);
1444            let weighted_closeness =
1445                newman_weighted_closeness_centrality(&g, true, |x| *x.weight(), parallel_threshold);
1446
1447            assert_eq!(classic_closeness, weighted_closeness);
1448        };
1449        test_case(200); // sequential
1450        test_case(1); // parallel
1451    }
1452
1453    #[test]
1454    fn test_weighted_closeness_two_connected_components_not_improved_digraph() {
1455        let test_case = |parallel_threshold: usize| {
1456            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1457                (0, 1, 1.0),
1458                (1, 2, 0.5),
1459                (2, 3, 0.25),
1460                (4, 5, 1.0),
1461                (5, 6, 0.5),
1462                (6, 7, 0.25),
1463            ]);
1464            let c = newman_weighted_closeness_centrality(
1465                &g,
1466                false,
1467                |x| *x.weight(),
1468                parallel_threshold,
1469            );
1470            let result = [
1471                Some(0.0),
1472                Some(1.0),
1473                Some(0.4),
1474                Some(0.176470),
1475                Some(0.0),
1476                Some(1.0),
1477                Some(0.4),
1478                Some(0.176470),
1479            ];
1480
1481            assert_almost_equal_iter!(result, c, 1e-4);
1482        };
1483        test_case(200); // sequential
1484        test_case(1); // parallel
1485    }
1486
1487    #[test]
1488    fn test_weighted_closeness_two_connected_components_improved_digraph() {
1489        let test_case = |parallel_threshold: usize| {
1490            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1491                (0, 1, 1.0),
1492                (1, 2, 0.5),
1493                (2, 3, 0.25),
1494                (4, 5, 1.0),
1495                (5, 6, 0.5),
1496                (6, 7, 0.25),
1497            ]);
1498            let c =
1499                newman_weighted_closeness_centrality(&g, true, |x| *x.weight(), parallel_threshold);
1500            let result = [
1501                Some(0.0),
1502                Some(0.14285714),
1503                Some(0.11428571),
1504                Some(0.07563025),
1505                Some(0.0),
1506                Some(0.14285714),
1507                Some(0.11428571),
1508                Some(0.07563025),
1509            ];
1510
1511            assert_almost_equal_iter!(result, c, 1e-4);
1512        };
1513        test_case(200); // sequential
1514        test_case(1); // parallel
1515    }
1516
1517    #[test]
1518    fn test_weighted_closeness_two_connected_components_improved_different_cardinality_digraph() {
1519        let test_case = |parallel_threshold: usize| {
1520            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1521                (0, 1, 1.0),
1522                (1, 2, 0.5),
1523                (2, 3, 0.25),
1524                (4, 5, 1.0),
1525                (5, 6, 0.5),
1526                (6, 7, 0.25),
1527                (7, 8, 0.125),
1528            ]);
1529            let c =
1530                newman_weighted_closeness_centrality(&g, true, |x| *x.weight(), parallel_threshold);
1531            let result = [
1532                Some(0.0),
1533                Some(0.125),
1534                Some(0.1),
1535                Some(0.06617647),
1536                Some(0.0),
1537                Some(0.125),
1538                Some(0.1),
1539                Some(0.06617647),
1540                Some(0.04081632),
1541            ];
1542
1543            assert_almost_equal_iter!(result, c, 1e-4);
1544        };
1545        test_case(200); // sequential
1546        test_case(1); // parallel
1547    }
1548
1549    #[test]
1550    fn test_weighted_closeness_small_ungraph() {
1551        let test_case = |parallel_threshold: usize| {
1552            let g = petgraph::graph::UnGraph::<u32, f64>::from_edges([
1553                (0, 1, 0.7),
1554                (1, 2, 0.2),
1555                (2, 3, 0.5),
1556            ]);
1557            let c = newman_weighted_closeness_centrality(
1558                &g,
1559                false,
1560                |x| *x.weight(),
1561                parallel_threshold,
1562            );
1563            let result = [
1564                Some(0.1842105),
1565                Some(0.2234042),
1566                Some(0.2234042),
1567                Some(0.1721311),
1568            ];
1569
1570            assert_almost_equal_iter!(result, c, 1e-4);
1571        };
1572        test_case(200); // sequential
1573        test_case(1); // parallel
1574    }
1575    #[test]
1576    fn test_weighted_closeness_small_digraph() {
1577        let test_case = |parallel_threshold: usize| {
1578            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1579                (0, 1, 0.7),
1580                (1, 2, 0.2),
1581                (2, 3, 0.5),
1582            ]);
1583            let c = newman_weighted_closeness_centrality(
1584                &g,
1585                false,
1586                |x| *x.weight(),
1587                parallel_threshold,
1588            );
1589            let result = [Some(0.0), Some(0.7), Some(0.175), Some(0.172131)];
1590
1591            assert_almost_equal_iter!(result, c, 1e-4);
1592        };
1593        test_case(200); // sequential
1594        test_case(1); // parallel
1595    }
1596
1597    #[test]
1598    fn test_weighted_closeness_many_to_one_connected_digraph() {
1599        let test_case = |parallel_threshold: usize| {
1600            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1601                (1, 0, 0.1),
1602                (2, 0, 0.1),
1603                (3, 0, 0.1),
1604                (4, 0, 0.1),
1605                (5, 0, 0.1),
1606                (6, 0, 0.1),
1607                (7, 0, 0.1),
1608                (0, 8, 1.0),
1609            ]);
1610            let c = newman_weighted_closeness_centrality(
1611                &g,
1612                false,
1613                |x| *x.weight(),
1614                parallel_threshold,
1615            );
1616            let result = [
1617                Some(0.1),
1618                Some(0.0),
1619                Some(0.0),
1620                Some(0.0),
1621                Some(0.0),
1622                Some(0.0),
1623                Some(0.0),
1624                Some(0.0),
1625                Some(0.10256),
1626            ];
1627
1628            assert_almost_equal_iter!(result, c, 1e-4);
1629        };
1630        test_case(200); // sequential
1631        test_case(1); // parallel
1632    }
1633
1634    #[test]
1635    fn test_weighted_closeness_many_to_one_connected_ungraph() {
1636        let test_case = |parallel_threshold: usize| {
1637            let g = petgraph::graph::UnGraph::<u32, f64>::from_edges([
1638                (1, 0, 0.1),
1639                (2, 0, 0.1),
1640                (3, 0, 0.1),
1641                (4, 0, 0.1),
1642                (5, 0, 0.1),
1643                (6, 0, 0.1),
1644                (7, 0, 0.1),
1645                (0, 8, 1.0),
1646            ]);
1647            let c = newman_weighted_closeness_centrality(
1648                &g,
1649                false,
1650                |x| *x.weight(),
1651                parallel_threshold,
1652            );
1653            let result = [
1654                Some(0.112676056),
1655                Some(0.056737588),
1656                Some(0.056737588),
1657                Some(0.056737588),
1658                Some(0.056737588),
1659                Some(0.056737588),
1660                Some(0.056737588),
1661                Some(0.056737588),
1662                Some(0.102564102),
1663            ];
1664
1665            assert_almost_equal_iter!(result, c, 1e-4);
1666        };
1667        test_case(200); // sequential
1668        test_case(1); // parallel
1669    }
1670
1671    #[test]
1672    fn test_weighted_closeness_many_to_one_not_connected_2_digraph() {
1673        let test_case = |parallel_threshold: usize| {
1674            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1675                (1, 0, 0.1),
1676                (2, 0, 0.1),
1677                (3, 0, 0.1),
1678                (4, 0, 0.1),
1679                (5, 0, 0.1),
1680                (6, 0, 0.1),
1681                (7, 0, 0.1),
1682                (1, 7, 1.0),
1683            ]);
1684            let c = newman_weighted_closeness_centrality(
1685                &g,
1686                false,
1687                |x| *x.weight(),
1688                parallel_threshold,
1689            );
1690            let result = [
1691                Some(0.1),
1692                Some(0.0),
1693                Some(0.0),
1694                Some(0.0),
1695                Some(0.0),
1696                Some(0.0),
1697                Some(0.0),
1698                Some(1.0),
1699            ];
1700
1701            assert_eq!(result, *c);
1702        };
1703        test_case(200); // sequential
1704        test_case(1); // parallel
1705    }
1706
1707    #[test]
1708    fn test_weighted_closeness_many_to_one_not_connected_1_digraph() {
1709        let test_case = |parallel_threshold: usize| {
1710            let g = petgraph::graph::DiGraph::<u32, f64>::from_edges([
1711                (1, 0, 0.1),
1712                (2, 0, 0.1),
1713                (3, 0, 0.1),
1714                (4, 0, 0.1),
1715                (5, 0, 0.1),
1716                (6, 0, 0.1),
1717                (7, 0, 0.1),
1718                (8, 7, 1.0),
1719            ]);
1720            let c = newman_weighted_closeness_centrality(
1721                &g,
1722                false,
1723                |x| *x.weight(),
1724                parallel_threshold,
1725            );
1726            let result = [
1727                Some(0.098765),
1728                Some(0.0),
1729                Some(0.0),
1730                Some(0.0),
1731                Some(0.0),
1732                Some(0.0),
1733                Some(0.0),
1734                Some(1.0),
1735                Some(0.0),
1736            ];
1737
1738            assert_almost_equal_iter!(result, c, 1e-4);
1739        };
1740        test_case(200); // sequential
1741        test_case(1); // parallel
1742    }
1743}
1744
1745/// Compute the group degree centrality of a set of nodes.
1746///
1747/// Group degree centrality measures the fraction of non-group nodes that are
1748/// connected to at least one member of the group. It is defined as:
1749///
1750/// C_D(S) = |N(S) \ S| / (|V| - |S|)
1751///
1752/// where N(S) is the union of neighborhoods of all nodes in S.
1753///
1754/// Based on: Everett, M. G., & Borgatti, S. P. (1999).
1755/// The centrality of groups and classes.
1756/// Journal of Mathematical Sociology, 23(3), 181-201.
1757///
1758/// Arguments:
1759///
1760/// * `graph` - The graph object to run the algorithm on
1761/// * `group` - A slice of node indices representing the group
1762/// * `direction` - Optional direction for directed graphs:
1763///     - `None` uses outgoing edges (default)
1764///     - `Some(Incoming)` counts nodes with edges into the group
1765///     - `Some(Outgoing)` counts nodes reachable from the group
1766///
1767/// # Example
1768/// ```rust
1769/// use rustworkx_core::petgraph;
1770/// use rustworkx_core::centrality::group_degree_centrality;
1771///
1772/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
1773///     (0, 1), (1, 2), (2, 3), (3, 4)
1774/// ]);
1775/// let output = group_degree_centrality(&g, &[0, 1], None);
1776/// // Nodes 0,1 are the group. Neighbors of {0,1} outside the group = {2}.
1777/// // So centrality = 1 / (5 - 2) = 1/3.
1778/// assert!((output - 1.0 / 3.0).abs() < 1e-10);
1779/// ```
1780pub fn group_degree_centrality<G>(
1781    graph: G,
1782    group: &[usize],
1783    direction: Option<petgraph::Direction>,
1784) -> f64
1785where
1786    G: NodeIndexable + IntoNodeIdentifiers + IntoNeighborsDirected + NodeCount + GraphProp,
1787    G::NodeId: Eq + Hash,
1788{
1789    let node_count = graph.node_count();
1790    let group_set: HashSet<usize> = group.iter().copied().collect();
1791    let group_size = group_set.len();
1792    if group_size >= node_count {
1793        return 0.0;
1794    }
1795
1796    let mut reached: HashSet<usize> = HashSet::new();
1797
1798    for &node_idx in &group_set {
1799        let node_id = graph.from_index(node_idx);
1800        let mut process_neighbor = |neighbor: G::NodeId| {
1801            let neighbor_idx = graph.to_index(neighbor);
1802            if !group_set.contains(&neighbor_idx) {
1803                reached.insert(neighbor_idx);
1804            }
1805        };
1806        match direction {
1807            Some(dir) => {
1808                for neighbor in graph.neighbors_directed(node_id, dir) {
1809                    process_neighbor(neighbor);
1810                }
1811            }
1812            None => {
1813                for neighbor in graph.neighbors(node_id) {
1814                    process_neighbor(neighbor);
1815                }
1816            }
1817        }
1818    }
1819
1820    reached.len() as f64 / (node_count - group_size) as f64
1821}
1822
1823/// Compute the group closeness centrality of a set of nodes.
1824///
1825/// Group closeness centrality measures how close a group of nodes is to
1826/// all non-group nodes. It is defined as:
1827///
1828/// C_close(S) = |V \ S| / sum_{v in V\S} d(S, v)
1829///
1830/// where d(S, v) = min_{u in S} d(u, v) is the minimum distance from any
1831/// group member to node v.
1832///
1833/// Note: For disconnected graphs, unreachable nodes do not contribute to
1834/// the distance sum but are still counted in |V \ S|. This can produce
1835/// values greater than 1.0, matching the convention used by NetworkX.
1836///
1837/// Based on: Everett, M. G., & Borgatti, S. P. (1999).
1838/// The centrality of groups and classes.
1839/// Journal of Mathematical Sociology, 23(3), 181-201.
1840///
1841/// Arguments:
1842///
1843/// * `graph` - The graph object to run the algorithm on
1844/// * `group` - A slice of node indices representing the group
1845///
1846/// # Example
1847/// ```rust
1848/// use rustworkx_core::petgraph;
1849/// use rustworkx_core::centrality::group_closeness_centrality;
1850///
1851/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
1852///     (0, 1), (1, 2), (2, 3), (3, 4)
1853/// ]);
1854/// let output = group_closeness_centrality(&g, &[0, 1]);
1855/// // Group = {0, 1}. Non-group = {2, 3, 4}.
1856/// // d({0,1}, 2) = 1, d({0,1}, 3) = 2, d({0,1}, 4) = 3. Sum = 6.
1857/// // Closeness = 3 / 6 = 0.5
1858/// assert!((output - 0.5).abs() < 1e-10);
1859/// ```
1860pub fn group_closeness_centrality<G>(graph: G, group: &[usize]) -> f64
1861where
1862    G: NodeIndexable + IntoNodeIdentifiers + GraphBase + IntoEdges + IntoEdgesDirected + NodeCount,
1863    G::NodeId: Eq + Hash,
1864    G::EdgeId: Eq,
1865{
1866    let node_count = graph.node_count();
1867    let group_set: HashSet<usize> = group.iter().copied().collect();
1868    let group_size = group_set.len();
1869    if group_size >= node_count {
1870        return 0.0;
1871    }
1872
1873    // Multi-source BFS on Reversed graph (incoming edges), matching the
1874    // convention used by NX and by per-node closeness_centrality: d(S,v)
1875    // is the distance from v to the nearest group member.
1876    let reversed = Reversed(&graph);
1877    let max_index = graph.node_bound();
1878    let mut distance: Vec<Option<usize>> = vec![None; max_index];
1879    let mut queue: VecDeque<G::NodeId> = VecDeque::new();
1880
1881    for &node_idx in &group_set {
1882        let node_id = graph.from_index(node_idx);
1883        distance[node_idx] = Some(0);
1884        queue.push_back(node_id);
1885    }
1886
1887    while let Some(v) = queue.pop_front() {
1888        let v_idx = graph.to_index(v);
1889        let dist_v = distance[v_idx].unwrap();
1890        for edge in reversed.edges(v) {
1891            let w = edge.target();
1892            let w_idx = graph.to_index(w);
1893            if distance[w_idx].is_none() {
1894                distance[w_idx] = Some(dist_v + 1);
1895                queue.push_back(w);
1896            }
1897        }
1898    }
1899
1900    let mut dist_sum: usize = 0;
1901    for node in graph.node_identifiers() {
1902        let idx = graph.to_index(node);
1903        if group_set.contains(&idx) {
1904            continue;
1905        }
1906        if let Some(d) = distance[idx] {
1907            dist_sum += d;
1908        }
1909    }
1910
1911    if dist_sum == 0 {
1912        return 0.0;
1913    }
1914
1915    (node_count - group_size) as f64 / dist_sum as f64
1916}
1917
1918/// Compute the group betweenness centrality of a set of nodes.
1919///
1920/// Group betweenness centrality measures the fraction of shortest paths
1921/// between non-group node pairs that pass through at least one group member.
1922/// It is defined as:
1923///
1924/// C_B(S) = sum_{s,t in V\S} sigma(s,t|S) / sigma(s,t)
1925///
1926/// where sigma(s,t) is the number of shortest paths from s to t, and
1927/// sigma(s,t|S) is the number of those paths passing through at least
1928/// one node in S.
1929///
1930/// Based on: Everett, M. G., & Borgatti, S. P. (1999).
1931/// The centrality of groups and classes.
1932/// Journal of Mathematical Sociology, 23(3), 181-201.
1933///
1934/// Arguments:
1935///
1936/// * `graph` - The graph object to run the algorithm on
1937/// * `group` - A slice of node indices representing the group
1938/// * `normalized` - Whether to normalize the result
1939/// * `parallel_threshold` - The number of nodes to calculate the group
1940///   betweenness centrality in parallel at, if the number of nodes in `graph`
1941///   is less than this value it will run in a single thread. A good default
1942///   to use here if you're not sure is `50` as that was found to be roughly
1943///   the number of nodes where parallelism improves performance for the
1944///   standard betweenness centrality function.
1945///
1946/// This function uses multiple threads for per-source shortest path searches
1947/// when the graph has at least `parallel_threshold` nodes. If the function
1948/// will be running in parallel the env var `RAYON_NUM_THREADS` can be used to
1949/// adjust how many threads will be used.
1950///
1951/// # Example
1952/// ```rust
1953/// use rustworkx_core::petgraph;
1954/// use rustworkx_core::centrality::group_betweenness_centrality;
1955///
1956/// let g = petgraph::graph::UnGraph::<i32, ()>::from_edges([
1957///     (0, 1), (1, 2), (2, 3), (3, 4)
1958/// ]);
1959/// let output = group_betweenness_centrality(&g, &[2], true, 50);
1960/// // Node 2 is on every shortest path between {0,1} and {3,4}.
1961/// assert!(output > 0.0);
1962/// ```
1963pub fn group_betweenness_centrality<G>(
1964    graph: G,
1965    group: &[usize],
1966    normalized: bool,
1967    parallel_threshold: usize,
1968) -> f64
1969where
1970    G: NodeIndexable
1971        + IntoNodeIdentifiers
1972        + IntoNeighborsDirected
1973        + NodeCount
1974        + GraphProp
1975        + GraphBase
1976        + Sync,
1977    G::NodeId: Eq + Hash,
1978{
1979    let node_count = graph.node_count();
1980    let group_set: HashSet<usize> = group.iter().copied().collect();
1981    let group_size = group_set.len();
1982
1983    if group_size == 0 || node_count <= 1 {
1984        return 0.0;
1985    }
1986
1987    let max_index = graph.node_bound();
1988
1989    // For each non-group source, run BFS on the full graph and on the graph
1990    // with group nodes removed. The difference in path counts gives us the
1991    // fraction of shortest paths passing through the group.
1992    let node_indices: Vec<usize> = graph
1993        .node_identifiers()
1994        .map(|node| graph.to_index(node))
1995        .collect();
1996
1997    let mut group_betweenness: f64 =
1998        CondIterator::new(node_indices.clone(), node_count >= parallel_threshold)
1999            .filter_map(|source_idx| {
2000                if group_set.contains(&source_idx) {
2001                    return None;
2002                }
2003                let source_id = graph.from_index(source_idx);
2004
2005                // BFS on full graph from source
2006                let mut sigma_full = vec![0.0_f64; max_index];
2007                let mut dist_full: Vec<Option<usize>> = vec![None; max_index];
2008                let mut queue: VecDeque<G::NodeId> = VecDeque::new();
2009
2010                sigma_full[source_idx] = 1.0;
2011                dist_full[source_idx] = Some(0);
2012                queue.push_back(source_id);
2013
2014                while let Some(v) = queue.pop_front() {
2015                    let v_idx = graph.to_index(v);
2016                    let dist_v = dist_full[v_idx].unwrap();
2017                    for w in graph.neighbors(v) {
2018                        let w_idx = graph.to_index(w);
2019                        if dist_full[w_idx].is_none() {
2020                            dist_full[w_idx] = Some(dist_v + 1);
2021                            queue.push_back(w);
2022                        }
2023                        if dist_full[w_idx] == Some(dist_v + 1) {
2024                            sigma_full[w_idx] += sigma_full[v_idx];
2025                        }
2026                    }
2027                }
2028
2029                // BFS on graph with group nodes removed
2030                let mut sigma_no_group = vec![0.0_f64; max_index];
2031                let mut dist_no_group: Vec<Option<usize>> = vec![None; max_index];
2032                let mut queue2: VecDeque<G::NodeId> = VecDeque::new();
2033
2034                sigma_no_group[source_idx] = 1.0;
2035                dist_no_group[source_idx] = Some(0);
2036                queue2.push_back(source_id);
2037
2038                while let Some(v) = queue2.pop_front() {
2039                    let v_idx = graph.to_index(v);
2040                    let dist_v = dist_no_group[v_idx].unwrap();
2041                    for w in graph.neighbors(v) {
2042                        let w_idx = graph.to_index(w);
2043                        if group_set.contains(&w_idx) {
2044                            continue;
2045                        }
2046                        if dist_no_group[w_idx].is_none() {
2047                            dist_no_group[w_idx] = Some(dist_v + 1);
2048                            queue2.push_back(w);
2049                        }
2050                        if dist_no_group[w_idx] == Some(dist_v + 1) {
2051                            sigma_no_group[w_idx] += sigma_no_group[v_idx];
2052                        }
2053                    }
2054                }
2055
2056                // For each non-group target, accumulate the fraction of shortest paths
2057                // that pass through at least one group member.
2058                let mut source_group_betweenness = 0.0;
2059                for &target_idx in &node_indices {
2060                    if target_idx == source_idx || group_set.contains(&target_idx) {
2061                        continue;
2062                    }
2063                    if sigma_full[target_idx] == 0.0 {
2064                        continue;
2065                    }
2066
2067                    // Paths through group = total - paths avoiding group,
2068                    // but only if the shortest path length is the same. If it differs,
2069                    // none of the shortest paths avoid the group.
2070                    let paths_avoiding = if dist_no_group[target_idx] == dist_full[target_idx] {
2071                        sigma_no_group[target_idx]
2072                    } else {
2073                        0.0
2074                    };
2075
2076                    let fraction_through_group =
2077                        (sigma_full[target_idx] - paths_avoiding) / sigma_full[target_idx];
2078                    source_group_betweenness += fraction_through_group;
2079                }
2080                Some(source_group_betweenness)
2081            })
2082            .sum();
2083
2084    if !graph.is_directed() {
2085        group_betweenness /= 2.0;
2086    }
2087
2088    if normalized {
2089        let non_group = node_count - group_size;
2090        if non_group > 1 {
2091            let norm = if graph.is_directed() {
2092                (non_group * (non_group - 1)) as f64
2093            } else {
2094                ((non_group * (non_group - 1)) / 2) as f64
2095            };
2096            group_betweenness /= norm;
2097        }
2098    }
2099
2100    group_betweenness
2101}
2102
2103#[cfg(test)]
2104mod test_group_degree_centrality {
2105    use crate::centrality::group_degree_centrality;
2106    use crate::petgraph;
2107
2108    #[test]
2109    fn test_undirected_path() {
2110        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2111        let result = group_degree_centrality(&g, &[0, 1], None);
2112        // Neighbors of {0,1} outside group = {2}. Centrality = 1/3.
2113        assert!((result - 1.0 / 3.0).abs() < 1e-10);
2114    }
2115
2116    #[test]
2117    fn test_undirected_complete() {
2118        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([
2119            (0, 1),
2120            (0, 2),
2121            (0, 3),
2122            (1, 2),
2123            (1, 3),
2124            (2, 3),
2125        ]);
2126        let result = group_degree_centrality(&g, &[0], None);
2127        assert!((result - 1.0).abs() < 1e-10);
2128    }
2129
2130    #[test]
2131    fn test_directed_out() {
2132        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
2133        let result = group_degree_centrality(&g, &[0, 1], Some(petgraph::Direction::Outgoing));
2134        // Out-neighbors of {0,1} outside group = {2}. Centrality = 1/2.
2135        assert!((result - 0.5).abs() < 1e-10);
2136    }
2137
2138    #[test]
2139    fn test_directed_in() {
2140        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
2141        let result = group_degree_centrality(&g, &[2, 3], Some(petgraph::Direction::Incoming));
2142        // In-neighbors of {2,3} outside group = {1}. Centrality = 1/2.
2143        assert!((result - 0.5).abs() < 1e-10);
2144    }
2145
2146    #[test]
2147    fn test_duplicate_group_nodes_are_counted_once() {
2148        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2)]);
2149        let result = group_degree_centrality(&g, &[1, 1], None);
2150        assert!((result - 1.0).abs() < 1e-10);
2151    }
2152}
2153
2154#[cfg(test)]
2155mod test_group_closeness_centrality {
2156    use crate::centrality::group_closeness_centrality;
2157    use crate::petgraph;
2158
2159    macro_rules! assert_almost_equal {
2160        ($x:expr, $y:expr, $d:expr) => {
2161            if ($x - $y).abs() >= $d {
2162                panic!("{} != {} within delta of {}", $x, $y, $d);
2163            }
2164        };
2165    }
2166
2167    #[test]
2168    fn test_undirected_path() {
2169        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2170        let result = group_closeness_centrality(&g, &[0, 1]);
2171        // Non-group = {2,3,4}. d(S,2)=1, d(S,3)=2, d(S,4)=3. Sum=6.
2172        // Closeness = 3/6 = 0.5
2173        assert_almost_equal!(result, 0.5, 1e-10);
2174    }
2175
2176    #[test]
2177    fn test_undirected_center_node() {
2178        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 2), (1, 2), (2, 3), (2, 4)]);
2179        let result = group_closeness_centrality(&g, &[2]);
2180        // Non-group = {0,1,3,4}. All at distance 1. Sum=4.
2181        // Closeness = 4/4 = 1.0
2182        assert_almost_equal!(result, 1.0, 1e-10);
2183    }
2184
2185    #[test]
2186    fn test_disconnected() {
2187        // Two disconnected components
2188        let mut g = petgraph::graph::UnGraph::<(), ()>::new_undirected();
2189        g.add_node(());
2190        g.add_node(());
2191        g.add_node(());
2192        g.add_edge(
2193            petgraph::graph::NodeIndex::new(0),
2194            petgraph::graph::NodeIndex::new(1),
2195            (),
2196        );
2197        // Node 2 is disconnected
2198        let result = group_closeness_centrality(&g, &[0]);
2199        // |V-S|=2, only node 1 reachable at distance 1. Node 2 unreachable.
2200        // dist_sum=1. closeness = 2/1 = 2.0
2201        assert_almost_equal!(result, 2.0, 1e-10);
2202    }
2203
2204    #[test]
2205    fn test_duplicate_group_nodes_are_counted_once() {
2206        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2)]);
2207        let result = group_closeness_centrality(&g, &[1, 1]);
2208        assert_almost_equal!(result, 1.0, 1e-10);
2209    }
2210}
2211
2212#[cfg(test)]
2213mod test_group_betweenness_centrality {
2214    use crate::centrality::{betweenness_centrality, group_betweenness_centrality};
2215    use crate::petgraph;
2216
2217    macro_rules! assert_almost_equal {
2218        ($x:expr, $y:expr, $d:expr) => {
2219            if ($x - $y).abs() >= $d {
2220                panic!("{} != {} within delta of {}", $x, $y, $d);
2221            }
2222        };
2223    }
2224
2225    #[test]
2226    fn test_undirected_path_center() {
2227        // Path: 0-1-2-3-4
2228        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2229        // Group = {2}. Node 2 is on all shortest paths between {0,1} and {3,4}.
2230        let result = group_betweenness_centrality(&g, &[2], false, 200);
2231        // Pairs through node 2: (0,3), (0,4), (1,3), (1,4) = 4 paths
2232        assert_almost_equal!(result, 4.0, 1e-10);
2233    }
2234
2235    #[test]
2236    fn test_undirected_path_center_parallel() {
2237        // Path: 0-1-2-3-4
2238        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2239        // Group = {2}. Node 2 is on all shortest paths between {0,1} and {3,4}.
2240        let result = group_betweenness_centrality(&g, &[2], false, 1);
2241        // Pairs through node 2: (0,3), (0,4), (1,3), (1,4) = 4 paths
2242        assert_almost_equal!(result, 4.0, 1e-10);
2243    }
2244
2245    #[test]
2246    fn test_undirected_path_center_normalized() {
2247        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2248        let result = group_betweenness_centrality(&g, &[2], true, 200);
2249        // Non-group size = 4. Normalization = C(4,2) = 6.
2250        // Normalized = 4/6 = 2/3
2251        assert_almost_equal!(result, 2.0 / 3.0, 1e-10);
2252    }
2253
2254    #[test]
2255    fn test_empty_group() {
2256        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2)]);
2257        let result = group_betweenness_centrality(&g, &[], false, 200);
2258        assert_almost_equal!(result, 0.0, 1e-10);
2259    }
2260
2261    #[test]
2262    fn test_single_node_group() {
2263        // Star graph: center=0, leaves=1,2,3,4
2264        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (0, 2), (0, 3), (0, 4)]);
2265        let result = group_betweenness_centrality(&g, &[0], false, 200);
2266        // Node 0 is on all 6 shortest paths between leaf pairs
2267        assert_almost_equal!(result, 6.0, 1e-10);
2268    }
2269
2270    #[test]
2271    fn test_directed_path() {
2272        // Directed path: 0->1->2->3->4
2273        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2274        // Group = {2}. Directed shortest paths through node 2:
2275        // (0,3), (0,4), (1,3), (1,4) = 4 pairs
2276        let result = group_betweenness_centrality(&g, &[2], false, 200);
2277        assert_almost_equal!(result, 4.0, 1e-10);
2278    }
2279
2280    #[test]
2281    fn test_directed_path_normalized() {
2282        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2283        let result = group_betweenness_centrality(&g, &[2], true, 200);
2284        // Non-group = 4 nodes. Directed normalization = 4 * 3 = 12.
2285        // Normalized = 4 / 12 = 1/3
2286        assert_almost_equal!(result, 1.0 / 3.0, 1e-10);
2287    }
2288
2289    #[test]
2290    fn test_directed_star() {
2291        // Directed star: 0->1, 0->2, 0->3, 0->4
2292        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (0, 2), (0, 3), (0, 4)]);
2293        // Group = {0}. No directed shortest paths between leaf pairs
2294        // pass through 0 (leaves are unreachable from each other).
2295        let result = group_betweenness_centrality(&g, &[0], false, 200);
2296        assert_almost_equal!(result, 0.0, 1e-10);
2297    }
2298
2299    #[test]
2300    fn test_directed_bidirectional_star() {
2301        // Bidirectional star: edges in both directions between center and leaves
2302        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([
2303            (0, 1),
2304            (1, 0),
2305            (0, 2),
2306            (2, 0),
2307            (0, 3),
2308            (3, 0),
2309            (0, 4),
2310            (4, 0),
2311        ]);
2312        // Group = {0}. All 12 directed shortest paths between leaf pairs
2313        // go through node 0 (e.g. 1->0->2, 2->0->1, etc.).
2314        // 4 leaves, 4*3 = 12 ordered pairs.
2315        let result = group_betweenness_centrality(&g, &[0], false, 200);
2316        assert_almost_equal!(result, 12.0, 1e-10);
2317    }
2318
2319    #[test]
2320    fn test_duplicate_group_nodes_are_counted_once() {
2321        let g = petgraph::graph::UnGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3), (3, 4)]);
2322        let result = group_betweenness_centrality(&g, &[2, 2], true, 200);
2323        assert_almost_equal!(result, 2.0 / 3.0, 1e-10);
2324    }
2325
2326    #[test]
2327    fn test_singleton_group_matches_betweenness_on_directed_path() {
2328        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
2329        let per_node = betweenness_centrality(&g, false, false, 200);
2330
2331        for (node, expected) in per_node.iter().enumerate().take(4) {
2332            let group_result = group_betweenness_centrality(&g, &[node], false, 200);
2333            assert_almost_equal!(group_result, expected.unwrap(), 1e-10);
2334        }
2335    }
2336
2337    #[test]
2338    fn test_singleton_group_matches_normalized_betweenness_on_directed_path() {
2339        let g = petgraph::graph::DiGraph::<(), ()>::from_edges([(0, 1), (1, 2), (2, 3)]);
2340        let per_node = betweenness_centrality(&g, false, true, 200);
2341
2342        for (node, expected) in per_node.iter().enumerate().take(4) {
2343            let group_result = group_betweenness_centrality(&g, &[node], true, 200);
2344            assert_almost_equal!(group_result, expected.unwrap(), 1e-10);
2345        }
2346    }
2347}