Skip to main content

somatize_runtime/
pruner.rs

1//! Early stopping strategies for optimization studies.
2//!
3//! A [`Pruner`] decides whether a trial should be stopped based on
4//! intermediate metric values. Implementations: [`MedianPruner`],
5//! [`PercentilePruner`].
6
7use somatize_core::event::MetricRecord;
8
9/// A pruner decides whether to stop a trial early based on intermediate metrics.
10pub trait Pruner: Send + Sync {
11    /// Decide whether to prune given the current trial's metrics and
12    /// the history of completed trials' metrics at the same step.
13    ///
14    /// Returns `Some(reason)` if the trial should be pruned.
15    fn should_prune(
16        &self,
17        metric_name: &str,
18        current_value: f64,
19        step: usize,
20        history: &[TrialMetricHistory],
21    ) -> Option<String>;
22}
23
24/// A completed trial's metric history (for comparing against).
25pub struct TrialMetricHistory {
26    pub trial_id: String,
27    pub metrics: Vec<MetricRecord>,
28}
29
30/// Prune if current value is below the median of completed trials at the same step.
31pub struct MedianPruner {
32    /// Don't prune before this many steps.
33    pub n_warmup_steps: usize,
34    /// Minimum completed trials needed before pruning starts.
35    pub min_trials: usize,
36}
37
38impl MedianPruner {
39    pub fn new(n_warmup_steps: usize) -> Self {
40        Self {
41            n_warmup_steps,
42            min_trials: 1,
43        }
44    }
45
46    pub fn with_min_trials(mut self, min_trials: usize) -> Self {
47        self.min_trials = min_trials;
48        self
49    }
50}
51
52impl Pruner for MedianPruner {
53    fn should_prune(
54        &self,
55        metric_name: &str,
56        current_value: f64,
57        step: usize,
58        history: &[TrialMetricHistory],
59    ) -> Option<String> {
60        if step < self.n_warmup_steps {
61            return None;
62        }
63
64        // Collect values at this step from completed trials
65        let mut values_at_step: Vec<f64> = history
66            .iter()
67            .filter_map(|h| {
68                h.metrics
69                    .iter()
70                    .filter(|m| m.name == metric_name && m.step == step)
71                    .map(|m| m.value)
72                    .next_back()
73            })
74            .collect();
75
76        if values_at_step.len() < self.min_trials {
77            return None;
78        }
79
80        values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
81        let median = if values_at_step.len().is_multiple_of(2) {
82            let mid = values_at_step.len() / 2;
83            (values_at_step[mid - 1] + values_at_step[mid]) / 2.0
84        } else {
85            values_at_step[values_at_step.len() / 2]
86        };
87
88        // Prune if below median (assuming maximize; for minimize, caller inverts)
89        if current_value < median {
90            Some(format!(
91                "value {current_value:.4} below median {median:.4} at step {step}"
92            ))
93        } else {
94            None
95        }
96    }
97}
98
99/// Prune if current value is below the given percentile.
100pub struct PercentilePruner {
101    pub percentile: f64,
102    pub n_warmup_steps: usize,
103    pub min_trials: usize,
104}
105
106impl PercentilePruner {
107    pub fn new(percentile: f64, n_warmup_steps: usize) -> Self {
108        Self {
109            percentile,
110            n_warmup_steps,
111            min_trials: 1,
112        }
113    }
114}
115
116impl Pruner for PercentilePruner {
117    fn should_prune(
118        &self,
119        metric_name: &str,
120        current_value: f64,
121        step: usize,
122        history: &[TrialMetricHistory],
123    ) -> Option<String> {
124        if step < self.n_warmup_steps {
125            return None;
126        }
127
128        let mut values_at_step: Vec<f64> = history
129            .iter()
130            .filter_map(|h| {
131                h.metrics
132                    .iter()
133                    .filter(|m| m.name == metric_name && m.step == step)
134                    .map(|m| m.value)
135                    .next_back()
136            })
137            .collect();
138
139        if values_at_step.len() < self.min_trials {
140            return None;
141        }
142
143        values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
144        let idx = ((self.percentile / 100.0) * values_at_step.len() as f64).floor() as usize;
145        let idx = idx.min(values_at_step.len() - 1);
146        let threshold = values_at_step[idx];
147
148        if current_value < threshold {
149            Some(format!(
150                "value {current_value:.4} below p{:.0} threshold {threshold:.4} at step {step}",
151                self.percentile
152            ))
153        } else {
154            None
155        }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use chrono::Utc;
163
164    fn make_history(values_per_step: &[Vec<f64>]) -> Vec<TrialMetricHistory> {
165        values_per_step[0]
166            .iter()
167            .enumerate()
168            .map(|(trial_idx, _)| {
169                let metrics: Vec<MetricRecord> = values_per_step
170                    .iter()
171                    .enumerate()
172                    .filter_map(|(step, vals)| {
173                        vals.get(trial_idx).map(|&v| MetricRecord {
174                            name: "f1".into(),
175                            value: v,
176                            step,
177                            timestamp: Utc::now(),
178                        })
179                    })
180                    .collect();
181                TrialMetricHistory {
182                    trial_id: format!("t{trial_idx}"),
183                    metrics,
184                }
185            })
186            .collect()
187    }
188
189    // ── Median pruner ──
190
191    #[test]
192    fn median_no_prune_during_warmup() {
193        let pruner = MedianPruner::new(5);
194        let history = make_history(&[vec![0.9, 0.8, 0.7]]);
195        assert!(pruner.should_prune("f1", 0.1, 3, &history).is_none());
196    }
197
198    #[test]
199    fn median_prunes_below_median() {
200        let pruner = MedianPruner::new(0);
201        // At step 0: values are [0.7, 0.8, 0.9]. Median = 0.8
202        let history = make_history(&[vec![0.7, 0.8, 0.9]]);
203        // Current = 0.5, below median 0.8
204        assert!(pruner.should_prune("f1", 0.5, 0, &history).is_some());
205    }
206
207    #[test]
208    fn median_keeps_above_median() {
209        let pruner = MedianPruner::new(0);
210        let history = make_history(&[vec![0.7, 0.8, 0.9]]);
211        // Current = 0.85, above median 0.8
212        assert!(pruner.should_prune("f1", 0.85, 0, &history).is_none());
213    }
214
215    #[test]
216    fn median_no_prune_insufficient_history() {
217        let pruner = MedianPruner::new(0).with_min_trials(5);
218        let history = make_history(&[vec![0.7, 0.8]]);
219        // Only 2 trials, need 5
220        assert!(pruner.should_prune("f1", 0.1, 0, &history).is_none());
221    }
222
223    #[test]
224    fn median_empty_history() {
225        let pruner = MedianPruner::new(0);
226        assert!(pruner.should_prune("f1", 0.5, 0, &[]).is_none());
227    }
228
229    // ── Percentile pruner ──
230
231    #[test]
232    fn percentile_prunes_below_threshold() {
233        let pruner = PercentilePruner::new(25.0, 0);
234        // At step 0: sorted = [0.3, 0.5, 0.7, 0.9]. p25 idx=1 → threshold=0.5
235        let history = make_history(&[vec![0.5, 0.9, 0.3, 0.7]]);
236        // Current = 0.2, below p25 threshold
237        assert!(pruner.should_prune("f1", 0.2, 0, &history).is_some());
238    }
239
240    #[test]
241    fn percentile_keeps_above_threshold() {
242        let pruner = PercentilePruner::new(25.0, 0);
243        let history = make_history(&[vec![0.5, 0.9, 0.3, 0.7]]);
244        assert!(pruner.should_prune("f1", 0.6, 0, &history).is_none());
245    }
246
247    #[test]
248    fn percentile_warmup_respected() {
249        let pruner = PercentilePruner::new(50.0, 10);
250        let history = make_history(&[vec![0.9]]);
251        assert!(pruner.should_prune("f1", 0.1, 5, &history).is_none());
252    }
253}