somatize_runtime/
pruner.rs1use somatize_core::event::MetricRecord;
8
9pub trait Pruner: Send + Sync {
11 fn should_prune(
16 &self,
17 metric_name: &str,
18 current_value: f64,
19 step: usize,
20 history: &[TrialMetricHistory],
21 ) -> Option<String>;
22}
23
24pub struct TrialMetricHistory {
26 pub trial_id: String,
27 pub metrics: Vec<MetricRecord>,
28}
29
30pub struct MedianPruner {
32 pub n_warmup_steps: usize,
34 pub min_trials: usize,
36}
37
38impl MedianPruner {
39 pub fn new(n_warmup_steps: usize) -> Self {
40 Self {
41 n_warmup_steps,
42 min_trials: 1,
43 }
44 }
45
46 pub fn with_min_trials(mut self, min_trials: usize) -> Self {
47 self.min_trials = min_trials;
48 self
49 }
50}
51
52impl Pruner for MedianPruner {
53 fn should_prune(
54 &self,
55 metric_name: &str,
56 current_value: f64,
57 step: usize,
58 history: &[TrialMetricHistory],
59 ) -> Option<String> {
60 if step < self.n_warmup_steps {
61 return None;
62 }
63
64 let mut values_at_step: Vec<f64> = history
66 .iter()
67 .filter_map(|h| {
68 h.metrics
69 .iter()
70 .filter(|m| m.name == metric_name && m.step == step)
71 .map(|m| m.value)
72 .next_back()
73 })
74 .collect();
75
76 if values_at_step.len() < self.min_trials {
77 return None;
78 }
79
80 values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
81 let median = if values_at_step.len().is_multiple_of(2) {
82 let mid = values_at_step.len() / 2;
83 (values_at_step[mid - 1] + values_at_step[mid]) / 2.0
84 } else {
85 values_at_step[values_at_step.len() / 2]
86 };
87
88 if current_value < median {
90 Some(format!(
91 "value {current_value:.4} below median {median:.4} at step {step}"
92 ))
93 } else {
94 None
95 }
96 }
97}
98
99pub struct PercentilePruner {
101 pub percentile: f64,
102 pub n_warmup_steps: usize,
103 pub min_trials: usize,
104}
105
106impl PercentilePruner {
107 pub fn new(percentile: f64, n_warmup_steps: usize) -> Self {
108 Self {
109 percentile,
110 n_warmup_steps,
111 min_trials: 1,
112 }
113 }
114}
115
116impl Pruner for PercentilePruner {
117 fn should_prune(
118 &self,
119 metric_name: &str,
120 current_value: f64,
121 step: usize,
122 history: &[TrialMetricHistory],
123 ) -> Option<String> {
124 if step < self.n_warmup_steps {
125 return None;
126 }
127
128 let mut values_at_step: Vec<f64> = history
129 .iter()
130 .filter_map(|h| {
131 h.metrics
132 .iter()
133 .filter(|m| m.name == metric_name && m.step == step)
134 .map(|m| m.value)
135 .next_back()
136 })
137 .collect();
138
139 if values_at_step.len() < self.min_trials {
140 return None;
141 }
142
143 values_at_step.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
144 let idx = ((self.percentile / 100.0) * values_at_step.len() as f64).floor() as usize;
145 let idx = idx.min(values_at_step.len() - 1);
146 let threshold = values_at_step[idx];
147
148 if current_value < threshold {
149 Some(format!(
150 "value {current_value:.4} below p{:.0} threshold {threshold:.4} at step {step}",
151 self.percentile
152 ))
153 } else {
154 None
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use chrono::Utc;
163
164 fn make_history(values_per_step: &[Vec<f64>]) -> Vec<TrialMetricHistory> {
165 values_per_step[0]
166 .iter()
167 .enumerate()
168 .map(|(trial_idx, _)| {
169 let metrics: Vec<MetricRecord> = values_per_step
170 .iter()
171 .enumerate()
172 .filter_map(|(step, vals)| {
173 vals.get(trial_idx).map(|&v| MetricRecord {
174 name: "f1".into(),
175 value: v,
176 step,
177 timestamp: Utc::now(),
178 })
179 })
180 .collect();
181 TrialMetricHistory {
182 trial_id: format!("t{trial_idx}"),
183 metrics,
184 }
185 })
186 .collect()
187 }
188
189 #[test]
192 fn median_no_prune_during_warmup() {
193 let pruner = MedianPruner::new(5);
194 let history = make_history(&[vec![0.9, 0.8, 0.7]]);
195 assert!(pruner.should_prune("f1", 0.1, 3, &history).is_none());
196 }
197
198 #[test]
199 fn median_prunes_below_median() {
200 let pruner = MedianPruner::new(0);
201 let history = make_history(&[vec![0.7, 0.8, 0.9]]);
203 assert!(pruner.should_prune("f1", 0.5, 0, &history).is_some());
205 }
206
207 #[test]
208 fn median_keeps_above_median() {
209 let pruner = MedianPruner::new(0);
210 let history = make_history(&[vec![0.7, 0.8, 0.9]]);
211 assert!(pruner.should_prune("f1", 0.85, 0, &history).is_none());
213 }
214
215 #[test]
216 fn median_no_prune_insufficient_history() {
217 let pruner = MedianPruner::new(0).with_min_trials(5);
218 let history = make_history(&[vec![0.7, 0.8]]);
219 assert!(pruner.should_prune("f1", 0.1, 0, &history).is_none());
221 }
222
223 #[test]
224 fn median_empty_history() {
225 let pruner = MedianPruner::new(0);
226 assert!(pruner.should_prune("f1", 0.5, 0, &[]).is_none());
227 }
228
229 #[test]
232 fn percentile_prunes_below_threshold() {
233 let pruner = PercentilePruner::new(25.0, 0);
234 let history = make_history(&[vec![0.5, 0.9, 0.3, 0.7]]);
236 assert!(pruner.should_prune("f1", 0.2, 0, &history).is_some());
238 }
239
240 #[test]
241 fn percentile_keeps_above_threshold() {
242 let pruner = PercentilePruner::new(25.0, 0);
243 let history = make_history(&[vec![0.5, 0.9, 0.3, 0.7]]);
244 assert!(pruner.should_prune("f1", 0.6, 0, &history).is_none());
245 }
246
247 #[test]
248 fn percentile_warmup_respected() {
249 let pruner = PercentilePruner::new(50.0, 10);
250 let history = make_history(&[vec![0.9]]);
251 assert!(pruner.should_prune("f1", 0.1, 5, &history).is_none());
252 }
253}