1pub mod bayesian;
8
9pub use bayesian::BayesianSampler;
10
11use somatize_core::error::Result;
12use somatize_core::search::{Scale, SearchDimension, SearchSpace};
13use std::collections::HashMap;
14
15pub trait Sampler: Send + Sync {
17 fn sample(
19 &mut self,
20 space: &SearchSpace,
21 trial_index: usize,
22 ) -> Result<Option<HashMap<String, serde_json::Value>>>;
23
24 fn n_trials(&self) -> Option<usize>;
26}
27
28pub struct GridSampler {
38 points_per_dim: usize,
39 dim_values: Option<Vec<(String, Vec<serde_json::Value>)>>,
41 total: Option<usize>,
43}
44
45impl GridSampler {
46 pub fn new(points_per_dim: usize) -> Self {
47 Self {
48 points_per_dim,
49 dim_values: None,
50 total: None,
51 }
52 }
53
54 fn ensure_dims(&mut self, space: &SearchSpace) {
56 if self.dim_values.is_some() {
57 return;
58 }
59 let dims: Vec<(String, Vec<serde_json::Value>)> = space
60 .active_dimensions()
61 .iter()
62 .map(|dim| {
63 let name = dim.name().to_string();
64 let values = self.discretize(dim);
65 (name, values)
66 })
67 .collect();
68
69 let total = if dims.is_empty() {
70 1 } else {
72 dims.iter().map(|(_, v)| v.len()).product()
73 };
74
75 self.dim_values = Some(dims);
76 self.total = Some(total);
77 }
78
79 fn sample_at(&self, trial_index: usize) -> Option<HashMap<String, serde_json::Value>> {
82 let dims = self.dim_values.as_ref()?;
83 let total = self.total?;
84
85 if trial_index >= total {
86 return None;
87 }
88
89 if dims.is_empty() {
90 return Some(HashMap::new());
91 }
92
93 let mut params = HashMap::new();
94 let mut remaining = trial_index;
95
96 for (name, values) in dims.iter().rev() {
99 let dim_size = values.len();
100 let dim_idx = remaining % dim_size;
101 remaining /= dim_size;
102 params.insert(name.clone(), values[dim_idx].clone());
103 }
104
105 Some(params)
106 }
107
108 fn discretize(&self, dim: &SearchDimension) -> Vec<serde_json::Value> {
109 match dim {
110 SearchDimension::Float {
111 low, high, scale, ..
112 } => linspace(*low, *high, self.points_per_dim, *scale)
113 .into_iter()
114 .map(|v| serde_json::json!(v))
115 .collect(),
116 SearchDimension::Int {
117 low, high, scale, ..
118 } => {
119 let n = self.points_per_dim.min((*high - *low + 1) as usize);
120 linspace(*low as f64, *high as f64, n, *scale)
121 .into_iter()
122 .map(|v| serde_json::json!(v.round() as i64))
123 .collect()
124 }
125 SearchDimension::Categorical { choices, .. } => choices.clone(),
126 SearchDimension::Conditional { dimension, .. } => self.discretize(dimension),
127 _ => vec![serde_json::Value::Null],
128 }
129 }
130}
131
132impl Sampler for GridSampler {
133 fn sample(
134 &mut self,
135 space: &SearchSpace,
136 trial_index: usize,
137 ) -> Result<Option<HashMap<String, serde_json::Value>>> {
138 self.ensure_dims(space);
139 Ok(self.sample_at(trial_index))
140 }
141
142 fn n_trials(&self) -> Option<usize> {
143 self.total
144 }
145}
146
147pub struct RandomSampler {
153 n_trials: usize,
154 seed: u64,
155}
156
157impl RandomSampler {
158 pub fn new(n_trials: usize, seed: Option<u64>) -> Self {
159 Self {
160 n_trials,
161 seed: seed.unwrap_or(42),
162 }
163 }
164
165 fn sample_dim(&self, dim: &SearchDimension, rng_state: u64) -> serde_json::Value {
166 let t = pseudo_random(rng_state); match dim {
168 SearchDimension::Float {
169 low, high, scale, ..
170 } => {
171 let val = sample_float(*low, *high, *scale, t);
172 serde_json::json!(val)
173 }
174 SearchDimension::Int { low, high, .. } => {
175 let range = (*high - *low + 1) as f64;
176 let val = *low + (t * range).floor() as i64;
177 let val = val.min(*high);
178 serde_json::json!(val)
179 }
180 SearchDimension::Categorical { choices, .. } => {
181 let idx = (t * choices.len() as f64).floor() as usize;
182 let idx = idx.min(choices.len() - 1);
183 choices[idx].clone()
184 }
185 SearchDimension::Conditional { dimension, .. } => self.sample_dim(dimension, rng_state),
186 _ => serde_json::Value::Null,
187 }
188 }
189}
190
191impl Sampler for RandomSampler {
192 fn sample(
193 &mut self,
194 space: &SearchSpace,
195 trial_index: usize,
196 ) -> Result<Option<HashMap<String, serde_json::Value>>> {
197 if trial_index >= self.n_trials {
198 return Ok(None);
199 }
200
201 let mut params = HashMap::new();
202 for (i, dim) in space.active_dimensions().iter().enumerate() {
203 let rng_state = hash_u64(self.seed, trial_index as u64, i as u64);
205 let value = self.sample_dim(dim, rng_state);
206 params.insert(dim.name().to_string(), value);
207 }
208
209 Ok(Some(params))
210 }
211
212 fn n_trials(&self) -> Option<usize> {
213 Some(self.n_trials)
214 }
215}
216
217fn linspace(low: f64, high: f64, n: usize, scale: Scale) -> Vec<f64> {
223 if n <= 1 {
224 return vec![(low + high) / 2.0];
225 }
226 match scale {
227 Scale::Linear => (0..n)
228 .map(|i| low + (high - low) * (i as f64 / (n - 1) as f64))
229 .collect(),
230 Scale::Log => {
231 let log_low = low.max(1e-12).ln();
232 let log_high = high.max(1e-12).ln();
233 (0..n)
234 .map(|i| (log_low + (log_high - log_low) * (i as f64 / (n - 1) as f64)).exp())
235 .collect()
236 }
237 Scale::ReverseLog => {
238 linspace(low, high, n, Scale::Log)
240 .into_iter()
241 .rev()
242 .collect()
243 }
244 }
245}
246
247pub fn sample_float(low: f64, high: f64, scale: Scale, t: f64) -> f64 {
249 match scale {
250 Scale::Linear => low + (high - low) * t,
251 Scale::Log => {
252 let log_low = low.max(1e-12).ln();
253 let log_high = high.max(1e-12).ln();
254 (log_low + (log_high - log_low) * t).exp()
255 }
256 Scale::ReverseLog => {
257 let val = sample_float(low, high, Scale::Log, 1.0 - t);
258 low + high - val
259 }
260 }
261}
262
263pub fn pseudo_random(state: u64) -> f64 {
265 let h = splitmix64(state);
266 (h >> 11) as f64 / (1u64 << 53) as f64
267}
268
269pub fn hash_u64(seed: u64, a: u64, b: u64) -> u64 {
271 splitmix64(
272 seed.wrapping_add(a.wrapping_mul(6364136223846793005))
273 .wrapping_add(b),
274 )
275}
276
277pub fn splitmix64(mut x: u64) -> u64 {
279 x = x.wrapping_add(0x9e3779b97f4a7c15);
280 x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
281 x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
282 x ^ (x >> 31)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use serde_json::json;
289
290 fn sample_space() -> SearchSpace {
291 let mut space = SearchSpace::new();
292 space.add(SearchDimension::Float {
293 name: "lr".into(),
294 low: 0.001,
295 high: 0.1,
296 scale: Scale::Log,
297 default: None,
298 });
299 space.add(SearchDimension::Categorical {
300 name: "kernel".into(),
301 choices: vec![json!("rbf"), json!("linear"), json!("poly")],
302 });
303 space
304 }
305
306 #[test]
309 fn grid_sampler_generates_all_combinations() {
310 let mut sampler = GridSampler::new(3);
311 let space = sample_space();
312
313 let mut trials = Vec::new();
315 for i in 0.. {
316 match sampler.sample(&space, i).unwrap() {
317 Some(params) => trials.push(params),
318 None => break,
319 }
320 }
321
322 assert_eq!(trials.len(), 9);
323
324 for t in &trials {
326 assert!(t.contains_key("lr"));
327 assert!(t.contains_key("kernel"));
328 }
329
330 let kernels: Vec<&serde_json::Value> = trials.iter().map(|t| &t["kernel"]).collect();
332 assert!(kernels.contains(&&json!("rbf")));
333 assert!(kernels.contains(&&json!("linear")));
334 assert!(kernels.contains(&&json!("poly")));
335 }
336
337 #[test]
338 fn grid_sampler_respects_log_scale() {
339 let mut space = SearchSpace::new();
340 space.add(SearchDimension::Float {
341 name: "lr".into(),
342 low: 0.001,
343 high: 1.0,
344 scale: Scale::Log,
345 default: None,
346 });
347
348 let mut sampler = GridSampler::new(3);
349 let t0 = sampler.sample(&space, 0).unwrap().unwrap();
350 let t1 = sampler.sample(&space, 1).unwrap().unwrap();
351 let t2 = sampler.sample(&space, 2).unwrap().unwrap();
352
353 let v0 = t0["lr"].as_f64().unwrap();
354 let v1 = t1["lr"].as_f64().unwrap();
355 let v2 = t2["lr"].as_f64().unwrap();
356
357 assert!(v0 < v1 && v1 < v2);
359 assert!((v1 - v0) < (v2 - v1));
360 }
361
362 #[test]
363 fn grid_sampler_int_dimension() {
364 let mut space = SearchSpace::new();
365 space.add(SearchDimension::Int {
366 name: "n".into(),
367 low: 1,
368 high: 5,
369 scale: Scale::Linear,
370 });
371
372 let mut sampler = GridSampler::new(5);
373 let mut values = Vec::new();
374 for i in 0.. {
375 match sampler.sample(&space, i).unwrap() {
376 Some(p) => values.push(p["n"].as_i64().unwrap()),
377 None => break,
378 }
379 }
380 assert_eq!(values, vec![1, 2, 3, 4, 5]);
381 }
382
383 #[test]
384 fn grid_empty_space() {
385 let mut sampler = GridSampler::new(3);
386 let space = SearchSpace::new();
387 let result = sampler.sample(&space, 0).unwrap();
388 assert!(result.is_some()); assert!(result.unwrap().is_empty());
390 assert!(sampler.sample(&space, 1).unwrap().is_none());
391 }
392
393 #[test]
396 fn random_sampler_generates_n_trials() {
397 let mut sampler = RandomSampler::new(10, Some(42));
398 let space = sample_space();
399
400 let mut trials = Vec::new();
401 for i in 0..20 {
402 match sampler.sample(&space, i).unwrap() {
403 Some(params) => trials.push(params),
404 None => break,
405 }
406 }
407
408 assert_eq!(trials.len(), 10);
409 }
410
411 #[test]
412 fn random_sampler_respects_bounds() {
413 let mut space = SearchSpace::new();
414 space.add(SearchDimension::Float {
415 name: "x".into(),
416 low: 0.0,
417 high: 1.0,
418 scale: Scale::Linear,
419 default: None,
420 });
421 space.add(SearchDimension::Int {
422 name: "n".into(),
423 low: 5,
424 high: 10,
425 scale: Scale::Linear,
426 });
427
428 let mut sampler = RandomSampler::new(100, Some(123));
429
430 for i in 0..100 {
431 let params = sampler.sample(&space, i).unwrap().unwrap();
432 let x = params["x"].as_f64().unwrap();
433 let n = params["n"].as_i64().unwrap();
434 assert!((0.0..=1.0).contains(&x), "x={x} out of bounds");
435 assert!((5..=10).contains(&n), "n={n} out of bounds");
436 }
437 }
438
439 #[test]
440 fn random_sampler_deterministic_with_seed() {
441 let space = sample_space();
442
443 let mut s1 = RandomSampler::new(5, Some(42));
444 let mut s2 = RandomSampler::new(5, Some(42));
445
446 for i in 0..5 {
447 let p1 = s1.sample(&space, i).unwrap().unwrap();
448 let p2 = s2.sample(&space, i).unwrap().unwrap();
449 assert_eq!(p1, p2);
450 }
451 }
452
453 #[test]
454 fn random_sampler_different_seeds_differ() {
455 let space = sample_space();
456
457 let mut s1 = RandomSampler::new(5, Some(42));
458 let mut s2 = RandomSampler::new(5, Some(99));
459
460 let p1 = s1.sample(&space, 0).unwrap().unwrap();
461 let p2 = s2.sample(&space, 0).unwrap().unwrap();
462 assert_ne!(p1["lr"], p2["lr"]);
464 }
465
466 #[test]
469 fn linspace_linear() {
470 let vals = linspace(0.0, 10.0, 5, Scale::Linear);
471 assert_eq!(vals, vec![0.0, 2.5, 5.0, 7.5, 10.0]);
472 }
473
474 #[test]
475 fn linspace_single_point() {
476 let vals = linspace(0.0, 10.0, 1, Scale::Linear);
477 assert_eq!(vals, vec![5.0]);
478 }
479
480 #[test]
481 fn linspace_log_denser_at_low_end() {
482 let vals = linspace(0.001, 1.0, 5, Scale::Log);
483 let gaps: Vec<f64> = vals.windows(2).map(|w| w[1] - w[0]).collect();
485 for i in 1..gaps.len() {
486 assert!(gaps[i] > gaps[i - 1], "gap[{i}] should be > gap[{}]", i - 1);
487 }
488 }
489}