webgraph_algo/distances/
hyperball.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Matteo Dell'Acqua
3 * SPDX-FileCopyrightText: 2025 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use anyhow::{bail, ensure, Context, Result};
9use card_est_array::impls::{HyperLogLog, HyperLogLogBuilder, SliceEstimatorArray};
10use card_est_array::traits::{
11    AsSyncArray, EstimationLogic, EstimatorArray, EstimatorArrayMut, EstimatorMut,
12    MergeEstimationLogic, SyncEstimatorArray,
13};
14use crossbeam_utils::CachePadded;
15use dsi_progress_logger::ConcurrentProgressLog;
16use kahan::KahanSum;
17use rayon::prelude::*;
18use std::hash::{BuildHasherDefault, DefaultHasher};
19use std::sync::{atomic::*, Mutex};
20use sux::traits::AtomicBitVecOps;
21use sux::{bits::AtomicBitVec, traits::Succ};
22use sync_cell_slice::{SyncCell, SyncSlice};
23use webgraph::traits::{RandomAccessGraph, SequentialLabeling};
24use webgraph::utils::Granularity;
25
26/// A builder for [`HyperBall`].
27///
28/// After creating a builder with [`HyperBallBuilder::new`] you can configure it
29/// using setters such as [`HyperBallBuilder`] its methods, then call
30/// [`HyperBallBuilder::build`] on it to create a [`HyperBall`] instance.
31pub struct HyperBallBuilder<
32    'a,
33    G1: RandomAccessGraph + Sync,
34    G2: RandomAccessGraph + Sync,
35    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
36    L: MergeEstimationLogic<Item = G1::Label>,
37    A: EstimatorArrayMut<L>,
38> {
39    /// A graph.
40    graph: &'a G1,
41    /// The transpose of `graph`, if any.
42    transpose: Option<&'a G2>,
43    /// The outdegree cumulative function of the graph.
44    cumul_outdegree: &'a D,
45    /// Whether to compute the sum of distances (e.g., for closeness centrality).
46    do_sum_of_dists: bool,
47    /// Whether to compute the sum of inverse distances (e.g., for harmonic centrality).
48    do_sum_of_inv_dists: bool,
49    /// Custom discount functions whose sum should be computed.
50    discount_functions: Vec<Box<dyn Fn(usize) -> f64 + Send + Sync + 'a>>,
51    /// The arc granularity.
52    arc_granularity: usize,
53    /// Integer weights for the nodes, if any.
54    weights: Option<&'a [usize]>,
55    /// A first array of estimators.
56    array_0: A,
57    /// A second array of estimators of the same length and with the same logic of
58    /// `array_0`.
59    array_1: A,
60    _marker: std::marker::PhantomData<L>,
61}
62
63impl<
64        'a,
65        G1: RandomAccessGraph + Sync,
66        G2: RandomAccessGraph + Sync,
67        D: for<'b> Succ<Input = usize, Output<'b> = usize>,
68    >
69    HyperBallBuilder<
70        'a,
71        G1,
72        G2,
73        D,
74        HyperLogLog<G1::Label, BuildHasherDefault<DefaultHasher>, usize>,
75        SliceEstimatorArray<
76            HyperLogLog<G1::Label, BuildHasherDefault<DefaultHasher>, usize>,
77            usize,
78            Box<[usize]>,
79        >,
80    >
81{
82    /// A builder for [`HyperBall`] using a specified [`EstimationLogic`].
83    ///
84    /// # Arguments
85    /// * `graph`: the graph to analyze.
86    /// * `transpose`: optionally, the transpose of `graph`. If [`None`], no
87    ///   systolic iterations will be performed by the resulting [`HyperBall`].
88    /// * `cumul_outdeg`: the outdegree cumulative function of the graph.
89    /// * `log2m`: the base-2 logarithm of the number *m* of register per
90    ///   HyperLogLog cardinality estimator.
91    /// * `weights`: the weights to use. If [`None`] every node is assumed to be
92    ///   of weight equal to 1.
93    pub fn with_hyper_log_log(
94        graph: &'a G1,
95        transposed: Option<&'a G2>,
96        cumul_outdeg: &'a D,
97        log2m: usize,
98        weights: Option<&'a [usize]>,
99    ) -> Result<Self> {
100        let num_elements = if let Some(w) = weights {
101            ensure!(
102                w.len() == graph.num_nodes(),
103                "weights should have length equal to the graph's number of nodes"
104            );
105            w.iter().sum()
106        } else {
107            graph.num_nodes()
108        };
109
110        let logic = HyperLogLogBuilder::new(num_elements)
111            .log_2_num_reg(log2m)
112            .build()
113            .with_context(|| "Could not build HyperLogLog logic")?;
114
115        let array_0 = SliceEstimatorArray::new(logic.clone(), graph.num_nodes());
116        let array_1 = SliceEstimatorArray::new(logic, graph.num_nodes());
117
118        Ok(Self {
119            graph,
120            transpose: transposed,
121            cumul_outdegree: cumul_outdeg,
122            do_sum_of_dists: false,
123            do_sum_of_inv_dists: false,
124            discount_functions: Vec::new(),
125            arc_granularity: Self::DEFAULT_GRANULARITY,
126            weights,
127            array_0,
128            array_1,
129            _marker: std::marker::PhantomData,
130        })
131    }
132}
133
134impl<
135        'a,
136        D: for<'b> Succ<Input = usize, Output<'b> = usize>,
137        G: RandomAccessGraph + Sync,
138        L: MergeEstimationLogic<Item = G::Label> + PartialEq,
139        A: EstimatorArrayMut<L>,
140    > HyperBallBuilder<'a, G, G, D, L, A>
141{
142    /// Creates a new builder with default parameters.
143    ///
144    /// # Arguments
145    /// * `graph`: the graph to analyze.
146    /// * `cumul_outdeg`: the outdegree cumulative function of the graph.
147    /// * `array_0`: a first array of estimators.
148    /// * `array_1`: A second array of estimators of the same length and with the same logic of
149    ///   `array_0`.
150    pub fn new(graph: &'a G, cumul_outdeg: &'a D, array_0: A, array_1: A) -> Self {
151        assert!(array_0.logic() == array_1.logic(), "Incompatible logic");
152        assert_eq!(
153            graph.num_nodes(),
154            array_0.len(),
155            "array_0 should have length {}. Got {}",
156            graph.num_nodes(),
157            array_0.len()
158        );
159        assert_eq!(
160            graph.num_nodes(),
161            array_1.len(),
162            "array_1 should have length {}. Got {}",
163            graph.num_nodes(),
164            array_1.len()
165        );
166        Self {
167            graph,
168            transpose: None,
169            cumul_outdegree: cumul_outdeg,
170            do_sum_of_dists: false,
171            do_sum_of_inv_dists: false,
172            discount_functions: Vec::new(),
173            arc_granularity: Self::DEFAULT_GRANULARITY,
174            weights: None,
175            array_0,
176            array_1,
177            _marker: std::marker::PhantomData,
178        }
179    }
180}
181
182impl<
183        'a,
184        G1: RandomAccessGraph + Sync,
185        G2: RandomAccessGraph + Sync,
186        D: for<'b> Succ<Input = usize, Output<'b> = usize>,
187        L: MergeEstimationLogic<Item = G1::Label>,
188        A: EstimatorArrayMut<L>,
189    > HyperBallBuilder<'a, G1, G2, D, L, A>
190{
191    const DEFAULT_GRANULARITY: usize = 16 * 1024;
192
193    /// Creates a new builder with default parameters using also the transpose.
194    ///
195    /// * `graph`: the graph to analyze.
196    /// * `transpose`: optionally, the transpose of `graph`. If [`None`], no
197    ///   systolic iterations will be performed by the resulting [`HyperBall`].
198    /// * `cumul_outdeg`: the outdegree cumulative function of the graph.
199    /// * `array_0`: a first array of estimators.
200    /// * `array_1`: A second array of estimators of the same length and with the same logic of
201    ///   `array_0`.
202    pub fn with_transpose(
203        graph: &'a G1,
204        transpose: &'a G2,
205        cumul_outdeg: &'a D,
206        array_0: A,
207        array_1: A,
208    ) -> Self {
209        assert_eq!(
210            graph.num_nodes(),
211            array_0.len(),
212            "array_0 should have have len {}. Got {}",
213            graph.num_nodes(),
214            array_0.len()
215        );
216        assert_eq!(
217            graph.num_nodes(),
218            array_1.len(),
219            "array_1 should have have len {}. Got {}",
220            graph.num_nodes(),
221            array_1.len()
222        );
223        assert_eq!(
224            transpose.num_nodes(),
225            graph.num_nodes(),
226            "the transpose should have same number of nodes of the graph ({}). Got {}.",
227            graph.num_nodes(),
228            transpose.num_nodes()
229        );
230        assert_eq!(
231            transpose.num_arcs(),
232            graph.num_arcs(),
233            "the transpose should have same number of nodes of the graph ({}). Got {}.",
234            graph.num_arcs(),
235            transpose.num_arcs()
236        );
237        /* TODO debug_assert!(
238            check_transposed(graph, transpose),
239            "the transpose should be the transpose of the graph"
240        );*/
241        Self {
242            graph,
243            transpose: Some(transpose),
244            cumul_outdegree: cumul_outdeg,
245            do_sum_of_dists: false,
246            do_sum_of_inv_dists: false,
247            discount_functions: Vec::new(),
248            arc_granularity: Self::DEFAULT_GRANULARITY,
249            weights: None,
250            array_0,
251            array_1,
252            _marker: std::marker::PhantomData,
253        }
254    }
255
256    /// Sets whether to compute the sum of distances.
257    pub fn sum_of_distances(mut self, do_sum_of_distances: bool) -> Self {
258        self.do_sum_of_dists = do_sum_of_distances;
259        self
260    }
261
262    /// Sets whether to compute the sum of inverse distances.
263    pub fn sum_of_inverse_distances(mut self, do_sum_of_inverse_distances: bool) -> Self {
264        self.do_sum_of_inv_dists = do_sum_of_inverse_distances;
265        self
266    }
267
268    /// Sets the base granularity used in the parallel phases of the iterations.
269    pub fn granularity(mut self, granularity: Granularity) -> Self {
270        self.arc_granularity =
271            granularity.arc_granularity(self.graph.num_nodes(), Some(self.graph.num_arcs()));
272        self
273    }
274
275    /// Sets optional weights for the nodes of the graph.
276    ///
277    /// # Arguments
278    /// * `weights`: weights to use for the nodes. If [`None`], every node is
279    ///   assumed to be of weight equal to 1.
280    pub fn weights(mut self, weights: Option<&'a [usize]>) -> Self {
281        if let Some(w) = weights {
282            assert_eq!(w.len(), self.graph.num_nodes());
283        }
284        self.weights = weights;
285        self
286    }
287
288    /// Adds a new discount function whose sum over all spheres should be
289    /// computed.
290    pub fn discount_function(
291        mut self,
292        discount_function: impl Fn(usize) -> f64 + Send + Sync + 'a,
293    ) -> Self {
294        self.discount_functions.push(Box::new(discount_function));
295        self
296    }
297
298    /// Removes all custom discount functions.
299    pub fn no_discount_function(mut self) -> Self {
300        self.discount_functions.clear();
301        self
302    }
303}
304
305impl<
306        'a,
307        G1: RandomAccessGraph + Sync,
308        G2: RandomAccessGraph + Sync,
309        D: for<'b> Succ<Input = usize, Output<'b> = usize>,
310        L: MergeEstimationLogic<Item = G1::Label> + Sync + std::fmt::Display,
311        A: EstimatorArrayMut<L>,
312    > HyperBallBuilder<'a, G1, G2, D, L, A>
313{
314    /// Builds a [`HyperBall`] instance.
315    ///
316    /// # Arguments
317    ///
318    /// * `pl`: A progress logger.
319    pub fn build(self, pl: &mut impl ConcurrentProgressLog) -> HyperBall<'a, G1, G2, D, L, A> {
320        let num_nodes = self.graph.num_nodes();
321
322        let sum_of_distances = if self.do_sum_of_dists {
323            pl.debug(format_args!("Initializing sum of distances"));
324            Some(vec![0.0; num_nodes])
325        } else {
326            pl.debug(format_args!("Skipping sum of distances"));
327            None
328        };
329        let sum_of_inverse_distances = if self.do_sum_of_inv_dists {
330            pl.debug(format_args!("Initializing sum of inverse distances"));
331            Some(vec![0.0; num_nodes])
332        } else {
333            pl.debug(format_args!("Skipping sum of inverse distances"));
334            None
335        };
336
337        let mut discounted_centralities = Vec::new();
338        pl.debug(format_args!(
339            "Initializing {} discount functions",
340            self.discount_functions.len()
341        ));
342        for _ in self.discount_functions.iter() {
343            discounted_centralities.push(vec![0.0; num_nodes]);
344        }
345
346        pl.info(format_args!("Initializing bit vectors"));
347        let estimator_modified = AtomicBitVec::new(num_nodes);
348        let modified_result_estimator = AtomicBitVec::new(num_nodes);
349        let must_be_checked = AtomicBitVec::new(num_nodes);
350        let next_must_be_checked = AtomicBitVec::new(num_nodes);
351
352        pl.info(format_args!(
353            "Using estimation logic: {}",
354            self.array_0.logic()
355        ));
356
357        HyperBall {
358            graph: self.graph,
359            transposed: self.transpose,
360            weight: self.weights,
361            granularity: self.arc_granularity,
362            curr_state: self.array_0,
363            next_state: self.array_1,
364            completed: false,
365            neighborhood_function: Vec::new(),
366            last: 0.0,
367            relative_increment: 0.0,
368            sum_of_dists: sum_of_distances,
369            sum_of_inv_dists: sum_of_inverse_distances,
370            discounted_centralities,
371            iteration_context: IterationContext {
372                cumul_outdeg: self.cumul_outdegree,
373                iteration: 0,
374                current_nf: Mutex::new(0.0),
375                arc_granularity: 0,
376                node_cursor: AtomicUsize::new(0).into(),
377                arc_cursor: Mutex::new((0, 0)),
378                visited_arcs: AtomicU64::new(0).into(),
379                modified_estimators: AtomicU64::new(0).into(),
380                systolic: false,
381                local: false,
382                pre_local: false,
383                local_checklist: Vec::new(),
384                local_next_must_be_checked: Mutex::new(Vec::new()),
385                must_be_checked,
386                next_must_be_checked,
387                curr_modified: estimator_modified,
388                next_modified: modified_result_estimator,
389                discount_functions: self.discount_functions,
390            },
391            _marker: std::marker::PhantomData,
392        }
393    }
394}
395
396/// Data used by [`parallel_task`](HyperBall::parallel_task).
397///
398/// These variables are used by the threads running
399/// [`parallel_task`](HyperBall::parallel_task). They must be isolated in a
400/// field because we need to be able to borrow exclusively
401/// [`HyperBall::next_state`], while sharing references to the data contained
402/// here and to the [`HyperBall::curr_state`].
403struct IterationContext<'a, G1: SequentialLabeling, D> {
404    /// The cumulative list of outdegrees.
405    cumul_outdeg: &'a D,
406    /// The number of the current iteration.
407    iteration: usize,
408    /// The value of the neighborhood function computed during the current iteration.
409    current_nf: Mutex<f64>,
410    /// The arc granularity: each task will try to process at least this number
411    /// of arcs.
412    arc_granularity: usize,
413    /// A cursor scanning the nodes to process during local computations.
414    node_cursor: CachePadded<AtomicUsize>,
415    /// A cursor scanning the nodes and arcs to process during non-local
416    /// computations.
417    arc_cursor: Mutex<(usize, usize)>,
418    /// The number of arcs visited during the current iteration.
419    visited_arcs: CachePadded<AtomicU64>,
420    /// The number of estimators modified during the current iteration.
421    modified_estimators: CachePadded<AtomicU64>,
422    /// `true` if we started a systolic computation.
423    systolic: bool,
424    /// `true` if we started a local computation.
425    local: bool,
426    /// `true` if we are preparing a local computation (systolic is `true` and less than 1% nodes were modified).
427    pre_local: bool,
428    /// If [`local`](Self::local) is `true`, the sorted list of nodes that
429    /// should be scanned.
430    local_checklist: Vec<G1::Label>,
431    /// If [`pre_local`](Self::pre_local) is `true`, the set of nodes that
432    /// should be scanned on the next iteration.
433    local_next_must_be_checked: Mutex<Vec<G1::Label>>,
434    /// Used in systolic iterations to keep track of nodes to check.
435    must_be_checked: AtomicBitVec,
436    /// Used in systolic iterations to keep track of nodes to check in the next
437    /// iteration.
438    next_must_be_checked: AtomicBitVec,
439    /// Whether each estimator has been modified during the previous iteration.
440    curr_modified: AtomicBitVec,
441    /// Whether each estimator has been modified during the current iteration.
442    next_modified: AtomicBitVec,
443    /// Custom discount functions whose sum should be computed.
444    discount_functions: Vec<Box<dyn Fn(usize) -> f64 + Send + Sync + 'a>>,
445}
446
447impl<G1: SequentialLabeling, D> IterationContext<'_, G1, D> {
448    /// Resets the iteration context
449    fn reset(&mut self, granularity: usize) {
450        self.arc_granularity = granularity;
451        self.node_cursor.store(0, Ordering::Relaxed);
452        *self.arc_cursor.lock().unwrap() = (0, 0);
453        self.visited_arcs.store(0, Ordering::Relaxed);
454        self.modified_estimators.store(0, Ordering::Relaxed);
455    }
456}
457
458/// An algorithm that computes an approximation of the neighborhood function,
459/// of the size of the reachable sets, and of (discounted) positive geometric
460/// centralities of a graph.
461pub struct HyperBall<
462    'a,
463    G1: RandomAccessGraph + Sync,
464    G2: RandomAccessGraph + Sync,
465    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
466    L: MergeEstimationLogic<Item = G1::Label> + Sync,
467    A: EstimatorArrayMut<L>,
468> {
469    /// The graph to analyze.
470    graph: &'a G1,
471    /// The transpose of [`Self::graph`], if any.
472    transposed: Option<&'a G2>,
473    /// An optional slice of nonnegative node weights.
474    weight: Option<&'a [usize]>,
475    /// The base number of nodes per task. TODO.
476    granularity: usize,
477    /// The previous state.
478    curr_state: A,
479    /// The next state.
480    next_state: A,
481    /// `true` if the computation is over.
482    completed: bool,
483    /// The neighborhood function.
484    neighborhood_function: Vec<f64>,
485    /// The value computed by the last iteration.
486    last: f64,
487    /// The relative increment of the neighborhood function for the last
488    /// iteration.
489    relative_increment: f64,
490    /// The sum of the distances from every given node, if requested.
491    sum_of_dists: Option<Vec<f32>>,
492    /// The sum of inverse distances from each given node, if requested.
493    sum_of_inv_dists: Option<Vec<f32>>,
494    /// The overall discount centrality for every discount function.
495    discounted_centralities: Vec<Vec<f32>>,
496    /// Context used in a single iteration.
497    iteration_context: IterationContext<'a, G1, D>,
498    _marker: std::marker::PhantomData<L>,
499}
500
501impl<
502        G1: RandomAccessGraph + Sync,
503        G2: RandomAccessGraph + Sync,
504        D: for<'b> Succ<Input = usize, Output<'b> = usize> + Sync,
505        L: MergeEstimationLogic<Item = usize> + Sync,
506        A: EstimatorArrayMut<L> + Sync + AsSyncArray<L>,
507    > HyperBall<'_, G1, G2, D, L, A>
508where
509    L::Backend: PartialEq,
510{
511    /// Runs HyperBall.
512    ///
513    /// # Arguments
514    ///
515    /// * `upper_bound`: an upper bound to the number of iterations.
516    ///
517    /// * `threshold`: a value that will be used to stop the computation by
518    ///   relative increment if the neighborhood function is being computed. If
519    ///   [`None`] the computation will stop when no estimators are modified.
520    ///
521    /// * `pl`: A progress logger.
522    pub fn run(
523        &mut self,
524        upper_bound: usize,
525        threshold: Option<f64>,
526        rng: impl rand::Rng,
527        pl: &mut impl ConcurrentProgressLog,
528    ) -> Result<()> {
529        let upper_bound = std::cmp::min(upper_bound, self.graph.num_nodes());
530
531        self.init(rng, pl)
532            .with_context(|| "Could not initialize estimator")?;
533
534        pl.item_name("iteration");
535        pl.expected_updates(None);
536        pl.start(format!(
537            "Running HyperBall for a maximum of {} iterations and a threshold of {:?}",
538            upper_bound, threshold
539        ));
540
541        for i in 0..upper_bound {
542            self.iterate(pl)
543                .with_context(|| format!("Could not perform iteration {}", i + 1))?;
544
545            pl.update_and_display();
546
547            if self
548                .iteration_context
549                .modified_estimators
550                .load(Ordering::Relaxed)
551                == 0
552            {
553                pl.info(format_args!(
554                    "Terminating HyperBall after {} iteration(s) by stabilization",
555                    i + 1
556                ));
557                break;
558            }
559
560            if let Some(t) = threshold {
561                if i > 3 && self.relative_increment < (1.0 + t) {
562                    pl.info(format_args!("Terminating HyperBall after {} iteration(s) by relative bound on the neighborhood function", i + 1));
563                    break;
564                }
565            }
566        }
567
568        pl.done();
569
570        Ok(())
571    }
572
573    /// Runs HyperBall until no estimators are modified.
574    ///
575    /// # Arguments
576    ///
577    /// * `upper_bound`: an upper bound to the number of iterations.
578    ///
579    /// * `pl`: A progress logger.
580    #[inline(always)]
581    pub fn run_until_stable(
582        &mut self,
583        upper_bound: usize,
584        rng: impl rand::Rng,
585        pl: &mut impl ConcurrentProgressLog,
586    ) -> Result<()> {
587        self.run(upper_bound, None, rng, pl)
588            .with_context(|| "Could not complete run_until_stable")
589    }
590
591    /// Runs HyperBall until no estimators are modified with no upper bound on the
592    /// number of iterations.
593    ///
594    /// # Arguments
595    ///
596    /// * `pl`: A progress logger.
597    #[inline(always)]
598    pub fn run_until_done(
599        &mut self,
600        rng: impl rand::Rng,
601        pl: &mut impl ConcurrentProgressLog,
602    ) -> Result<()> {
603        self.run_until_stable(usize::MAX, rng, pl)
604            .with_context(|| "Could not complete run_until_done")
605    }
606
607    #[inline(always)]
608    fn ensure_iteration(&self) -> Result<()> {
609        ensure!(
610            self.iteration_context.iteration > 0,
611            "HyperBall was not run. Please call HyperBall::run before accessing computed fields."
612        );
613        Ok(())
614    }
615
616    /// Returns the neighborhood function computed by this instance.
617    pub fn neighborhood_function(&self) -> Result<&[f64]> {
618        self.ensure_iteration()?;
619        Ok(&self.neighborhood_function)
620    }
621
622    /// Returns the sum of distances computed by this instance if requested.
623    pub fn sum_of_distances(&self) -> Result<&[f32]> {
624        self.ensure_iteration()?;
625        if let Some(distances) = &self.sum_of_dists {
626            // TODO these are COPIES
627            Ok(distances)
628        } else {
629            bail!("Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute them")
630        }
631    }
632
633    /// Returns the harmonic centralities (sum of inverse distances) computed by this instance if requested.
634    pub fn harmonic_centralities(&self) -> Result<&[f32]> {
635        self.ensure_iteration()?;
636        if let Some(distances) = &self.sum_of_inv_dists {
637            Ok(distances)
638        } else {
639            bail!("Sum of inverse distances were not requested: use builder.with_sum_of_inverse_distances(true) while building HyperBall to compute them")
640        }
641    }
642
643    /// Returns the discounted centralities of the specified index computed by this instance.
644    ///
645    /// # Arguments
646    /// * `index`: the index of the requested discounted centrality.
647    pub fn discounted_centrality(&self, index: usize) -> Result<&[f32]> {
648        self.ensure_iteration()?;
649        let d = self.discounted_centralities.get(index);
650        if let Some(distances) = d {
651            Ok(distances)
652        } else {
653            bail!("Discount centrality of index {} does not exist", index)
654        }
655    }
656
657    /// Computes and returns the closeness centralities from the sum of distances computed by this instance.
658    pub fn closeness_centrality(&self) -> Result<Box<[f32]>> {
659        self.ensure_iteration()?;
660        if let Some(distances) = &self.sum_of_dists {
661            Ok(distances
662                .iter()
663                .map(|&d| if d == 0.0 { 0.0 } else { d.recip() })
664                .collect())
665        } else {
666            bail!("Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute closeness centrality")
667        }
668    }
669
670    /// Computes and returns the lin centralities from the sum of distances computed by this instance.
671    ///
672    /// Note that lin's index for isolated nodes is by (our) definition one (it's smaller than any other node).
673    pub fn lin_centrality(&self) -> Result<Box<[f32]>> {
674        self.ensure_iteration()?;
675        if let Some(distances) = &self.sum_of_dists {
676            let logic = self.curr_state.logic();
677            Ok(distances
678                .iter()
679                .enumerate()
680                .map(|(node, &d)| {
681                    if d == 0.0 {
682                        1.0
683                    } else {
684                        let count = logic.estimate(self.curr_state.get_backend(node));
685                        (count * count / d as f64) as f32
686                    }
687                })
688                .collect())
689        } else {
690            bail!("Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute lin centrality")
691        }
692    }
693
694    /// Computes and returns the Nieminen centralities from the sum of distances computed by this instance.
695    pub fn nieminen_centrality(&self) -> Result<Box<[f32]>> {
696        self.ensure_iteration()?;
697        if let Some(distances) = &self.sum_of_dists {
698            let logic = self.curr_state.logic();
699            Ok(distances
700                .iter()
701                .enumerate()
702                .map(|(node, &d)| {
703                    let count = logic.estimate(self.curr_state.get_backend(node));
704                    ((count * count) - d as f64) as f32
705                })
706                .collect())
707        } else {
708            bail!("Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute lin centrality")
709        }
710    }
711
712    /// Reads from the internal estimator array and estimates the number of nodes
713    /// reachable from the specified node.
714    ///
715    /// # Arguments
716    /// * `node`: the index of the node to compute reachable nodes from.
717    pub fn reachable_nodes_from(&self, node: usize) -> Result<f64> {
718        self.ensure_iteration()?;
719        Ok(self
720            .curr_state
721            .logic()
722            .estimate(self.curr_state.get_backend(node)))
723    }
724
725    /// Reads from the internal estimator array and estimates the number of nodes reachable
726    /// from every node of the graph.
727    ///
728    /// `hyperball.reachable_nodes().unwrap()[i]` is equal to `hyperball.reachable_nodes_from(i).unwrap()`.
729    pub fn reachable_nodes(&self) -> Result<Box<[f32]>> {
730        self.ensure_iteration()?;
731        let logic = self.curr_state.logic();
732        Ok((0..self.graph.num_nodes())
733            .map(|n| logic.estimate(self.curr_state.get_backend(n)) as f32)
734            .collect())
735    }
736}
737
738impl<
739        G1: RandomAccessGraph + Sync,
740        G2: RandomAccessGraph + Sync,
741        D: for<'b> Succ<Input = usize, Output<'b> = usize> + Sync,
742        L: EstimationLogic<Item = usize> + MergeEstimationLogic + Sync,
743        A: EstimatorArrayMut<L> + Sync + AsSyncArray<L>,
744    > HyperBall<'_, G1, G2, D, L, A>
745where
746    L::Backend: PartialEq,
747{
748    /// Performs a new iteration of HyperBall.
749    ///
750    /// # Arguments
751    /// * `pl`: A progress logger.
752    fn iterate(&mut self, pl: &mut impl ConcurrentProgressLog) -> Result<()> {
753        let ic = &mut self.iteration_context;
754
755        pl.info(format_args!("Performing iteration {}", ic.iteration + 1));
756
757        // Alias the number of modified estimators, nodes and arcs
758        let num_nodes = self.graph.num_nodes() as u64;
759        let num_arcs = self.graph.num_arcs();
760        let modified_estimators = ic.modified_estimators.load(Ordering::Relaxed);
761
762        // Let us record whether the previous computation was systolic or local
763        let prev_was_systolic = ic.systolic;
764        let prev_was_local = ic.local;
765
766        // If less than one fourth of the nodes have been modified, and we have
767        // the transpose, it is time to pass to a systolic computation
768        ic.systolic =
769            self.transposed.is_some() && ic.iteration > 0 && modified_estimators < num_nodes / 4;
770
771        // Non-systolic computations add up the values of all estimators.
772        //
773        // Systolic computations modify the last value by compensating for each
774        // modified estimators.
775        *ic.current_nf.lock().unwrap() = if ic.systolic { self.last } else { 0.0 };
776
777        // If we completed the last iteration in pre-local mode, we MUST run in
778        // local mode
779        ic.local = ic.pre_local;
780
781        // We run in pre-local mode if we are systolic and few nodes where
782        // modified.
783        ic.pre_local =
784            ic.systolic && modified_estimators < (num_nodes * num_nodes) / (num_arcs * 10);
785
786        if ic.systolic {
787            pl.info(format_args!(
788                "Starting systolic iteration (local: {}, pre_local: {})",
789                ic.local, ic.pre_local
790            ));
791        } else {
792            pl.info(format_args!("Starting standard iteration"));
793        }
794
795        if prev_was_local {
796            for &node in ic.local_checklist.iter() {
797                ic.next_modified.set(node, false, Ordering::Relaxed);
798            }
799        } else {
800            ic.next_modified.fill(false, Ordering::Relaxed);
801        }
802
803        if ic.local {
804            // In case of a local computation, we convert the set of
805            // must-be-checked for the next iteration into a check list
806            rayon::join(
807                || ic.local_checklist.clear(),
808                || {
809                    let mut local_next_must_be_checked =
810                        ic.local_next_must_be_checked.lock().unwrap();
811                    local_next_must_be_checked.par_sort_unstable();
812                    local_next_must_be_checked.dedup();
813                },
814            );
815            std::mem::swap(
816                &mut ic.local_checklist,
817                &mut ic.local_next_must_be_checked.lock().unwrap(),
818            );
819        } else if ic.systolic {
820            rayon::join(
821                || {
822                    // Systolic, non-local computations store the could-be-modified set implicitly into Self::next_must_be_checked.
823                    ic.next_must_be_checked.fill(false, Ordering::Relaxed);
824                },
825                || {
826                    // If the previous computation wasn't systolic, we must assume that all registers could have changed.
827                    if !prev_was_systolic {
828                        ic.must_be_checked.fill(true, Ordering::Relaxed);
829                    }
830                },
831            );
832        }
833
834        let mut granularity = ic.arc_granularity;
835        let num_threads = rayon::current_num_threads();
836
837        if num_threads > 1 && !ic.local {
838            if ic.iteration > 0 {
839                granularity = f64::min(
840                    std::cmp::max(1, num_nodes as usize / num_threads) as _,
841                    granularity as f64
842                        * (num_nodes as f64 / std::cmp::max(1, modified_estimators) as f64),
843                ) as usize;
844            }
845            pl.info(format_args!(
846                "Adaptive granularity for this iteration: {}",
847                granularity
848            ));
849        }
850
851        ic.reset(granularity);
852
853        pl.item_name("arc");
854        pl.expected_updates(if ic.local { None } else { Some(num_arcs as _) });
855        pl.start("Starting parallel execution");
856        {
857            let next_state_sync = self.next_state.as_sync_array();
858            let sum_of_dists = self.sum_of_dists.as_mut().map(|x| x.as_sync_slice());
859            let sum_of_inv_dists = self.sum_of_inv_dists.as_mut().map(|x| x.as_sync_slice());
860
861            let discounted_centralities = &self
862                .discounted_centralities
863                .iter_mut()
864                .map(|s| s.as_sync_slice())
865                .collect::<Vec<_>>();
866            rayon::broadcast(|c| {
867                Self::parallel_task(
868                    self.graph,
869                    self.transposed,
870                    &self.curr_state,
871                    &next_state_sync,
872                    ic,
873                    sum_of_dists,
874                    sum_of_inv_dists,
875                    discounted_centralities,
876                    c,
877                )
878            });
879        }
880
881        pl.done_with_count(ic.visited_arcs.load(Ordering::Relaxed) as usize);
882        let modified_estimators = ic.modified_estimators.load(Ordering::Relaxed);
883
884        pl.info(format_args!(
885            "Modified estimators: {}/{} ({:.3}%)",
886            modified_estimators,
887            self.graph.num_nodes(),
888            (modified_estimators as f64 / self.graph.num_nodes() as f64) * 100.0
889        ));
890
891        std::mem::swap(&mut self.curr_state, &mut self.next_state);
892        std::mem::swap(&mut ic.curr_modified, &mut ic.next_modified);
893
894        if ic.systolic {
895            std::mem::swap(&mut ic.must_be_checked, &mut ic.next_must_be_checked);
896        }
897
898        let mut current_nf_mut = ic.current_nf.lock().unwrap();
899        self.last = *current_nf_mut;
900        // We enforce monotonicity--non-monotonicity can only be caused by
901        // approximation errors
902        let &last_output = self
903            .neighborhood_function
904            .as_slice()
905            .last()
906            .expect("Should always have at least 1 element");
907        if *current_nf_mut < last_output {
908            *current_nf_mut = last_output;
909        }
910        self.relative_increment = *current_nf_mut / last_output;
911
912        pl.info(format_args!(
913            "Pairs: {} ({}%)",
914            *current_nf_mut,
915            (*current_nf_mut * 100.0) / (num_nodes * num_nodes) as f64
916        ));
917        pl.info(format_args!(
918            "Absolute increment: {}",
919            *current_nf_mut - last_output
920        ));
921        pl.info(format_args!(
922            "Relative increment: {}",
923            self.relative_increment
924        ));
925
926        self.neighborhood_function.push(*current_nf_mut);
927
928        ic.iteration += 1;
929
930        Ok(())
931    }
932
933    /// The parallel operations to be performed each iteration.
934    ///
935    /// # Arguments:
936    /// * `graph`: the graph to analyze.
937    /// * `transpose`: optionally, the transpose of `graph`. If [`None`], no
938    ///   systolic iterations will be performed.
939    /// * `curr_state`: the current state of the estimators.
940    /// * `next_state`: the next state of the estimators (to be computed).
941    /// * `ic`: the iteration context.
942    #[allow(clippy::too_many_arguments)]
943    fn parallel_task(
944        graph: &(impl RandomAccessGraph + Sync),
945        transpose: Option<&(impl RandomAccessGraph + Sync)>,
946        curr_state: &impl EstimatorArray<L>,
947        next_state: &impl SyncEstimatorArray<L>,
948        ic: &IterationContext<'_, G1, D>,
949        sum_of_dists: Option<&[SyncCell<f32>]>,
950        sum_of_inv_dists: Option<&[SyncCell<f32>]>,
951        discounted_centralities: &[&[SyncCell<f32>]],
952        _broadcast_context: rayon::BroadcastContext,
953    ) {
954        let node_granularity = ic.arc_granularity;
955        let arc_granularity = ((graph.num_arcs() as f64 * node_granularity as f64)
956            / graph.num_nodes() as f64)
957            .ceil() as usize;
958        let do_centrality = sum_of_dists.is_some()
959            || sum_of_inv_dists.is_some()
960            || !ic.discount_functions.is_empty();
961        let node_upper_limit = if ic.local {
962            ic.local_checklist.len()
963        } else {
964            graph.num_nodes()
965        };
966        let mut visited_arcs = 0;
967        let mut modified_estimators = 0;
968        let arc_upper_limit = graph.num_arcs();
969
970        // During standard iterations, cumulates the neighborhood function for the nodes scanned
971        // by this thread. During systolic iterations, cumulates the *increase* of the
972        // neighborhood function for the nodes scanned by this thread.
973        let mut neighborhood_function_delta = KahanSum::new_with_value(0.0);
974        let mut helper = curr_state.logic().new_helper();
975        let logic = curr_state.logic();
976        let mut next_estimator = logic.new_estimator();
977
978        loop {
979            // Get work
980            let (start, end) = if ic.local {
981                let start = std::cmp::min(
982                    ic.node_cursor.fetch_add(1, Ordering::Relaxed),
983                    node_upper_limit,
984                );
985                let end = std::cmp::min(start + 1, node_upper_limit);
986                (start, end)
987            } else {
988                let mut arc_balanced_cursor = ic.arc_cursor.lock().unwrap();
989                let (mut next_node, mut next_arc) = *arc_balanced_cursor;
990                if next_node >= node_upper_limit {
991                    (node_upper_limit, node_upper_limit)
992                } else {
993                    let start = next_node;
994                    let target = next_arc + arc_granularity;
995                    if target as u64 >= arc_upper_limit {
996                        next_node = node_upper_limit;
997                    } else {
998                        (next_node, next_arc) = ic.cumul_outdeg.succ(target).unwrap();
999                    }
1000                    let end = next_node;
1001                    *arc_balanced_cursor = (next_node, next_arc);
1002                    (start, end)
1003                }
1004            };
1005
1006            if start == node_upper_limit {
1007                break;
1008            }
1009
1010            // Do work
1011            for i in start..end {
1012                let node = if ic.local { ic.local_checklist[i] } else { i };
1013
1014                let prev_estimator = curr_state.get_backend(node);
1015
1016                // The three cases in which we enumerate successors:
1017                // 1) A non-systolic computation (we don't know anything, so we enumerate).
1018                // 2) A systolic, local computation (the node is by definition to be checked, as it comes from the local check list).
1019                // 3) A systolic, non-local computation in which the node should be checked.
1020                if !ic.systolic || ic.local || ic.must_be_checked[node] {
1021                    next_estimator.set(prev_estimator);
1022                    let mut modified = false;
1023                    for succ in graph.successors(node) {
1024                        if succ != node && ic.curr_modified[succ] {
1025                            visited_arcs += 1;
1026                            if !modified {
1027                                modified = true;
1028                            }
1029                            logic.merge_with_helper(
1030                                next_estimator.as_mut(),
1031                                curr_state.get_backend(succ),
1032                                &mut helper,
1033                            );
1034                        }
1035                    }
1036
1037                    let mut post = f64::NAN;
1038                    let estimator_modified = modified && next_estimator.as_ref() != prev_estimator;
1039
1040                    // We need the estimator value only if the iteration is standard (as we're going to
1041                    // compute the neighborhood function cumulating actual values, and not deltas) or
1042                    // if the estimator was actually modified (as we're going to cumulate the neighborhood
1043                    // function delta, or at least some centrality).
1044                    if !ic.systolic || estimator_modified {
1045                        post = logic.estimate(next_estimator.as_ref())
1046                    }
1047                    if !ic.systolic {
1048                        neighborhood_function_delta += post;
1049                    }
1050
1051                    if estimator_modified && (ic.systolic || do_centrality) {
1052                        let pre = logic.estimate(prev_estimator);
1053                        if ic.systolic {
1054                            neighborhood_function_delta += -pre;
1055                            neighborhood_function_delta += post;
1056                        }
1057
1058                        if do_centrality {
1059                            let delta = post - pre;
1060                            // Note that this code is executed only for distances > 0
1061                            if delta > 0.0 {
1062                                if let Some(distances) = sum_of_dists {
1063                                    let new_value = delta * (ic.iteration + 1) as f64;
1064                                    unsafe {
1065                                        distances[node]
1066                                            .set((distances[node].get() as f64 + new_value) as f32)
1067                                    };
1068                                }
1069                                if let Some(distances) = sum_of_inv_dists {
1070                                    let new_value = delta / (ic.iteration + 1) as f64;
1071                                    unsafe {
1072                                        distances[node]
1073                                            .set((distances[node].get() as f64 + new_value) as f32)
1074                                    };
1075                                }
1076                                for (func, distances) in ic
1077                                    .discount_functions
1078                                    .iter()
1079                                    .zip(discounted_centralities.iter())
1080                                {
1081                                    let new_value = delta * func(ic.iteration + 1);
1082                                    unsafe {
1083                                        distances[node]
1084                                            .set((distances[node].get() as f64 + new_value) as f32)
1085                                    };
1086                                }
1087                            }
1088                        }
1089                    }
1090
1091                    if estimator_modified {
1092                        // We keep track of modified estimators in the result. Note that we must
1093                        // add the current node to the must-be-checked set for the next
1094                        // local iteration if it is modified, as it might need a copy to
1095                        // the result array at the next iteration.
1096                        if ic.pre_local {
1097                            ic.local_next_must_be_checked.lock().unwrap().push(node);
1098                        }
1099                        ic.next_modified.set(node, true, Ordering::Relaxed);
1100
1101                        if ic.systolic {
1102                            debug_assert!(transpose.is_some());
1103                            // In systolic computations we must keep track of
1104                            // which estimators must be checked on the next
1105                            // iteration. If we are preparing a local
1106                            // computation, we do this explicitly, by adding the
1107                            // predecessors of the current node to a set.
1108                            // Otherwise, we do this implicitly, by setting the
1109                            // corresponding entry in an array.
1110
1111                            // SAFETY: ic.systolic is true, so transpose is Some
1112                            let transpose = unsafe { transpose.unwrap_unchecked() };
1113                            if ic.pre_local {
1114                                let mut local_next_must_be_checked =
1115                                    ic.local_next_must_be_checked.lock().unwrap();
1116                                for succ in transpose.successors(node) {
1117                                    local_next_must_be_checked.push(succ);
1118                                }
1119                            } else {
1120                                for succ in transpose.successors(node) {
1121                                    ic.next_must_be_checked.set(succ, true, Ordering::Relaxed);
1122                                }
1123                            }
1124                        }
1125
1126                        modified_estimators += 1;
1127                    }
1128
1129                    unsafe {
1130                        next_state.set(node, next_estimator.as_ref());
1131                    }
1132                } else {
1133                    // Even if we cannot possibly have changed our value, still our copy
1134                    // in the result vector might need to be updated because it does not
1135                    // reflect our current value.
1136                    if ic.curr_modified[node] {
1137                        unsafe {
1138                            next_state.set(node, prev_estimator);
1139                        }
1140                    }
1141                }
1142            }
1143        }
1144
1145        *ic.current_nf.lock().unwrap() += neighborhood_function_delta.sum();
1146        ic.visited_arcs.fetch_add(visited_arcs, Ordering::Relaxed);
1147        ic.modified_estimators
1148            .fetch_add(modified_estimators, Ordering::Relaxed);
1149    }
1150
1151    /// Initializes HyperBall.
1152    fn init(&mut self, mut rng: impl rand::Rng, pl: &mut impl ConcurrentProgressLog) -> Result<()> {
1153        pl.start("Initializing estimators");
1154        pl.info(format_args!("Clearing all registers"));
1155
1156        self.curr_state.clear();
1157        self.next_state.clear();
1158
1159        pl.info(format_args!("Initializing registers"));
1160        if let Some(w) = &self.weight {
1161            pl.info(format_args!("Loading weights"));
1162            for (i, &node_weight) in w.iter().enumerate() {
1163                let mut estimator = self.curr_state.get_estimator_mut(i);
1164                for _ in 0..node_weight {
1165                    estimator.add(&(rng.random::<u64>() as usize));
1166                }
1167            }
1168        } else {
1169            (0..self.graph.num_nodes()).for_each(|i| {
1170                self.curr_state.get_estimator_mut(i).add(i);
1171            });
1172        }
1173
1174        self.completed = false;
1175
1176        let ic = &mut self.iteration_context;
1177        ic.iteration = 0;
1178        ic.systolic = false;
1179        ic.local = false;
1180        ic.pre_local = false;
1181        ic.reset(self.granularity);
1182
1183        pl.debug(format_args!("Initializing distances"));
1184        if let Some(distances) = &mut self.sum_of_dists {
1185            distances.fill(0.0);
1186        }
1187        if let Some(distances) = &mut self.sum_of_inv_dists {
1188            distances.fill(0.0);
1189        }
1190        pl.debug(format_args!("Initializing centralities"));
1191        for centralities in self.discounted_centralities.iter_mut() {
1192            centralities.fill(0.0);
1193        }
1194
1195        self.last = self.graph.num_nodes() as f64;
1196        pl.debug(format_args!("Initializing neighborhood function"));
1197        self.neighborhood_function.clear();
1198        self.neighborhood_function.push(self.last);
1199
1200        pl.debug(format_args!("Initializing modified estimators"));
1201        ic.curr_modified.fill(true, Ordering::Relaxed);
1202
1203        pl.done();
1204
1205        Ok(())
1206    }
1207}
1208
1209#[cfg(test)]
1210mod test {
1211    use std::hash::{BuildHasherDefault, DefaultHasher};
1212
1213    use super::*;
1214    use card_est_array::traits::{EstimatorArray, MergeEstimator};
1215    use dsi_progress_logger::no_logging;
1216    use epserde::deser::{Deserialize, Flags};
1217    use rand::SeedableRng;
1218    use webgraph::{
1219        prelude::{BvGraph, DCF},
1220        traits::SequentialLabeling,
1221    };
1222
1223    type HyperBallArray<G> = SliceEstimatorArray<
1224        HyperLogLog<<G as SequentialLabeling>::Label, BuildHasherDefault<DefaultHasher>, usize>,
1225        usize,
1226        Box<[usize]>,
1227    >;
1228
1229    struct SeqHyperBall<'a, G: RandomAccessGraph> {
1230        graph: &'a G,
1231        curr_state: HyperBallArray<G>,
1232        next_state: HyperBallArray<G>,
1233    }
1234
1235    impl<G: RandomAccessGraph> SeqHyperBall<'_, G> {
1236        fn init(&mut self) {
1237            for i in 0..self.graph.num_nodes() {
1238                self.curr_state.get_estimator_mut(i).add(i);
1239            }
1240        }
1241
1242        fn iterate(&mut self) {
1243            for i in 0..self.graph.num_nodes() {
1244                let mut estimator = self.next_state.get_estimator_mut(i);
1245                estimator.set(self.curr_state.get_backend(i));
1246                for succ in self.graph.successors(i) {
1247                    estimator.merge(self.curr_state.get_backend(succ));
1248                }
1249            }
1250            std::mem::swap(&mut self.curr_state, &mut self.next_state);
1251        }
1252    }
1253
1254    #[cfg_attr(feature = "slow_tests", test)]
1255    #[cfg_attr(not(feature = "slow_tests"), allow(dead_code))]
1256    fn test_cnr_2000() -> Result<()> {
1257        let basename = "../data/cnr-2000";
1258
1259        let graph = BvGraph::with_basename(basename).load()?;
1260        let transpose = BvGraph::with_basename(basename.to_owned() + "-t").load()?;
1261        let cumulative = unsafe { DCF::load_mmap(basename.to_owned() + ".dcf", Flags::empty()) }?;
1262
1263        let num_nodes = graph.num_nodes();
1264
1265        let hyper_log_log = HyperLogLogBuilder::new(num_nodes)
1266            .log_2_num_reg(6)
1267            .build()?;
1268
1269        let seq_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1270        let seq_result_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1271        let par_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1272        let par_result_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1273
1274        let mut hyperball = HyperBallBuilder::with_transpose(
1275            &graph,
1276            &transpose,
1277            cumulative.uncase(),
1278            par_bits,
1279            par_result_bits,
1280        )
1281        .build(no_logging![]);
1282        let mut seq_hyperball = SeqHyperBall {
1283            curr_state: seq_bits,
1284            next_state: seq_result_bits,
1285            graph: &graph,
1286        };
1287
1288        let mut modified_estimators = num_nodes as u64;
1289        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
1290        hyperball.init(&mut rng, no_logging![])?;
1291        seq_hyperball.init();
1292
1293        while modified_estimators != 0 {
1294            hyperball.iterate(no_logging![])?;
1295            seq_hyperball.iterate();
1296
1297            modified_estimators = hyperball
1298                .iteration_context
1299                .modified_estimators
1300                .load(Ordering::Relaxed);
1301
1302            assert_eq!(
1303                hyperball.next_state.as_ref(),
1304                seq_hyperball.next_state.as_ref()
1305            );
1306            assert_eq!(
1307                hyperball.curr_state.as_ref(),
1308                seq_hyperball.curr_state.as_ref()
1309            );
1310        }
1311
1312        Ok(())
1313    }
1314}