webgraph_algo/distances/
hyperball.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Matteo Dell'Acqua
3 * SPDX-FileCopyrightText: 2025 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8use anyhow::{Context, Result, bail, ensure};
9use card_est_array::impls::{HyperLogLog, HyperLogLogBuilder, SliceEstimatorArray};
10use card_est_array::traits::{
11    AsSyncArray, EstimationLogic, EstimatorArray, EstimatorArrayMut, EstimatorMut,
12    MergeEstimationLogic, SyncEstimatorArray,
13};
14use crossbeam_utils::CachePadded;
15use dsi_progress_logger::ConcurrentProgressLog;
16use kahan::KahanSum;
17use rayon::prelude::*;
18use std::hash::{BuildHasherDefault, DefaultHasher};
19use std::sync::{Mutex, atomic::*};
20use sux::traits::AtomicBitVecOps;
21use sux::{bits::AtomicBitVec, traits::Succ};
22use sync_cell_slice::{SyncCell, SyncSlice};
23use webgraph::traits::{RandomAccessGraph, SequentialLabeling};
24use webgraph::utils::Granularity;
25
26/// A builder for [`HyperBall`].
27///
28/// After creating a builder with [`HyperBallBuilder::new`] you can configure it
29/// using setters such as [`HyperBallBuilder`] its methods, then call
30/// [`HyperBallBuilder::build`] on it to create a [`HyperBall`] instance.
31pub struct HyperBallBuilder<
32    'a,
33    G1: RandomAccessGraph + Sync,
34    G2: RandomAccessGraph + Sync,
35    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
36    L: MergeEstimationLogic<Item = G1::Label>,
37    A: EstimatorArrayMut<L>,
38> {
39    /// A graph.
40    graph: &'a G1,
41    /// The transpose of `graph`, if any.
42    transpose: Option<&'a G2>,
43    /// The outdegree cumulative function of the graph.
44    cumul_outdegree: &'a D,
45    /// Whether to compute the sum of distances (e.g., for closeness centrality).
46    do_sum_of_dists: bool,
47    /// Whether to compute the sum of inverse distances (e.g., for harmonic centrality).
48    do_sum_of_inv_dists: bool,
49    /// Custom discount functions whose sum should be computed.
50    discount_functions: Vec<Box<dyn Fn(usize) -> f64 + Send + Sync + 'a>>,
51    /// The arc granularity.
52    arc_granularity: usize,
53    /// Integer weights for the nodes, if any.
54    weights: Option<&'a [usize]>,
55    /// A first array of estimators.
56    array_0: A,
57    /// A second array of estimators of the same length and with the same logic of
58    /// `array_0`.
59    array_1: A,
60    _marker: std::marker::PhantomData<L>,
61}
62
63impl<
64    'a,
65    G1: RandomAccessGraph + Sync,
66    G2: RandomAccessGraph + Sync,
67    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
68>
69    HyperBallBuilder<
70        'a,
71        G1,
72        G2,
73        D,
74        HyperLogLog<G1::Label, BuildHasherDefault<DefaultHasher>, usize>,
75        SliceEstimatorArray<
76            HyperLogLog<G1::Label, BuildHasherDefault<DefaultHasher>, usize>,
77            usize,
78            Box<[usize]>,
79        >,
80    >
81{
82    /// A builder for [`HyperBall`] using a specified [`EstimationLogic`].
83    ///
84    /// # Arguments
85    /// * `graph`: the graph to analyze.
86    /// * `transpose`: optionally, the transpose of `graph`. If [`None`], no
87    ///   systolic iterations will be performed by the resulting [`HyperBall`].
88    /// * `cumul_outdeg`: the outdegree cumulative function of the graph.
89    /// * `log2m`: the base-2 logarithm of the number *m* of register per
90    ///   HyperLogLog cardinality estimator.
91    /// * `weights`: the weights to use. If [`None`] every node is assumed to be
92    ///   of weight equal to 1.
93    pub fn with_hyper_log_log(
94        graph: &'a G1,
95        transposed: Option<&'a G2>,
96        cumul_outdeg: &'a D,
97        log2m: usize,
98        weights: Option<&'a [usize]>,
99    ) -> Result<Self> {
100        let num_elements = if let Some(w) = weights {
101            ensure!(
102                w.len() == graph.num_nodes(),
103                "weights should have length equal to the graph's number of nodes"
104            );
105            w.iter().sum()
106        } else {
107            graph.num_nodes()
108        };
109
110        let logic = HyperLogLogBuilder::new(num_elements)
111            .log_2_num_reg(log2m)
112            .build()
113            .with_context(|| "Could not build HyperLogLog logic")?;
114
115        let array_0 = SliceEstimatorArray::new(logic.clone(), graph.num_nodes());
116        let array_1 = SliceEstimatorArray::new(logic, graph.num_nodes());
117
118        Ok(Self {
119            graph,
120            transpose: transposed,
121            cumul_outdegree: cumul_outdeg,
122            do_sum_of_dists: false,
123            do_sum_of_inv_dists: false,
124            discount_functions: Vec::new(),
125            arc_granularity: Self::DEFAULT_GRANULARITY,
126            weights,
127            array_0,
128            array_1,
129            _marker: std::marker::PhantomData,
130        })
131    }
132}
133
134impl<
135    'a,
136    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
137    G: RandomAccessGraph + Sync,
138    L: MergeEstimationLogic<Item = G::Label> + PartialEq,
139    A: EstimatorArrayMut<L>,
140> HyperBallBuilder<'a, G, G, D, L, A>
141{
142    /// Creates a new builder with default parameters.
143    ///
144    /// # Arguments
145    /// * `graph`: the graph to analyze.
146    /// * `cumul_outdeg`: the outdegree cumulative function of the graph.
147    /// * `array_0`: a first array of estimators.
148    /// * `array_1`: A second array of estimators of the same length and with the same logic of
149    ///   `array_0`.
150    pub fn new(graph: &'a G, cumul_outdeg: &'a D, array_0: A, array_1: A) -> Self {
151        assert!(array_0.logic() == array_1.logic(), "Incompatible logic");
152        assert_eq!(
153            graph.num_nodes(),
154            array_0.len(),
155            "array_0 should have length {}. Got {}",
156            graph.num_nodes(),
157            array_0.len()
158        );
159        assert_eq!(
160            graph.num_nodes(),
161            array_1.len(),
162            "array_1 should have length {}. Got {}",
163            graph.num_nodes(),
164            array_1.len()
165        );
166        Self {
167            graph,
168            transpose: None,
169            cumul_outdegree: cumul_outdeg,
170            do_sum_of_dists: false,
171            do_sum_of_inv_dists: false,
172            discount_functions: Vec::new(),
173            arc_granularity: Self::DEFAULT_GRANULARITY,
174            weights: None,
175            array_0,
176            array_1,
177            _marker: std::marker::PhantomData,
178        }
179    }
180}
181
182impl<
183    'a,
184    G1: RandomAccessGraph + Sync,
185    G2: RandomAccessGraph + Sync,
186    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
187    L: MergeEstimationLogic<Item = G1::Label>,
188    A: EstimatorArrayMut<L>,
189> HyperBallBuilder<'a, G1, G2, D, L, A>
190{
191    const DEFAULT_GRANULARITY: usize = 16 * 1024;
192
193    /// Creates a new builder with default parameters using also the transpose.
194    ///
195    /// * `graph`: the graph to analyze.
196    /// * `transpose`: optionally, the transpose of `graph`. If [`None`], no
197    ///   systolic iterations will be performed by the resulting [`HyperBall`].
198    /// * `cumul_outdeg`: the outdegree cumulative function of the graph.
199    /// * `array_0`: a first array of estimators.
200    /// * `array_1`: A second array of estimators of the same length and with the same logic of
201    ///   `array_0`.
202    pub fn with_transpose(
203        graph: &'a G1,
204        transpose: &'a G2,
205        cumul_outdeg: &'a D,
206        array_0: A,
207        array_1: A,
208    ) -> Self {
209        assert_eq!(
210            graph.num_nodes(),
211            array_0.len(),
212            "array_0 should have have len {}. Got {}",
213            graph.num_nodes(),
214            array_0.len()
215        );
216        assert_eq!(
217            graph.num_nodes(),
218            array_1.len(),
219            "array_1 should have have len {}. Got {}",
220            graph.num_nodes(),
221            array_1.len()
222        );
223        assert_eq!(
224            transpose.num_nodes(),
225            graph.num_nodes(),
226            "the transpose should have same number of nodes of the graph ({}). Got {}.",
227            graph.num_nodes(),
228            transpose.num_nodes()
229        );
230        assert_eq!(
231            transpose.num_arcs(),
232            graph.num_arcs(),
233            "the transpose should have same number of nodes of the graph ({}). Got {}.",
234            graph.num_arcs(),
235            transpose.num_arcs()
236        );
237        /* TODO debug_assert!(
238            check_transposed(graph, transpose),
239            "the transpose should be the transpose of the graph"
240        );*/
241        Self {
242            graph,
243            transpose: Some(transpose),
244            cumul_outdegree: cumul_outdeg,
245            do_sum_of_dists: false,
246            do_sum_of_inv_dists: false,
247            discount_functions: Vec::new(),
248            arc_granularity: Self::DEFAULT_GRANULARITY,
249            weights: None,
250            array_0,
251            array_1,
252            _marker: std::marker::PhantomData,
253        }
254    }
255
256    /// Sets whether to compute the sum of distances.
257    pub fn sum_of_distances(mut self, do_sum_of_distances: bool) -> Self {
258        self.do_sum_of_dists = do_sum_of_distances;
259        self
260    }
261
262    /// Sets whether to compute the sum of inverse distances.
263    pub fn sum_of_inverse_distances(mut self, do_sum_of_inverse_distances: bool) -> Self {
264        self.do_sum_of_inv_dists = do_sum_of_inverse_distances;
265        self
266    }
267
268    /// Sets the base granularity used in the parallel phases of the iterations.
269    pub fn granularity(mut self, granularity: Granularity) -> Self {
270        self.arc_granularity =
271            granularity.arc_granularity(self.graph.num_nodes(), Some(self.graph.num_arcs()));
272        self
273    }
274
275    /// Sets optional weights for the nodes of the graph.
276    ///
277    /// # Arguments
278    /// * `weights`: weights to use for the nodes. If [`None`], every node is
279    ///   assumed to be of weight equal to 1.
280    pub fn weights(mut self, weights: Option<&'a [usize]>) -> Self {
281        if let Some(w) = weights {
282            assert_eq!(w.len(), self.graph.num_nodes());
283        }
284        self.weights = weights;
285        self
286    }
287
288    /// Adds a new discount function whose sum over all spheres should be
289    /// computed.
290    pub fn discount_function(
291        mut self,
292        discount_function: impl Fn(usize) -> f64 + Send + Sync + 'a,
293    ) -> Self {
294        self.discount_functions.push(Box::new(discount_function));
295        self
296    }
297
298    /// Removes all custom discount functions.
299    pub fn no_discount_function(mut self) -> Self {
300        self.discount_functions.clear();
301        self
302    }
303}
304
305impl<
306    'a,
307    G1: RandomAccessGraph + Sync,
308    G2: RandomAccessGraph + Sync,
309    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
310    L: MergeEstimationLogic<Item = G1::Label> + Sync + std::fmt::Display,
311    A: EstimatorArrayMut<L>,
312> HyperBallBuilder<'a, G1, G2, D, L, A>
313{
314    /// Builds a [`HyperBall`] instance.
315    ///
316    /// # Arguments
317    ///
318    /// * `pl`: A progress logger.
319    pub fn build(self, pl: &mut impl ConcurrentProgressLog) -> HyperBall<'a, G1, G2, D, L, A> {
320        let num_nodes = self.graph.num_nodes();
321
322        let sum_of_distances = if self.do_sum_of_dists {
323            pl.debug(format_args!("Initializing sum of distances"));
324            Some(vec![0.0; num_nodes])
325        } else {
326            pl.debug(format_args!("Skipping sum of distances"));
327            None
328        };
329        let sum_of_inverse_distances = if self.do_sum_of_inv_dists {
330            pl.debug(format_args!("Initializing sum of inverse distances"));
331            Some(vec![0.0; num_nodes])
332        } else {
333            pl.debug(format_args!("Skipping sum of inverse distances"));
334            None
335        };
336
337        let mut discounted_centralities = Vec::new();
338        pl.debug(format_args!(
339            "Initializing {} discount functions",
340            self.discount_functions.len()
341        ));
342        for _ in self.discount_functions.iter() {
343            discounted_centralities.push(vec![0.0; num_nodes]);
344        }
345
346        pl.info(format_args!("Initializing bit vectors"));
347        let estimator_modified = AtomicBitVec::new(num_nodes);
348        let modified_result_estimator = AtomicBitVec::new(num_nodes);
349        let must_be_checked = AtomicBitVec::new(num_nodes);
350        let next_must_be_checked = AtomicBitVec::new(num_nodes);
351
352        pl.info(format_args!(
353            "Using estimation logic: {}",
354            self.array_0.logic()
355        ));
356
357        HyperBall {
358            graph: self.graph,
359            transposed: self.transpose,
360            weight: self.weights,
361            granularity: self.arc_granularity,
362            curr_state: self.array_0,
363            next_state: self.array_1,
364            completed: false,
365            neighborhood_function: Vec::new(),
366            last: 0.0,
367            relative_increment: 0.0,
368            sum_of_dists: sum_of_distances,
369            sum_of_inv_dists: sum_of_inverse_distances,
370            discounted_centralities,
371            iteration_context: IterationContext {
372                cumul_outdeg: self.cumul_outdegree,
373                iteration: 0,
374                current_nf: Mutex::new(0.0),
375                arc_granularity: 0,
376                node_cursor: AtomicUsize::new(0).into(),
377                arc_cursor: Mutex::new((0, 0)),
378                visited_arcs: AtomicU64::new(0).into(),
379                modified_estimators: AtomicU64::new(0).into(),
380                systolic: false,
381                local: false,
382                pre_local: false,
383                local_checklist: Vec::new(),
384                local_next_must_be_checked: Mutex::new(Vec::new()),
385                must_be_checked,
386                next_must_be_checked,
387                curr_modified: estimator_modified,
388                next_modified: modified_result_estimator,
389                discount_functions: self.discount_functions,
390            },
391            _marker: std::marker::PhantomData,
392        }
393    }
394}
395
396/// Data used by [`parallel_task`](HyperBall::parallel_task).
397///
398/// These variables are used by the threads running
399/// [`parallel_task`](HyperBall::parallel_task). They must be isolated in a
400/// field because we need to be able to borrow exclusively
401/// [`HyperBall::next_state`], while sharing references to the data contained
402/// here and to the [`HyperBall::curr_state`].
403struct IterationContext<'a, G1: SequentialLabeling, D> {
404    /// The cumulative list of outdegrees.
405    cumul_outdeg: &'a D,
406    /// The number of the current iteration.
407    iteration: usize,
408    /// The value of the neighborhood function computed during the current iteration.
409    current_nf: Mutex<f64>,
410    /// The arc granularity: each task will try to process at least this number
411    /// of arcs.
412    arc_granularity: usize,
413    /// A cursor scanning the nodes to process during local computations.
414    node_cursor: CachePadded<AtomicUsize>,
415    /// A cursor scanning the nodes and arcs to process during non-local
416    /// computations.
417    arc_cursor: Mutex<(usize, usize)>,
418    /// The number of arcs visited during the current iteration.
419    visited_arcs: CachePadded<AtomicU64>,
420    /// The number of estimators modified during the current iteration.
421    modified_estimators: CachePadded<AtomicU64>,
422    /// `true` if we started a systolic computation.
423    systolic: bool,
424    /// `true` if we started a local computation.
425    local: bool,
426    /// `true` if we are preparing a local computation (systolic is `true` and less than 1% nodes were modified).
427    pre_local: bool,
428    /// If [`local`](Self::local) is `true`, the sorted list of nodes that
429    /// should be scanned.
430    local_checklist: Vec<G1::Label>,
431    /// If [`pre_local`](Self::pre_local) is `true`, the set of nodes that
432    /// should be scanned on the next iteration.
433    local_next_must_be_checked: Mutex<Vec<G1::Label>>,
434    /// Used in systolic iterations to keep track of nodes to check.
435    must_be_checked: AtomicBitVec,
436    /// Used in systolic iterations to keep track of nodes to check in the next
437    /// iteration.
438    next_must_be_checked: AtomicBitVec,
439    /// Whether each estimator has been modified during the previous iteration.
440    curr_modified: AtomicBitVec,
441    /// Whether each estimator has been modified during the current iteration.
442    next_modified: AtomicBitVec,
443    /// Custom discount functions whose sum should be computed.
444    discount_functions: Vec<Box<dyn Fn(usize) -> f64 + Send + Sync + 'a>>,
445}
446
447impl<G1: SequentialLabeling, D> IterationContext<'_, G1, D> {
448    /// Resets the iteration context
449    fn reset(&mut self, granularity: usize) {
450        self.arc_granularity = granularity;
451        self.node_cursor.store(0, Ordering::Relaxed);
452        *self.arc_cursor.lock().unwrap() = (0, 0);
453        self.visited_arcs.store(0, Ordering::Relaxed);
454        self.modified_estimators.store(0, Ordering::Relaxed);
455    }
456}
457
458/// An algorithm that computes an approximation of the neighborhood function,
459/// of the size of the reachable sets, and of (discounted) positive geometric
460/// centralities of a graph.
461pub struct HyperBall<
462    'a,
463    G1: RandomAccessGraph + Sync,
464    G2: RandomAccessGraph + Sync,
465    D: for<'b> Succ<Input = usize, Output<'b> = usize>,
466    L: MergeEstimationLogic<Item = G1::Label> + Sync,
467    A: EstimatorArrayMut<L>,
468> {
469    /// The graph to analyze.
470    graph: &'a G1,
471    /// The transpose of [`Self::graph`], if any.
472    transposed: Option<&'a G2>,
473    /// An optional slice of nonnegative node weights.
474    weight: Option<&'a [usize]>,
475    /// The base number of nodes per task. TODO.
476    granularity: usize,
477    /// The previous state.
478    curr_state: A,
479    /// The next state.
480    next_state: A,
481    /// `true` if the computation is over.
482    completed: bool,
483    /// The neighborhood function.
484    neighborhood_function: Vec<f64>,
485    /// The value computed by the last iteration.
486    last: f64,
487    /// The relative increment of the neighborhood function for the last
488    /// iteration.
489    relative_increment: f64,
490    /// The sum of the distances from every given node, if requested.
491    sum_of_dists: Option<Vec<f32>>,
492    /// The sum of inverse distances from each given node, if requested.
493    sum_of_inv_dists: Option<Vec<f32>>,
494    /// The overall discount centrality for every discount function.
495    discounted_centralities: Vec<Vec<f32>>,
496    /// Context used in a single iteration.
497    iteration_context: IterationContext<'a, G1, D>,
498    _marker: std::marker::PhantomData<L>,
499}
500
501impl<
502    G1: RandomAccessGraph + Sync,
503    G2: RandomAccessGraph + Sync,
504    D: for<'b> Succ<Input = usize, Output<'b> = usize> + Sync,
505    L: MergeEstimationLogic<Item = usize> + Sync,
506    A: EstimatorArrayMut<L> + Sync + AsSyncArray<L>,
507> HyperBall<'_, G1, G2, D, L, A>
508where
509    L::Backend: PartialEq,
510{
511    /// Runs HyperBall.
512    ///
513    /// # Arguments
514    ///
515    /// * `upper_bound`: an upper bound to the number of iterations.
516    ///
517    /// * `threshold`: a value that will be used to stop the computation by
518    ///   relative increment if the neighborhood function is being computed. If
519    ///   [`None`] the computation will stop when no estimators are modified.
520    ///
521    /// * `pl`: A progress logger.
522    pub fn run(
523        &mut self,
524        upper_bound: usize,
525        threshold: Option<f64>,
526        rng: impl rand::Rng,
527        pl: &mut impl ConcurrentProgressLog,
528    ) -> Result<()> {
529        let upper_bound = std::cmp::min(upper_bound, self.graph.num_nodes());
530
531        self.init(rng, pl)
532            .with_context(|| "Could not initialize estimator")?;
533
534        pl.item_name("iteration");
535        pl.expected_updates(None);
536        pl.start(format!(
537            "Running HyperBall for a maximum of {} iterations and a threshold of {:?}",
538            upper_bound, threshold
539        ));
540
541        for i in 0..upper_bound {
542            self.iterate(pl)
543                .with_context(|| format!("Could not perform iteration {}", i + 1))?;
544
545            pl.update_and_display();
546
547            if self
548                .iteration_context
549                .modified_estimators
550                .load(Ordering::Relaxed)
551                == 0
552            {
553                pl.info(format_args!(
554                    "Terminating HyperBall after {} iteration(s) by stabilization",
555                    i + 1
556                ));
557                break;
558            }
559
560            if let Some(t) = threshold {
561                if i > 3 && self.relative_increment < (1.0 + t) {
562                    pl.info(format_args!("Terminating HyperBall after {} iteration(s) by relative bound on the neighborhood function", i + 1));
563                    break;
564                }
565            }
566        }
567
568        pl.done();
569
570        Ok(())
571    }
572
573    /// Runs HyperBall until no estimators are modified.
574    ///
575    /// # Arguments
576    ///
577    /// * `upper_bound`: an upper bound to the number of iterations.
578    ///
579    /// * `pl`: A progress logger.
580    #[inline(always)]
581    pub fn run_until_stable(
582        &mut self,
583        upper_bound: usize,
584        rng: impl rand::Rng,
585        pl: &mut impl ConcurrentProgressLog,
586    ) -> Result<()> {
587        self.run(upper_bound, None, rng, pl)
588            .with_context(|| "Could not complete run_until_stable")
589    }
590
591    /// Runs HyperBall until no estimators are modified with no upper bound on the
592    /// number of iterations.
593    ///
594    /// # Arguments
595    ///
596    /// * `pl`: A progress logger.
597    #[inline(always)]
598    pub fn run_until_done(
599        &mut self,
600        rng: impl rand::Rng,
601        pl: &mut impl ConcurrentProgressLog,
602    ) -> Result<()> {
603        self.run_until_stable(usize::MAX, rng, pl)
604            .with_context(|| "Could not complete run_until_done")
605    }
606
607    #[inline(always)]
608    fn ensure_iteration(&self) -> Result<()> {
609        ensure!(
610            self.iteration_context.iteration > 0,
611            "HyperBall was not run. Please call HyperBall::run before accessing computed fields."
612        );
613        Ok(())
614    }
615
616    /// Returns the neighborhood function computed by this instance.
617    pub fn neighborhood_function(&self) -> Result<&[f64]> {
618        self.ensure_iteration()?;
619        Ok(&self.neighborhood_function)
620    }
621
622    /// Returns the sum of distances computed by this instance if requested.
623    pub fn sum_of_distances(&self) -> Result<&[f32]> {
624        self.ensure_iteration()?;
625        if let Some(distances) = &self.sum_of_dists {
626            // TODO these are COPIES
627            Ok(distances)
628        } else {
629            bail!(
630                "Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute them"
631            )
632        }
633    }
634
635    /// Returns the harmonic centralities (sum of inverse distances) computed by this instance if requested.
636    pub fn harmonic_centralities(&self) -> Result<&[f32]> {
637        self.ensure_iteration()?;
638        if let Some(distances) = &self.sum_of_inv_dists {
639            Ok(distances)
640        } else {
641            bail!(
642                "Sum of inverse distances were not requested: use builder.with_sum_of_inverse_distances(true) while building HyperBall to compute them"
643            )
644        }
645    }
646
647    /// Returns the discounted centralities of the specified index computed by this instance.
648    ///
649    /// # Arguments
650    /// * `index`: the index of the requested discounted centrality.
651    pub fn discounted_centrality(&self, index: usize) -> Result<&[f32]> {
652        self.ensure_iteration()?;
653        let d = self.discounted_centralities.get(index);
654        if let Some(distances) = d {
655            Ok(distances)
656        } else {
657            bail!("Discount centrality of index {} does not exist", index)
658        }
659    }
660
661    /// Computes and returns the closeness centralities from the sum of distances computed by this instance.
662    pub fn closeness_centrality(&self) -> Result<Box<[f32]>> {
663        self.ensure_iteration()?;
664        if let Some(distances) = &self.sum_of_dists {
665            Ok(distances
666                .iter()
667                .map(|&d| if d == 0.0 { 0.0 } else { d.recip() })
668                .collect())
669        } else {
670            bail!(
671                "Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute closeness centrality"
672            )
673        }
674    }
675
676    /// Computes and returns the lin centralities from the sum of distances computed by this instance.
677    ///
678    /// Note that lin's index for isolated nodes is by (our) definition one (it's smaller than any other node).
679    pub fn lin_centrality(&self) -> Result<Box<[f32]>> {
680        self.ensure_iteration()?;
681        if let Some(distances) = &self.sum_of_dists {
682            let logic = self.curr_state.logic();
683            Ok(distances
684                .iter()
685                .enumerate()
686                .map(|(node, &d)| {
687                    if d == 0.0 {
688                        1.0
689                    } else {
690                        let count = logic.estimate(self.curr_state.get_backend(node));
691                        (count * count / d as f64) as f32
692                    }
693                })
694                .collect())
695        } else {
696            bail!(
697                "Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute lin centrality"
698            )
699        }
700    }
701
702    /// Computes and returns the Nieminen centralities from the sum of distances computed by this instance.
703    pub fn nieminen_centrality(&self) -> Result<Box<[f32]>> {
704        self.ensure_iteration()?;
705        if let Some(distances) = &self.sum_of_dists {
706            let logic = self.curr_state.logic();
707            Ok(distances
708                .iter()
709                .enumerate()
710                .map(|(node, &d)| {
711                    let count = logic.estimate(self.curr_state.get_backend(node));
712                    ((count * count) - d as f64) as f32
713                })
714                .collect())
715        } else {
716            bail!(
717                "Sum of distances were not requested: use builder.with_sum_of_distances(true) while building HyperBall to compute lin centrality"
718            )
719        }
720    }
721
722    /// Reads from the internal estimator array and estimates the number of nodes
723    /// reachable from the specified node.
724    ///
725    /// # Arguments
726    /// * `node`: the index of the node to compute reachable nodes from.
727    pub fn reachable_nodes_from(&self, node: usize) -> Result<f64> {
728        self.ensure_iteration()?;
729        Ok(self
730            .curr_state
731            .logic()
732            .estimate(self.curr_state.get_backend(node)))
733    }
734
735    /// Reads from the internal estimator array and estimates the number of nodes reachable
736    /// from every node of the graph.
737    ///
738    /// `hyperball.reachable_nodes().unwrap()[i]` is equal to `hyperball.reachable_nodes_from(i).unwrap()`.
739    pub fn reachable_nodes(&self) -> Result<Box<[f32]>> {
740        self.ensure_iteration()?;
741        let logic = self.curr_state.logic();
742        Ok((0..self.graph.num_nodes())
743            .map(|n| logic.estimate(self.curr_state.get_backend(n)) as f32)
744            .collect())
745    }
746}
747
748impl<
749    G1: RandomAccessGraph + Sync,
750    G2: RandomAccessGraph + Sync,
751    D: for<'b> Succ<Input = usize, Output<'b> = usize> + Sync,
752    L: EstimationLogic<Item = usize> + MergeEstimationLogic + Sync,
753    A: EstimatorArrayMut<L> + Sync + AsSyncArray<L>,
754> HyperBall<'_, G1, G2, D, L, A>
755where
756    L::Backend: PartialEq,
757{
758    /// Performs a new iteration of HyperBall.
759    ///
760    /// # Arguments
761    /// * `pl`: A progress logger.
762    fn iterate(&mut self, pl: &mut impl ConcurrentProgressLog) -> Result<()> {
763        let ic = &mut self.iteration_context;
764
765        pl.info(format_args!("Performing iteration {}", ic.iteration + 1));
766
767        // Alias the number of modified estimators, nodes and arcs
768        let num_nodes = self.graph.num_nodes() as u64;
769        let num_arcs = self.graph.num_arcs();
770        let modified_estimators = ic.modified_estimators.load(Ordering::Relaxed);
771
772        // Let us record whether the previous computation was systolic or local
773        let prev_was_systolic = ic.systolic;
774        let prev_was_local = ic.local;
775
776        // If less than one fourth of the nodes have been modified, and we have
777        // the transpose, it is time to pass to a systolic computation
778        ic.systolic =
779            self.transposed.is_some() && ic.iteration > 0 && modified_estimators < num_nodes / 4;
780
781        // Non-systolic computations add up the values of all estimators.
782        //
783        // Systolic computations modify the last value by compensating for each
784        // modified estimators.
785        *ic.current_nf.lock().unwrap() = if ic.systolic { self.last } else { 0.0 };
786
787        // If we completed the last iteration in pre-local mode, we MUST run in
788        // local mode
789        ic.local = ic.pre_local;
790
791        // We run in pre-local mode if we are systolic and few nodes where
792        // modified.
793        ic.pre_local =
794            ic.systolic && modified_estimators < (num_nodes * num_nodes) / (num_arcs * 10);
795
796        if ic.systolic {
797            pl.info(format_args!(
798                "Starting systolic iteration (local: {}, pre_local: {})",
799                ic.local, ic.pre_local
800            ));
801        } else {
802            pl.info(format_args!("Starting standard iteration"));
803        }
804
805        if prev_was_local {
806            for &node in ic.local_checklist.iter() {
807                ic.next_modified.set(node, false, Ordering::Relaxed);
808            }
809        } else {
810            ic.next_modified.fill(false, Ordering::Relaxed);
811        }
812
813        if ic.local {
814            // In case of a local computation, we convert the set of
815            // must-be-checked for the next iteration into a check list
816            rayon::join(
817                || ic.local_checklist.clear(),
818                || {
819                    let mut local_next_must_be_checked =
820                        ic.local_next_must_be_checked.lock().unwrap();
821                    local_next_must_be_checked.par_sort_unstable();
822                    local_next_must_be_checked.dedup();
823                },
824            );
825            std::mem::swap(
826                &mut ic.local_checklist,
827                &mut ic.local_next_must_be_checked.lock().unwrap(),
828            );
829        } else if ic.systolic {
830            rayon::join(
831                || {
832                    // Systolic, non-local computations store the could-be-modified set implicitly into Self::next_must_be_checked.
833                    ic.next_must_be_checked.fill(false, Ordering::Relaxed);
834                },
835                || {
836                    // If the previous computation wasn't systolic, we must assume that all registers could have changed.
837                    if !prev_was_systolic {
838                        ic.must_be_checked.fill(true, Ordering::Relaxed);
839                    }
840                },
841            );
842        }
843
844        let mut granularity = ic.arc_granularity;
845        let num_threads = rayon::current_num_threads();
846
847        if num_threads > 1 && !ic.local {
848            if ic.iteration > 0 {
849                granularity = f64::min(
850                    std::cmp::max(1, num_nodes as usize / num_threads) as _,
851                    granularity as f64
852                        * (num_nodes as f64 / std::cmp::max(1, modified_estimators) as f64),
853                ) as usize;
854            }
855            pl.info(format_args!(
856                "Adaptive granularity for this iteration: {}",
857                granularity
858            ));
859        }
860
861        ic.reset(granularity);
862
863        pl.item_name("arc");
864        pl.expected_updates(if ic.local { None } else { Some(num_arcs as _) });
865        pl.start("Starting parallel execution");
866        {
867            let next_state_sync = self.next_state.as_sync_array();
868            let sum_of_dists = self.sum_of_dists.as_mut().map(|x| x.as_sync_slice());
869            let sum_of_inv_dists = self.sum_of_inv_dists.as_mut().map(|x| x.as_sync_slice());
870
871            let discounted_centralities = &self
872                .discounted_centralities
873                .iter_mut()
874                .map(|s| s.as_sync_slice())
875                .collect::<Vec<_>>();
876            rayon::broadcast(|c| {
877                Self::parallel_task(
878                    self.graph,
879                    self.transposed,
880                    &self.curr_state,
881                    &next_state_sync,
882                    ic,
883                    sum_of_dists,
884                    sum_of_inv_dists,
885                    discounted_centralities,
886                    c,
887                )
888            });
889        }
890
891        pl.done_with_count(ic.visited_arcs.load(Ordering::Relaxed) as usize);
892        let modified_estimators = ic.modified_estimators.load(Ordering::Relaxed);
893
894        pl.info(format_args!(
895            "Modified estimators: {}/{} ({:.3}%)",
896            modified_estimators,
897            self.graph.num_nodes(),
898            (modified_estimators as f64 / self.graph.num_nodes() as f64) * 100.0
899        ));
900
901        std::mem::swap(&mut self.curr_state, &mut self.next_state);
902        std::mem::swap(&mut ic.curr_modified, &mut ic.next_modified);
903
904        if ic.systolic {
905            std::mem::swap(&mut ic.must_be_checked, &mut ic.next_must_be_checked);
906        }
907
908        let mut current_nf_mut = ic.current_nf.lock().unwrap();
909        self.last = *current_nf_mut;
910        // We enforce monotonicity--non-monotonicity can only be caused by
911        // approximation errors
912        let &last_output = self
913            .neighborhood_function
914            .as_slice()
915            .last()
916            .expect("Should always have at least 1 element");
917        if *current_nf_mut < last_output {
918            *current_nf_mut = last_output;
919        }
920        self.relative_increment = *current_nf_mut / last_output;
921
922        pl.info(format_args!(
923            "Pairs: {} ({}%)",
924            *current_nf_mut,
925            (*current_nf_mut * 100.0) / (num_nodes * num_nodes) as f64
926        ));
927        pl.info(format_args!(
928            "Absolute increment: {}",
929            *current_nf_mut - last_output
930        ));
931        pl.info(format_args!(
932            "Relative increment: {}",
933            self.relative_increment
934        ));
935
936        self.neighborhood_function.push(*current_nf_mut);
937
938        ic.iteration += 1;
939
940        Ok(())
941    }
942
943    /// The parallel operations to be performed each iteration.
944    ///
945    /// # Arguments:
946    /// * `graph`: the graph to analyze.
947    /// * `transpose`: optionally, the transpose of `graph`. If [`None`], no
948    ///   systolic iterations will be performed.
949    /// * `curr_state`: the current state of the estimators.
950    /// * `next_state`: the next state of the estimators (to be computed).
951    /// * `ic`: the iteration context.
952    #[allow(clippy::too_many_arguments)]
953    fn parallel_task(
954        graph: &(impl RandomAccessGraph + Sync),
955        transpose: Option<&(impl RandomAccessGraph + Sync)>,
956        curr_state: &impl EstimatorArray<L>,
957        next_state: &impl SyncEstimatorArray<L>,
958        ic: &IterationContext<'_, G1, D>,
959        sum_of_dists: Option<&[SyncCell<f32>]>,
960        sum_of_inv_dists: Option<&[SyncCell<f32>]>,
961        discounted_centralities: &[&[SyncCell<f32>]],
962        _broadcast_context: rayon::BroadcastContext,
963    ) {
964        let node_granularity = ic.arc_granularity;
965        let arc_granularity = ((graph.num_arcs() as f64 * node_granularity as f64)
966            / graph.num_nodes() as f64)
967            .ceil() as usize;
968        let do_centrality = sum_of_dists.is_some()
969            || sum_of_inv_dists.is_some()
970            || !ic.discount_functions.is_empty();
971        let node_upper_limit = if ic.local {
972            ic.local_checklist.len()
973        } else {
974            graph.num_nodes()
975        };
976        let mut visited_arcs = 0;
977        let mut modified_estimators = 0;
978        let arc_upper_limit = graph.num_arcs();
979
980        // During standard iterations, cumulates the neighborhood function for the nodes scanned
981        // by this thread. During systolic iterations, cumulates the *increase* of the
982        // neighborhood function for the nodes scanned by this thread.
983        let mut neighborhood_function_delta = KahanSum::new_with_value(0.0);
984        let mut helper = curr_state.logic().new_helper();
985        let logic = curr_state.logic();
986        let mut next_estimator = logic.new_estimator();
987
988        loop {
989            // Get work
990            let (start, end) = if ic.local {
991                let start = std::cmp::min(
992                    ic.node_cursor.fetch_add(1, Ordering::Relaxed),
993                    node_upper_limit,
994                );
995                let end = std::cmp::min(start + 1, node_upper_limit);
996                (start, end)
997            } else {
998                let mut arc_balanced_cursor = ic.arc_cursor.lock().unwrap();
999                let (mut next_node, mut next_arc) = *arc_balanced_cursor;
1000                if next_node >= node_upper_limit {
1001                    (node_upper_limit, node_upper_limit)
1002                } else {
1003                    let start = next_node;
1004                    let target = next_arc + arc_granularity;
1005                    if target as u64 >= arc_upper_limit {
1006                        next_node = node_upper_limit;
1007                    } else {
1008                        (next_node, next_arc) = ic.cumul_outdeg.succ(target).unwrap();
1009                    }
1010                    let end = next_node;
1011                    *arc_balanced_cursor = (next_node, next_arc);
1012                    (start, end)
1013                }
1014            };
1015
1016            if start == node_upper_limit {
1017                break;
1018            }
1019
1020            // Do work
1021            for i in start..end {
1022                let node = if ic.local { ic.local_checklist[i] } else { i };
1023
1024                let prev_estimator = curr_state.get_backend(node);
1025
1026                // The three cases in which we enumerate successors:
1027                // 1) A non-systolic computation (we don't know anything, so we enumerate).
1028                // 2) A systolic, local computation (the node is by definition to be checked, as it comes from the local check list).
1029                // 3) A systolic, non-local computation in which the node should be checked.
1030                if !ic.systolic || ic.local || ic.must_be_checked[node] {
1031                    next_estimator.set(prev_estimator);
1032                    let mut modified = false;
1033                    for succ in graph.successors(node) {
1034                        if succ != node && ic.curr_modified[succ] {
1035                            visited_arcs += 1;
1036                            if !modified {
1037                                modified = true;
1038                            }
1039                            logic.merge_with_helper(
1040                                next_estimator.as_mut(),
1041                                curr_state.get_backend(succ),
1042                                &mut helper,
1043                            );
1044                        }
1045                    }
1046
1047                    let mut post = f64::NAN;
1048                    let estimator_modified = modified && next_estimator.as_ref() != prev_estimator;
1049
1050                    // We need the estimator value only if the iteration is standard (as we're going to
1051                    // compute the neighborhood function cumulating actual values, and not deltas) or
1052                    // if the estimator was actually modified (as we're going to cumulate the neighborhood
1053                    // function delta, or at least some centrality).
1054                    if !ic.systolic || estimator_modified {
1055                        post = logic.estimate(next_estimator.as_ref())
1056                    }
1057                    if !ic.systolic {
1058                        neighborhood_function_delta += post;
1059                    }
1060
1061                    if estimator_modified && (ic.systolic || do_centrality) {
1062                        let pre = logic.estimate(prev_estimator);
1063                        if ic.systolic {
1064                            neighborhood_function_delta += -pre;
1065                            neighborhood_function_delta += post;
1066                        }
1067
1068                        if do_centrality {
1069                            let delta = post - pre;
1070                            // Note that this code is executed only for distances > 0
1071                            if delta > 0.0 {
1072                                if let Some(distances) = sum_of_dists {
1073                                    let new_value = delta * (ic.iteration + 1) as f64;
1074                                    unsafe {
1075                                        distances[node]
1076                                            .set((distances[node].get() as f64 + new_value) as f32)
1077                                    };
1078                                }
1079                                if let Some(distances) = sum_of_inv_dists {
1080                                    let new_value = delta / (ic.iteration + 1) as f64;
1081                                    unsafe {
1082                                        distances[node]
1083                                            .set((distances[node].get() as f64 + new_value) as f32)
1084                                    };
1085                                }
1086                                for (func, distances) in ic
1087                                    .discount_functions
1088                                    .iter()
1089                                    .zip(discounted_centralities.iter())
1090                                {
1091                                    let new_value = delta * func(ic.iteration + 1);
1092                                    unsafe {
1093                                        distances[node]
1094                                            .set((distances[node].get() as f64 + new_value) as f32)
1095                                    };
1096                                }
1097                            }
1098                        }
1099                    }
1100
1101                    if estimator_modified {
1102                        // We keep track of modified estimators in the result. Note that we must
1103                        // add the current node to the must-be-checked set for the next
1104                        // local iteration if it is modified, as it might need a copy to
1105                        // the result array at the next iteration.
1106                        if ic.pre_local {
1107                            ic.local_next_must_be_checked.lock().unwrap().push(node);
1108                        }
1109                        ic.next_modified.set(node, true, Ordering::Relaxed);
1110
1111                        if ic.systolic {
1112                            debug_assert!(transpose.is_some());
1113                            // In systolic computations we must keep track of
1114                            // which estimators must be checked on the next
1115                            // iteration. If we are preparing a local
1116                            // computation, we do this explicitly, by adding the
1117                            // predecessors of the current node to a set.
1118                            // Otherwise, we do this implicitly, by setting the
1119                            // corresponding entry in an array.
1120
1121                            // SAFETY: ic.systolic is true, so transpose is Some
1122                            let transpose = unsafe { transpose.unwrap_unchecked() };
1123                            if ic.pre_local {
1124                                let mut local_next_must_be_checked =
1125                                    ic.local_next_must_be_checked.lock().unwrap();
1126                                for succ in transpose.successors(node) {
1127                                    local_next_must_be_checked.push(succ);
1128                                }
1129                            } else {
1130                                for succ in transpose.successors(node) {
1131                                    ic.next_must_be_checked.set(succ, true, Ordering::Relaxed);
1132                                }
1133                            }
1134                        }
1135
1136                        modified_estimators += 1;
1137                    }
1138
1139                    unsafe {
1140                        next_state.set(node, next_estimator.as_ref());
1141                    }
1142                } else {
1143                    // Even if we cannot possibly have changed our value, still our copy
1144                    // in the result vector might need to be updated because it does not
1145                    // reflect our current value.
1146                    if ic.curr_modified[node] {
1147                        unsafe {
1148                            next_state.set(node, prev_estimator);
1149                        }
1150                    }
1151                }
1152            }
1153        }
1154
1155        *ic.current_nf.lock().unwrap() += neighborhood_function_delta.sum();
1156        ic.visited_arcs.fetch_add(visited_arcs, Ordering::Relaxed);
1157        ic.modified_estimators
1158            .fetch_add(modified_estimators, Ordering::Relaxed);
1159    }
1160
1161    /// Initializes HyperBall.
1162    fn init(&mut self, mut rng: impl rand::Rng, pl: &mut impl ConcurrentProgressLog) -> Result<()> {
1163        pl.start("Initializing estimators");
1164        pl.info(format_args!("Clearing all registers"));
1165
1166        self.curr_state.clear();
1167        self.next_state.clear();
1168
1169        pl.info(format_args!("Initializing registers"));
1170        if let Some(w) = &self.weight {
1171            pl.info(format_args!("Loading weights"));
1172            for (i, &node_weight) in w.iter().enumerate() {
1173                let mut estimator = self.curr_state.get_estimator_mut(i);
1174                for _ in 0..node_weight {
1175                    estimator.add(&(rng.random::<u64>() as usize));
1176                }
1177            }
1178        } else {
1179            (0..self.graph.num_nodes()).for_each(|i| {
1180                self.curr_state.get_estimator_mut(i).add(i);
1181            });
1182        }
1183
1184        self.completed = false;
1185
1186        let ic = &mut self.iteration_context;
1187        ic.iteration = 0;
1188        ic.systolic = false;
1189        ic.local = false;
1190        ic.pre_local = false;
1191        ic.reset(self.granularity);
1192
1193        pl.debug(format_args!("Initializing distances"));
1194        if let Some(distances) = &mut self.sum_of_dists {
1195            distances.fill(0.0);
1196        }
1197        if let Some(distances) = &mut self.sum_of_inv_dists {
1198            distances.fill(0.0);
1199        }
1200        pl.debug(format_args!("Initializing centralities"));
1201        for centralities in self.discounted_centralities.iter_mut() {
1202            centralities.fill(0.0);
1203        }
1204
1205        self.last = self.graph.num_nodes() as f64;
1206        pl.debug(format_args!("Initializing neighborhood function"));
1207        self.neighborhood_function.clear();
1208        self.neighborhood_function.push(self.last);
1209
1210        pl.debug(format_args!("Initializing modified estimators"));
1211        ic.curr_modified.fill(true, Ordering::Relaxed);
1212
1213        pl.done();
1214
1215        Ok(())
1216    }
1217}
1218
1219#[cfg(test)]
1220mod test {
1221    use std::hash::{BuildHasherDefault, DefaultHasher};
1222
1223    use super::*;
1224    use card_est_array::traits::{EstimatorArray, MergeEstimator};
1225    use dsi_progress_logger::no_logging;
1226    use epserde::deser::{Deserialize, Flags};
1227    use rand::SeedableRng;
1228    use webgraph::{
1229        prelude::{BvGraph, DCF},
1230        traits::SequentialLabeling,
1231    };
1232
1233    type HyperBallArray<G> = SliceEstimatorArray<
1234        HyperLogLog<<G as SequentialLabeling>::Label, BuildHasherDefault<DefaultHasher>, usize>,
1235        usize,
1236        Box<[usize]>,
1237    >;
1238
1239    struct SeqHyperBall<'a, G: RandomAccessGraph> {
1240        graph: &'a G,
1241        curr_state: HyperBallArray<G>,
1242        next_state: HyperBallArray<G>,
1243    }
1244
1245    impl<G: RandomAccessGraph> SeqHyperBall<'_, G> {
1246        fn init(&mut self) {
1247            for i in 0..self.graph.num_nodes() {
1248                self.curr_state.get_estimator_mut(i).add(i);
1249            }
1250        }
1251
1252        fn iterate(&mut self) {
1253            for i in 0..self.graph.num_nodes() {
1254                let mut estimator = self.next_state.get_estimator_mut(i);
1255                estimator.set(self.curr_state.get_backend(i));
1256                for succ in self.graph.successors(i) {
1257                    estimator.merge(self.curr_state.get_backend(succ));
1258                }
1259            }
1260            std::mem::swap(&mut self.curr_state, &mut self.next_state);
1261        }
1262    }
1263
1264    #[cfg_attr(feature = "slow_tests", test)]
1265    #[cfg_attr(not(feature = "slow_tests"), allow(dead_code))]
1266    fn test_cnr_2000() -> Result<()> {
1267        let basename = "../data/cnr-2000";
1268
1269        let graph = BvGraph::with_basename(basename).load()?;
1270        let transpose = BvGraph::with_basename(basename.to_owned() + "-t").load()?;
1271        let cumulative = unsafe { DCF::load_mmap(basename.to_owned() + ".dcf", Flags::empty()) }?;
1272
1273        let num_nodes = graph.num_nodes();
1274
1275        let hyper_log_log = HyperLogLogBuilder::new(num_nodes)
1276            .log_2_num_reg(6)
1277            .build()?;
1278
1279        let seq_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1280        let seq_result_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1281        let par_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1282        let par_result_bits = SliceEstimatorArray::new(hyper_log_log.clone(), num_nodes);
1283
1284        let mut hyperball = HyperBallBuilder::with_transpose(
1285            &graph,
1286            &transpose,
1287            cumulative.uncase(),
1288            par_bits,
1289            par_result_bits,
1290        )
1291        .build(no_logging![]);
1292        let mut seq_hyperball = SeqHyperBall {
1293            curr_state: seq_bits,
1294            next_state: seq_result_bits,
1295            graph: &graph,
1296        };
1297
1298        let mut modified_estimators = num_nodes as u64;
1299        let mut rng = rand::rngs::SmallRng::seed_from_u64(42);
1300        hyperball.init(&mut rng, no_logging![])?;
1301        seq_hyperball.init();
1302
1303        while modified_estimators != 0 {
1304            hyperball.iterate(no_logging![])?;
1305            seq_hyperball.iterate();
1306
1307            modified_estimators = hyperball
1308                .iteration_context
1309                .modified_estimators
1310                .load(Ordering::Relaxed);
1311
1312            assert_eq!(
1313                hyperball.next_state.as_ref(),
1314                seq_hyperball.next_state.as_ref()
1315            );
1316            assert_eq!(
1317                hyperball.curr_state.as_ref(),
1318                seq_hyperball.curr_state.as_ref()
1319            );
1320        }
1321
1322        Ok(())
1323    }
1324}