webgraph_algo/llp/
preds.rs

1/*
2 * SPDX-FileCopyrightText: 2024 Tommaso Fontana
3 * SPDX-FileCopyrightText: 2024 Sebastiano Vigna
4 *
5 * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
6 */
7
8//! Predicates implementing stopping conditions.
9//!
10//! The implementation of [layered label propagation](super) requires a
11//! [predicate](Predicate) to stop the algorithm. This module provides a few
12//! such predicates: they evaluate to true if the updates should be stopped.
13//!
14//! You can combine the predicates using the `and` and `or` methods provided by
15//! the [`Predicate`] trait.
16//!
17//! # Examples
18//! ```
19//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! use predicates::prelude::*;
21//! use webgraph::algo::llp::preds::{MinGain, MaxUpdates};
22//!
23//! let mut predicate = MinGain::try_from(0.001)?.boxed();
24//! predicate = predicate.or(MaxUpdates::from(100)).boxed();
25//! #     Ok(())
26//! # }
27//! ```
28
29use anyhow::ensure;
30use predicates::{reflection::PredicateReflection, Predicate};
31use std::fmt::Display;
32
33#[doc(hidden)]
34/// This structure is passed to predicates to provide the
35/// information that is needed to evaluate them.
36pub struct PredParams {
37    pub num_nodes: usize,
38    pub num_arcs: u64,
39    pub gain: f64,
40    pub avg_gain_impr: f64,
41    pub modified: usize,
42    pub update: usize,
43}
44
45/// Stop after at most the provided number of updates for a given ɣ.
46#[derive(Debug, Clone)]
47pub struct MaxUpdates {
48    max_updates: usize,
49}
50
51impl MaxUpdates {
52    pub const DEFAULT_MAX_UPDATES: usize = usize::MAX;
53}
54
55impl From<Option<usize>> for MaxUpdates {
56    fn from(max_updates: Option<usize>) -> Self {
57        match max_updates {
58            Some(max_updates) => MaxUpdates { max_updates },
59            None => Self::default(),
60        }
61    }
62}
63
64impl From<usize> for MaxUpdates {
65    fn from(max_updates: usize) -> Self {
66        Some(max_updates).into()
67    }
68}
69
70impl Default for MaxUpdates {
71    fn default() -> Self {
72        Self::from(Self::DEFAULT_MAX_UPDATES)
73    }
74}
75
76impl Display for MaxUpdates {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.write_fmt(format_args!("(max updates: {})", self.max_updates))
79    }
80}
81
82impl PredicateReflection for MaxUpdates {}
83impl Predicate<PredParams> for MaxUpdates {
84    fn eval(&self, pred_params: &PredParams) -> bool {
85        pred_params.update + 1 >= self.max_updates
86    }
87}
88
89#[derive(Debug, Clone)]
90/// Stop if the gain of the objective function is below the given threshold.
91///
92/// The [default threshold](Self::DEFAULT_THRESHOLD) is the same as that
93/// of the Java implementation.
94pub struct MinGain {
95    threshold: f64,
96}
97
98impl MinGain {
99    pub const DEFAULT_THRESHOLD: f64 = 0.001;
100}
101
102impl TryFrom<Option<f64>> for MinGain {
103    type Error = anyhow::Error;
104    fn try_from(threshold: Option<f64>) -> anyhow::Result<Self> {
105        Ok(match threshold {
106            Some(threshold) => {
107                ensure!(!threshold.is_nan());
108                ensure!(threshold >= 0.0, "The threshold must be nonnegative");
109                MinGain { threshold }
110            }
111            None => Self::default(),
112        })
113    }
114}
115
116impl TryFrom<f64> for MinGain {
117    type Error = anyhow::Error;
118    fn try_from(threshold: f64) -> anyhow::Result<Self> {
119        Some(threshold).try_into()
120    }
121}
122
123impl Default for MinGain {
124    fn default() -> Self {
125        Self::try_from(Self::DEFAULT_THRESHOLD).unwrap()
126    }
127}
128
129impl Display for MinGain {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.write_fmt(format_args!("(min gain: {})", self.threshold))
132    }
133}
134
135impl PredicateReflection for MinGain {}
136impl Predicate<PredParams> for MinGain {
137    fn eval(&self, pred_params: &PredParams) -> bool {
138        pred_params.gain <= self.threshold
139    }
140}
141
142#[derive(Debug, Clone)]
143/// Stop if the average improvement of the gain of the objective function on
144/// a window of ten updates is below the given threshold.
145///
146/// This criterion is a second-order version of [`MinGain`]. It is very useful
147/// to avoid a large number of iteration which do not improve the objective
148/// function significantly.
149pub struct MinAvgImprov {
150    threshold: f64,
151}
152
153impl MinAvgImprov {
154    pub const DEFAULT_THRESHOLD: f64 = 0.1;
155}
156
157impl TryFrom<Option<f64>> for MinAvgImprov {
158    type Error = anyhow::Error;
159    fn try_from(threshold: Option<f64>) -> anyhow::Result<Self> {
160        Ok(match threshold {
161            Some(threshold) => {
162                ensure!(!threshold.is_nan());
163                MinAvgImprov { threshold }
164            }
165            None => Self::default(),
166        })
167    }
168}
169
170impl TryFrom<f64> for MinAvgImprov {
171    type Error = anyhow::Error;
172    fn try_from(threshold: f64) -> anyhow::Result<Self> {
173        Some(threshold).try_into()
174    }
175}
176
177impl Default for MinAvgImprov {
178    fn default() -> Self {
179        Self::try_from(Self::DEFAULT_THRESHOLD).unwrap()
180    }
181}
182
183impl Display for MinAvgImprov {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        f.write_fmt(format_args!(
186            "(min avg gain improvement: {})",
187            self.threshold
188        ))
189    }
190}
191
192impl PredicateReflection for MinAvgImprov {}
193impl Predicate<PredParams> for MinAvgImprov {
194    fn eval(&self, pred_params: &PredParams) -> bool {
195        pred_params.avg_gain_impr <= self.threshold
196    }
197}
198
199#[derive(Debug, Clone, Default)]
200/// Stop after the number of modified nodes falls below the square root of the
201/// number of nodes.
202pub struct MinModified {}
203
204impl Display for MinModified {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.write_str("(min modified: √n)")
207    }
208}
209
210impl PredicateReflection for MinModified {}
211impl Predicate<PredParams> for MinModified {
212    fn eval(&self, pred_params: &PredParams) -> bool {
213        (pred_params.modified as f64) <= (pred_params.num_nodes as f64).sqrt()
214    }
215}
216
217#[derive(Debug, Clone, Default)]
218/// Stop after the number of modified nodes falls below
219/// a specified percentage of the number of nodes.
220pub struct PercModified {
221    threshold: f64,
222}
223
224impl Display for PercModified {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.write_fmt(format_args!("(min modified: {}%)", self.threshold * 100.0))
227    }
228}
229
230impl TryFrom<f64> for PercModified {
231    type Error = anyhow::Error;
232    fn try_from(threshold: f64) -> anyhow::Result<Self> {
233        ensure!(
234            threshold >= 0.0,
235            "The percent threshold must be nonnegative"
236        );
237        ensure!(
238            threshold <= 100.0,
239            "The percent threshold must be at most 100"
240        );
241        Ok(PercModified {
242            threshold: threshold / 100.0,
243        })
244    }
245}
246
247impl PredicateReflection for PercModified {}
248impl Predicate<PredParams> for PercModified {
249    fn eval(&self, pred_params: &PredParams) -> bool {
250        (pred_params.modified as f64) <= (pred_params.num_nodes as f64) * self.threshold
251    }
252}