Skip to main content

somatize_runtime/sampler/
bayesian.rs

1use crate::sampler::{Sampler, hash_u64, pseudo_random, sample_float};
2use somatize_core::error::Result;
3use somatize_core::search::{SearchDimension, SearchSpace};
4use std::collections::HashMap;
5
6/// Bayesian optimization sampler using Tree-Parzen Estimator (TPE).
7///
8/// For the first `n_startup` trials, samples randomly. After that,
9/// uses the history of (params, metric) to model "good" vs "bad"
10/// parameter distributions and samples from the "good" distribution.
11///
12/// This is a simplified TPE: it splits trials into top/bottom quantiles
13/// and samples from the top quantile's parameter distributions.
14pub struct BayesianSampler {
15    n_trials: usize,
16    n_startup: usize,
17    seed: u64,
18    /// History: (params, metric_value) for completed trials.
19    history: Vec<(HashMap<String, serde_json::Value>, f64)>,
20    /// Quantile split: top gamma fraction is "good".
21    gamma: f64,
22}
23
24impl BayesianSampler {
25    pub fn new(n_trials: usize, n_startup: usize, seed: Option<u64>) -> Self {
26        Self {
27            n_trials,
28            n_startup: n_startup.max(2),
29            seed: seed.unwrap_or(42),
30            history: Vec::new(),
31            gamma: 0.25, // top 25% are "good"
32        }
33    }
34
35    /// Record a completed trial's result (for informing future samples).
36    pub fn record(&mut self, params: HashMap<String, serde_json::Value>, metric: f64) {
37        self.history.push((params, metric));
38    }
39
40    /// Sample using TPE: bias towards parameters seen in "good" trials.
41    fn sample_tpe(
42        &self,
43        space: &SearchSpace,
44        trial_index: usize,
45    ) -> HashMap<String, serde_json::Value> {
46        // Split history into good/bad by quantile
47        let mut sorted_history: Vec<(usize, f64)> = self
48            .history
49            .iter()
50            .enumerate()
51            .map(|(i, (_, v))| (i, *v))
52            .collect();
53        sorted_history.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54
55        let n_good = (self.history.len() as f64 * self.gamma).ceil() as usize;
56        let n_good = n_good.max(1).min(self.history.len());
57        let good_indices: Vec<usize> = sorted_history[..n_good].iter().map(|(i, _)| *i).collect();
58
59        let mut params = HashMap::new();
60        for (dim_idx, dim) in space.active_dimensions().iter().enumerate() {
61            let rng_state = hash_u64(self.seed, trial_index as u64, dim_idx as u64);
62            let t = pseudo_random(rng_state);
63
64            // With 80% probability, sample near good trials' values for this dim.
65            // With 20% probability, sample uniformly (exploration).
66            let explore_prob = pseudo_random(hash_u64(
67                self.seed,
68                trial_index as u64,
69                dim_idx as u64 + 1000,
70            ));
71
72            let value = if explore_prob < 0.2 || good_indices.is_empty() {
73                // Explore: sample uniformly
74                self.sample_uniform(dim, t)
75            } else {
76                // Exploit: sample near a good trial's value
77                let good_idx = good_indices
78                    [((t * good_indices.len() as f64) as usize).min(good_indices.len() - 1)];
79                let good_params = &self.history[good_idx].0;
80
81                if let Some(good_val) = good_params.get(dim.name()) {
82                    self.sample_near(dim, good_val, rng_state)
83                } else {
84                    self.sample_uniform(dim, t)
85                }
86            };
87
88            params.insert(dim.name().to_string(), value);
89        }
90
91        params
92    }
93
94    fn sample_uniform(&self, dim: &SearchDimension, t: f64) -> serde_json::Value {
95        match dim {
96            SearchDimension::Float {
97                low, high, scale, ..
98            } => {
99                serde_json::json!(sample_float(*low, *high, *scale, t))
100            }
101            SearchDimension::Int { low, high, .. } => {
102                let range = (*high - *low + 1) as f64;
103                let val = *low + (t * range).floor() as i64;
104                serde_json::json!(val.min(*high))
105            }
106            SearchDimension::Categorical { choices, .. } => {
107                let idx = (t * choices.len() as f64).floor() as usize;
108                choices[idx.min(choices.len() - 1)].clone()
109            }
110            _ => serde_json::Value::Null,
111        }
112    }
113
114    /// Sample near a "good" value with gaussian-like perturbation.
115    fn sample_near(
116        &self,
117        dim: &SearchDimension,
118        center: &serde_json::Value,
119        rng_state: u64,
120    ) -> serde_json::Value {
121        let t = pseudo_random(hash_u64(rng_state, 777, 0));
122        let perturbation = (pseudo_random(hash_u64(rng_state, 888, 0)) - 0.5) * 0.3;
123
124        match dim {
125            SearchDimension::Float { low, high, .. } => {
126                if let Some(center_val) = center.as_f64() {
127                    let range = *high - *low;
128                    let new_val = (center_val + perturbation * range).clamp(*low, *high);
129                    serde_json::json!(new_val)
130                } else {
131                    self.sample_uniform(dim, t)
132                }
133            }
134            SearchDimension::Int { low, high, .. } => {
135                if let Some(center_val) = center.as_i64() {
136                    let range = (*high - *low) as f64;
137                    let new_val = (center_val as f64 + perturbation * range).round() as i64;
138                    serde_json::json!(new_val.clamp(*low, *high))
139                } else {
140                    self.sample_uniform(dim, t)
141                }
142            }
143            SearchDimension::Categorical { choices, .. } => {
144                // For categorical: mostly keep the good value, sometimes explore
145                if perturbation.abs() < 0.1 {
146                    center.clone()
147                } else {
148                    let idx = (t * choices.len() as f64).floor() as usize;
149                    choices[idx.min(choices.len() - 1)].clone()
150                }
151            }
152            _ => serde_json::Value::Null,
153        }
154    }
155}
156
157impl Sampler for BayesianSampler {
158    fn sample(
159        &mut self,
160        space: &SearchSpace,
161        trial_index: usize,
162    ) -> Result<Option<HashMap<String, serde_json::Value>>> {
163        if trial_index >= self.n_trials {
164            return Ok(None);
165        }
166
167        if trial_index < self.n_startup || self.history.is_empty() {
168            // Random startup phase
169            let mut params = HashMap::new();
170            for (i, dim) in space.active_dimensions().iter().enumerate() {
171                let rng_state = hash_u64(self.seed, trial_index as u64, i as u64);
172                let t = pseudo_random(rng_state);
173                params.insert(dim.name().to_string(), self.sample_uniform(dim, t));
174            }
175            Ok(Some(params))
176        } else {
177            Ok(Some(self.sample_tpe(space, trial_index)))
178        }
179    }
180
181    fn n_trials(&self) -> Option<usize> {
182        Some(self.n_trials)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use somatize_core::search::Scale;
190
191    fn sample_space() -> SearchSpace {
192        let mut space = SearchSpace::new();
193        space.add(SearchDimension::Float {
194            name: "lr".into(),
195            low: 0.001,
196            high: 0.1,
197            scale: Scale::Log,
198            default: None,
199        });
200        space.add(SearchDimension::Categorical {
201            name: "kernel".into(),
202            choices: vec![serde_json::json!("rbf"), serde_json::json!("linear")],
203        });
204        space
205    }
206
207    #[test]
208    fn startup_phase_is_random() {
209        let mut sampler = BayesianSampler::new(20, 5, Some(42));
210        let space = sample_space();
211
212        // First 5 trials should all produce different params (random)
213        let mut samples = Vec::new();
214        for i in 0..5 {
215            let params = sampler.sample(&space, i).unwrap().unwrap();
216            assert!(params.contains_key("lr"));
217            assert!(params.contains_key("kernel"));
218            samples.push(params);
219        }
220
221        // Check they're not all identical
222        let lrs: Vec<f64> = samples.iter().map(|p| p["lr"].as_f64().unwrap()).collect();
223        assert!(lrs.windows(2).any(|w| (w[0] - w[1]).abs() > 1e-10));
224    }
225
226    #[test]
227    fn tpe_phase_after_recording_history() {
228        let mut sampler = BayesianSampler::new(20, 3, Some(42));
229        let space = sample_space();
230
231        // Record some history
232        for i in 0..5 {
233            let params = sampler.sample(&space, i).unwrap().unwrap();
234            let lr = params["lr"].as_f64().unwrap();
235            let metric = 1.0 - (lr - 0.01).abs() * 10.0; // best at lr=0.01
236            sampler.record(params, metric);
237        }
238
239        // Now sample in TPE mode (trial_index >= n_startup)
240        let params = sampler.sample(&space, 5).unwrap().unwrap();
241        assert!(params.contains_key("lr"));
242        let lr = params["lr"].as_f64().unwrap();
243        assert!((0.001..=0.1).contains(&lr));
244    }
245
246    #[test]
247    fn respects_n_trials_limit() {
248        let mut sampler = BayesianSampler::new(10, 3, Some(42));
249        let space = sample_space();
250
251        for i in 0..15 {
252            let result = sampler.sample(&space, i).unwrap();
253            if i < 10 {
254                assert!(result.is_some());
255            } else {
256                assert!(result.is_none());
257            }
258        }
259    }
260
261    #[test]
262    fn deterministic_with_seed() {
263        let space = sample_space();
264
265        let mut s1 = BayesianSampler::new(10, 3, Some(42));
266        let mut s2 = BayesianSampler::new(10, 3, Some(42));
267
268        for i in 0..5 {
269            let p1 = s1.sample(&space, i).unwrap().unwrap();
270            let p2 = s2.sample(&space, i).unwrap().unwrap();
271            assert_eq!(p1, p2);
272        }
273    }
274
275    #[test]
276    fn different_seeds_differ() {
277        let space = sample_space();
278
279        let mut s1 = BayesianSampler::new(10, 3, Some(42));
280        let mut s2 = BayesianSampler::new(10, 3, Some(99));
281
282        let p1 = s1.sample(&space, 0).unwrap().unwrap();
283        let p2 = s2.sample(&space, 0).unwrap().unwrap();
284        assert_ne!(p1["lr"], p2["lr"]);
285    }
286}