Skip to main content

somatize_runtime/sampler/
mod.rs

1//! Hyperparameter samplers for optimization studies.
2//!
3//! - [`GridSampler`] — exhaustive cartesian product, lazy index-based
4//! - [`RandomSampler`] — uniform sampling with deterministic seeds
5//! - [`BayesianSampler`] — TPE (Tree-Parzen Estimator) with explore/exploit
6
7pub mod bayesian;
8
9pub use bayesian::BayesianSampler;
10
11use somatize_core::error::Result;
12use somatize_core::search::{Scale, SearchDimension, SearchSpace};
13use std::collections::HashMap;
14
15/// A sampler produces hyperparameter configurations from a search space.
16pub trait Sampler: Send + Sync {
17    /// Sample the next set of parameters. Returns None when exhausted.
18    fn sample(
19        &mut self,
20        space: &SearchSpace,
21        trial_index: usize,
22    ) -> Result<Option<HashMap<String, serde_json::Value>>>;
23
24    /// Total number of trials this sampler will produce (if known).
25    fn n_trials(&self) -> Option<usize>;
26}
27
28// ──────────────────────────────────────────────
29// Grid Sampler
30// ──────────────────────────────────────────────
31
32/// Exhaustive grid search over all combinations.
33///
34/// Uses lazy index-based generation: instead of building the full cartesian
35/// product in memory, it computes the parameter set for a given trial index
36/// on the fly. Safe for large search spaces.
37pub struct GridSampler {
38    points_per_dim: usize,
39    /// Cached per-dimension discrete values (computed once, not the full grid).
40    dim_values: Option<Vec<(String, Vec<serde_json::Value>)>>,
41    /// Total number of combinations.
42    total: Option<usize>,
43}
44
45impl GridSampler {
46    pub fn new(points_per_dim: usize) -> Self {
47        Self {
48            points_per_dim,
49            dim_values: None,
50            total: None,
51        }
52    }
53
54    /// Compute discrete values for each dimension (once).
55    fn ensure_dims(&mut self, space: &SearchSpace) {
56        if self.dim_values.is_some() {
57            return;
58        }
59        let dims: Vec<(String, Vec<serde_json::Value>)> = space
60            .active_dimensions()
61            .iter()
62            .map(|dim| {
63                let name = dim.name().to_string();
64                let values = self.discretize(dim);
65                (name, values)
66            })
67            .collect();
68
69        let total = if dims.is_empty() {
70            1 // one combo with empty params
71        } else {
72            dims.iter().map(|(_, v)| v.len()).product()
73        };
74
75        self.dim_values = Some(dims);
76        self.total = Some(total);
77    }
78
79    /// Convert a flat trial index into a multi-dimensional index
80    /// and look up the parameter values. O(n_dims) per call.
81    fn sample_at(&self, trial_index: usize) -> Option<HashMap<String, serde_json::Value>> {
82        let dims = self.dim_values.as_ref()?;
83        let total = self.total?;
84
85        if trial_index >= total {
86            return None;
87        }
88
89        if dims.is_empty() {
90            return Some(HashMap::new());
91        }
92
93        let mut params = HashMap::new();
94        let mut remaining = trial_index;
95
96        // Decompose flat index into per-dimension indices
97        // like converting a number to mixed-radix representation
98        for (name, values) in dims.iter().rev() {
99            let dim_size = values.len();
100            let dim_idx = remaining % dim_size;
101            remaining /= dim_size;
102            params.insert(name.clone(), values[dim_idx].clone());
103        }
104
105        Some(params)
106    }
107
108    fn discretize(&self, dim: &SearchDimension) -> Vec<serde_json::Value> {
109        match dim {
110            SearchDimension::Float {
111                low, high, scale, ..
112            } => linspace(*low, *high, self.points_per_dim, *scale)
113                .into_iter()
114                .map(|v| serde_json::json!(v))
115                .collect(),
116            SearchDimension::Int {
117                low, high, scale, ..
118            } => {
119                let n = self.points_per_dim.min((*high - *low + 1) as usize);
120                linspace(*low as f64, *high as f64, n, *scale)
121                    .into_iter()
122                    .map(|v| serde_json::json!(v.round() as i64))
123                    .collect()
124            }
125            SearchDimension::Categorical { choices, .. } => choices.clone(),
126            SearchDimension::Conditional { dimension, .. } => self.discretize(dimension),
127            _ => vec![serde_json::Value::Null],
128        }
129    }
130}
131
132impl Sampler for GridSampler {
133    fn sample(
134        &mut self,
135        space: &SearchSpace,
136        trial_index: usize,
137    ) -> Result<Option<HashMap<String, serde_json::Value>>> {
138        self.ensure_dims(space);
139        Ok(self.sample_at(trial_index))
140    }
141
142    fn n_trials(&self) -> Option<usize> {
143        self.total
144    }
145}
146
147// ──────────────────────────────────────────────
148// Random Sampler
149// ──────────────────────────────────────────────
150
151/// Random search: sample uniformly from each dimension.
152pub struct RandomSampler {
153    n_trials: usize,
154    seed: u64,
155}
156
157impl RandomSampler {
158    pub fn new(n_trials: usize, seed: Option<u64>) -> Self {
159        Self {
160            n_trials,
161            seed: seed.unwrap_or(42),
162        }
163    }
164
165    fn sample_dim(&self, dim: &SearchDimension, rng_state: u64) -> serde_json::Value {
166        let t = pseudo_random(rng_state); // [0.0, 1.0)
167        match dim {
168            SearchDimension::Float {
169                low, high, scale, ..
170            } => {
171                let val = sample_float(*low, *high, *scale, t);
172                serde_json::json!(val)
173            }
174            SearchDimension::Int { low, high, .. } => {
175                let range = (*high - *low + 1) as f64;
176                let val = *low + (t * range).floor() as i64;
177                let val = val.min(*high);
178                serde_json::json!(val)
179            }
180            SearchDimension::Categorical { choices, .. } => {
181                let idx = (t * choices.len() as f64).floor() as usize;
182                let idx = idx.min(choices.len() - 1);
183                choices[idx].clone()
184            }
185            SearchDimension::Conditional { dimension, .. } => self.sample_dim(dimension, rng_state),
186            _ => serde_json::Value::Null,
187        }
188    }
189}
190
191impl Sampler for RandomSampler {
192    fn sample(
193        &mut self,
194        space: &SearchSpace,
195        trial_index: usize,
196    ) -> Result<Option<HashMap<String, serde_json::Value>>> {
197        if trial_index >= self.n_trials {
198            return Ok(None);
199        }
200
201        let mut params = HashMap::new();
202        for (i, dim) in space.active_dimensions().iter().enumerate() {
203            // Different rng state per dimension per trial
204            let rng_state = hash_u64(self.seed, trial_index as u64, i as u64);
205            let value = self.sample_dim(dim, rng_state);
206            params.insert(dim.name().to_string(), value);
207        }
208
209        Ok(Some(params))
210    }
211
212    fn n_trials(&self) -> Option<usize> {
213        Some(self.n_trials)
214    }
215}
216
217// ──────────────────────────────────────────────
218// Helpers
219// ──────────────────────────────────────────────
220
221/// Generate evenly spaced values in a range, respecting scale.
222fn linspace(low: f64, high: f64, n: usize, scale: Scale) -> Vec<f64> {
223    if n <= 1 {
224        return vec![(low + high) / 2.0];
225    }
226    match scale {
227        Scale::Linear => (0..n)
228            .map(|i| low + (high - low) * (i as f64 / (n - 1) as f64))
229            .collect(),
230        Scale::Log => {
231            let log_low = low.max(1e-12).ln();
232            let log_high = high.max(1e-12).ln();
233            (0..n)
234                .map(|i| (log_low + (log_high - log_low) * (i as f64 / (n - 1) as f64)).exp())
235                .collect()
236        }
237        Scale::ReverseLog => {
238            // Reverse: denser at high end
239            linspace(low, high, n, Scale::Log)
240                .into_iter()
241                .rev()
242                .collect()
243        }
244    }
245}
246
247/// Sample a float from [low, high] given t in [0, 1), respecting scale.
248pub fn sample_float(low: f64, high: f64, scale: Scale, t: f64) -> f64 {
249    match scale {
250        Scale::Linear => low + (high - low) * t,
251        Scale::Log => {
252            let log_low = low.max(1e-12).ln();
253            let log_high = high.max(1e-12).ln();
254            (log_low + (log_high - log_low) * t).exp()
255        }
256        Scale::ReverseLog => {
257            let val = sample_float(low, high, Scale::Log, 1.0 - t);
258            low + high - val
259        }
260    }
261}
262
263/// Simple deterministic pseudo-random (public for use by BayesianSampler): hash-based, returns [0.0, 1.0).
264pub fn pseudo_random(state: u64) -> f64 {
265    let h = splitmix64(state);
266    (h >> 11) as f64 / (1u64 << 53) as f64
267}
268
269/// Simple hash combiner for generating unique RNG states.
270pub fn hash_u64(seed: u64, a: u64, b: u64) -> u64 {
271    splitmix64(
272        seed.wrapping_add(a.wrapping_mul(6364136223846793005))
273            .wrapping_add(b),
274    )
275}
276
277/// SplitMix64 hash function.
278pub fn splitmix64(mut x: u64) -> u64 {
279    x = x.wrapping_add(0x9e3779b97f4a7c15);
280    x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
281    x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
282    x ^ (x >> 31)
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use serde_json::json;
289
290    fn sample_space() -> SearchSpace {
291        let mut space = SearchSpace::new();
292        space.add(SearchDimension::Float {
293            name: "lr".into(),
294            low: 0.001,
295            high: 0.1,
296            scale: Scale::Log,
297            default: None,
298        });
299        space.add(SearchDimension::Categorical {
300            name: "kernel".into(),
301            choices: vec![json!("rbf"), json!("linear"), json!("poly")],
302        });
303        space
304    }
305
306    // ── Grid tests ──
307
308    #[test]
309    fn grid_sampler_generates_all_combinations() {
310        let mut sampler = GridSampler::new(3);
311        let space = sample_space();
312
313        // 3 points for lr * 3 choices for kernel = 9 combinations
314        let mut trials = Vec::new();
315        for i in 0.. {
316            match sampler.sample(&space, i).unwrap() {
317                Some(params) => trials.push(params),
318                None => break,
319            }
320        }
321
322        assert_eq!(trials.len(), 9);
323
324        // All should have both params
325        for t in &trials {
326            assert!(t.contains_key("lr"));
327            assert!(t.contains_key("kernel"));
328        }
329
330        // All kernels should appear
331        let kernels: Vec<&serde_json::Value> = trials.iter().map(|t| &t["kernel"]).collect();
332        assert!(kernels.contains(&&json!("rbf")));
333        assert!(kernels.contains(&&json!("linear")));
334        assert!(kernels.contains(&&json!("poly")));
335    }
336
337    #[test]
338    fn grid_sampler_respects_log_scale() {
339        let mut space = SearchSpace::new();
340        space.add(SearchDimension::Float {
341            name: "lr".into(),
342            low: 0.001,
343            high: 1.0,
344            scale: Scale::Log,
345            default: None,
346        });
347
348        let mut sampler = GridSampler::new(3);
349        let t0 = sampler.sample(&space, 0).unwrap().unwrap();
350        let t1 = sampler.sample(&space, 1).unwrap().unwrap();
351        let t2 = sampler.sample(&space, 2).unwrap().unwrap();
352
353        let v0 = t0["lr"].as_f64().unwrap();
354        let v1 = t1["lr"].as_f64().unwrap();
355        let v2 = t2["lr"].as_f64().unwrap();
356
357        // Log scale: gap between v0-v1 should be smaller than v1-v2
358        assert!(v0 < v1 && v1 < v2);
359        assert!((v1 - v0) < (v2 - v1));
360    }
361
362    #[test]
363    fn grid_sampler_int_dimension() {
364        let mut space = SearchSpace::new();
365        space.add(SearchDimension::Int {
366            name: "n".into(),
367            low: 1,
368            high: 5,
369            scale: Scale::Linear,
370        });
371
372        let mut sampler = GridSampler::new(5);
373        let mut values = Vec::new();
374        for i in 0.. {
375            match sampler.sample(&space, i).unwrap() {
376                Some(p) => values.push(p["n"].as_i64().unwrap()),
377                None => break,
378            }
379        }
380        assert_eq!(values, vec![1, 2, 3, 4, 5]);
381    }
382
383    #[test]
384    fn grid_empty_space() {
385        let mut sampler = GridSampler::new(3);
386        let space = SearchSpace::new();
387        let result = sampler.sample(&space, 0).unwrap();
388        assert!(result.is_some()); // one combo with empty params
389        assert!(result.unwrap().is_empty());
390        assert!(sampler.sample(&space, 1).unwrap().is_none());
391    }
392
393    // ── Random tests ──
394
395    #[test]
396    fn random_sampler_generates_n_trials() {
397        let mut sampler = RandomSampler::new(10, Some(42));
398        let space = sample_space();
399
400        let mut trials = Vec::new();
401        for i in 0..20 {
402            match sampler.sample(&space, i).unwrap() {
403                Some(params) => trials.push(params),
404                None => break,
405            }
406        }
407
408        assert_eq!(trials.len(), 10);
409    }
410
411    #[test]
412    fn random_sampler_respects_bounds() {
413        let mut space = SearchSpace::new();
414        space.add(SearchDimension::Float {
415            name: "x".into(),
416            low: 0.0,
417            high: 1.0,
418            scale: Scale::Linear,
419            default: None,
420        });
421        space.add(SearchDimension::Int {
422            name: "n".into(),
423            low: 5,
424            high: 10,
425            scale: Scale::Linear,
426        });
427
428        let mut sampler = RandomSampler::new(100, Some(123));
429
430        for i in 0..100 {
431            let params = sampler.sample(&space, i).unwrap().unwrap();
432            let x = params["x"].as_f64().unwrap();
433            let n = params["n"].as_i64().unwrap();
434            assert!((0.0..=1.0).contains(&x), "x={x} out of bounds");
435            assert!((5..=10).contains(&n), "n={n} out of bounds");
436        }
437    }
438
439    #[test]
440    fn random_sampler_deterministic_with_seed() {
441        let space = sample_space();
442
443        let mut s1 = RandomSampler::new(5, Some(42));
444        let mut s2 = RandomSampler::new(5, Some(42));
445
446        for i in 0..5 {
447            let p1 = s1.sample(&space, i).unwrap().unwrap();
448            let p2 = s2.sample(&space, i).unwrap().unwrap();
449            assert_eq!(p1, p2);
450        }
451    }
452
453    #[test]
454    fn random_sampler_different_seeds_differ() {
455        let space = sample_space();
456
457        let mut s1 = RandomSampler::new(5, Some(42));
458        let mut s2 = RandomSampler::new(5, Some(99));
459
460        let p1 = s1.sample(&space, 0).unwrap().unwrap();
461        let p2 = s2.sample(&space, 0).unwrap().unwrap();
462        // Very unlikely to be equal with different seeds
463        assert_ne!(p1["lr"], p2["lr"]);
464    }
465
466    // ── Linspace tests ──
467
468    #[test]
469    fn linspace_linear() {
470        let vals = linspace(0.0, 10.0, 5, Scale::Linear);
471        assert_eq!(vals, vec![0.0, 2.5, 5.0, 7.5, 10.0]);
472    }
473
474    #[test]
475    fn linspace_single_point() {
476        let vals = linspace(0.0, 10.0, 1, Scale::Linear);
477        assert_eq!(vals, vec![5.0]);
478    }
479
480    #[test]
481    fn linspace_log_denser_at_low_end() {
482        let vals = linspace(0.001, 1.0, 5, Scale::Log);
483        // Log scale: gaps should increase
484        let gaps: Vec<f64> = vals.windows(2).map(|w| w[1] - w[0]).collect();
485        for i in 1..gaps.len() {
486            assert!(gaps[i] > gaps[i - 1], "gap[{i}] should be > gap[{}]", i - 1);
487        }
488    }
489}