somatize_runtime/executors/
pbt.rs1use crate::event_bus::EventBus;
11use crate::sampler::{hash_u64, pseudo_random};
12use somatize_core::error::Result;
13use somatize_core::event::Event;
14use somatize_core::search::{SearchDimension, SearchSpace};
15use somatize_core::strategy::{ExploitStrategy, ExploreStrategy};
16use somatize_core::value::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20#[derive(Debug, Clone)]
22pub struct PbtConfig {
23 pub population_size: usize,
24 pub generations: usize,
25 pub exploit: ExploitStrategy,
26 pub explore: ExploreStrategy,
27 pub search_space: SearchSpace,
28 pub train_steps_per_generation: usize,
29}
30
31#[derive(Debug, Clone)]
33pub struct PopulationMember {
34 pub id: String,
35 pub params: HashMap<String, serde_json::Value>,
36 pub state: Value,
37 pub fitness: Option<f64>,
38}
39
40pub trait PbtExecutor: Send + Sync {
42 fn train(&self, member: &PopulationMember) -> Result<Value>;
44 fn evaluate(&self, member: &PopulationMember) -> Result<f64>;
46}
47
48pub struct FnPbtExecutor<T, E> {
50 pub train_fn: T,
51 pub eval_fn: E,
52}
53
54impl<T, E> PbtExecutor for FnPbtExecutor<T, E>
55where
56 T: Fn(&PopulationMember) -> Result<Value> + Send + Sync,
57 E: Fn(&PopulationMember) -> Result<f64> + Send + Sync,
58{
59 fn train(&self, member: &PopulationMember) -> Result<Value> {
60 (self.train_fn)(member)
61 }
62 fn evaluate(&self, member: &PopulationMember) -> Result<f64> {
63 (self.eval_fn)(member)
64 }
65}
66
67pub struct PbtRunner {
69 event_bus: Arc<EventBus>,
70}
71
72impl PbtRunner {
73 pub fn new(event_bus: Arc<EventBus>) -> Self {
74 Self { event_bus }
75 }
76
77 pub fn run(
81 &self,
82 config: &PbtConfig,
83 executor: &dyn PbtExecutor,
84 ) -> Result<Vec<PopulationMember>> {
85 let study_id = somatize_core::util::timestamp_id("pbt");
86 let mut rng_state: u64 = 42;
87
88 let mut population = self.initialize_population(config, &mut rng_state);
90
91 for generation in 0..config.generations {
92 self.event_bus.emit(Event::GenerationStarted {
93 study_id: study_id.clone(),
94 generation,
95 population_size: population.len(),
96 });
97
98 for member in &mut population {
100 match executor.train(member) {
101 Ok(new_state) => member.state = new_state,
102 Err(e) => {
103 tracing::warn!("PBT train failed for {}: {e}", member.id);
104 }
105 }
106 }
107
108 for member in &mut population {
110 match executor.evaluate(member) {
111 Ok(fitness) => member.fitness = Some(fitness),
112 Err(e) => {
113 tracing::warn!("PBT evaluate failed for {}: {e}", member.id);
114 member.fitness = Some(f64::NEG_INFINITY);
115 }
116 }
117 }
118
119 population.sort_by(|a, b| {
121 b.fitness
122 .unwrap_or(f64::NEG_INFINITY)
123 .partial_cmp(&a.fitness.unwrap_or(f64::NEG_INFINITY))
124 .unwrap_or(std::cmp::Ordering::Equal)
125 });
126
127 let best_fitness = population[0].fitness.unwrap_or(0.0);
128 let mean_fitness =
129 population.iter().filter_map(|m| m.fitness).sum::<f64>() / population.len() as f64;
130
131 self.evolve(
133 &mut population,
134 config,
135 generation,
136 &study_id,
137 &mut rng_state,
138 );
139
140 self.event_bus.emit(Event::GenerationCompleted {
141 study_id: study_id.clone(),
142 generation,
143 best_fitness,
144 mean_fitness,
145 });
146 }
147
148 population.sort_by(|a, b| {
150 b.fitness
151 .unwrap_or(f64::NEG_INFINITY)
152 .partial_cmp(&a.fitness.unwrap_or(f64::NEG_INFINITY))
153 .unwrap_or(std::cmp::Ordering::Equal)
154 });
155
156 Ok(population)
157 }
158
159 fn initialize_population(
160 &self,
161 config: &PbtConfig,
162 rng_state: &mut u64,
163 ) -> Vec<PopulationMember> {
164 let mut population = Vec::with_capacity(config.population_size);
165
166 for i in 0..config.population_size {
167 let params = sample_params(&config.search_space, rng_state);
168 population.push(PopulationMember {
169 id: format!("member_{i}"),
170 params,
171 state: Value::Empty,
172 fitness: None,
173 });
174 }
175
176 population
177 }
178
179 fn evolve(
180 &self,
181 population: &mut [PopulationMember],
182 config: &PbtConfig,
183 generation: usize,
184 study_id: &str,
185 rng_state: &mut u64,
186 ) {
187 let n = population.len();
188 if n < 2 {
189 return;
190 }
191
192 let cutoff = match &config.exploit {
193 ExploitStrategy::Truncation { fraction } => {
194 let c = ((n as f64) * fraction).ceil() as usize;
195 c.max(1).min(n / 2)
196 }
197 ExploitStrategy::Binary { .. } => n / 2,
198 _ => n / 2,
199 };
200
201 match &config.exploit {
203 ExploitStrategy::Truncation { .. } => {
204 for i in 0..cutoff {
205 let bottom_idx = n - 1 - i;
206 let top_idx = i;
207 if bottom_idx <= top_idx {
208 break;
209 }
210
211 let donor_id = population[top_idx].id.clone();
212 let replaced_id = population[bottom_idx].id.clone();
213
214 population[bottom_idx].params = population[top_idx].params.clone();
215 population[bottom_idx].state = population[top_idx].state.clone();
216
217 self.event_bus.emit(Event::MemberExploited {
218 study_id: study_id.to_string(),
219 generation,
220 replaced_id,
221 donor_id,
222 });
223 }
224 }
225 ExploitStrategy::Binary { .. } => {
226 for i in cutoff..n {
227 *rng_state = hash_u64(*rng_state, i as u64, generation as u64);
228 let opponent = (*rng_state as usize) % cutoff;
229 let my_fitness = population[i].fitness.unwrap_or(f64::NEG_INFINITY);
230 let opp_fitness = population[opponent].fitness.unwrap_or(f64::NEG_INFINITY);
231 if my_fitness < opp_fitness {
232 let donor_id = population[opponent].id.clone();
233 let replaced_id = population[i].id.clone();
234 population[i].params = population[opponent].params.clone();
235 population[i].state = population[opponent].state.clone();
236
237 self.event_bus.emit(Event::MemberExploited {
238 study_id: study_id.to_string(),
239 generation,
240 replaced_id,
241 donor_id,
242 });
243 }
244 }
245 }
246 _ => {}
247 }
248
249 match &config.explore {
251 ExploreStrategy::Perturbation { factor } => {
252 for member in population[(n - cutoff)..].iter_mut() {
253 perturb_params(&mut member.params, *factor, rng_state);
254 }
255 }
256 ExploreStrategy::Resample => {
257 for member in population[(n - cutoff)..].iter_mut() {
258 member.params = sample_params(&config.search_space, rng_state);
259 }
260 }
261 _ => {}
262 }
263 }
264}
265
266fn sample_params(space: &SearchSpace, rng_state: &mut u64) -> HashMap<String, serde_json::Value> {
268 let mut params = HashMap::new();
269
270 for (dim_idx, dim) in space.dimensions.iter().enumerate() {
271 *rng_state = hash_u64(*rng_state, dim_idx as u64, 0);
272 let value = match dim {
273 SearchDimension::Float { low, high, .. } => {
274 let t = pseudo_random(*rng_state);
275 let v = low + t * (high - low);
276 serde_json::Value::from(v)
277 }
278 SearchDimension::Int { low, high, .. } => {
279 let t = pseudo_random(*rng_state);
280 let range = (*high - *low + 1) as f64;
281 let v = *low + (t * range) as i64;
282 serde_json::Value::from(v.min(*high))
283 }
284 SearchDimension::Categorical { choices, .. } => {
285 let t = pseudo_random(*rng_state);
286 let idx = (t * choices.len() as f64) as usize;
287 let idx = idx.min(choices.len() - 1);
288 choices[idx].clone()
289 }
290 _ => continue,
291 };
292 params.insert(dim.name().to_string(), value);
293 }
294
295 params
296}
297
298fn perturb_params(
300 params: &mut HashMap<String, serde_json::Value>,
301 factor: f64,
302 rng_state: &mut u64,
303) {
304 for (i, value) in params.values_mut().enumerate() {
305 if let Some(v) = value.as_f64() {
306 *rng_state = hash_u64(*rng_state, i as u64, 999);
307 let t = pseudo_random(*rng_state);
308 let perturbation = 1.0 + (t * 2.0 - 1.0) * factor;
309 *value = serde_json::Value::from(v * perturbation);
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use somatize_core::search::Scale;
318
319 fn test_config() -> PbtConfig {
320 let mut space = SearchSpace::new();
321 space.add(SearchDimension::Float {
322 name: "lr".into(),
323 low: 0.001,
324 high: 1.0,
325 scale: Scale::Log,
326 default: None,
327 });
328
329 PbtConfig {
330 population_size: 6,
331 generations: 3,
332 exploit: ExploitStrategy::Truncation { fraction: 0.33 },
333 explore: ExploreStrategy::Perturbation { factor: 0.2 },
334 search_space: space,
335 train_steps_per_generation: 10,
336 }
337 }
338
339 #[test]
340 fn pbt_basic_run() {
341 let bus = Arc::new(EventBus::new(256));
342 let runner = PbtRunner::new(bus);
343
344 let executor = FnPbtExecutor {
345 train_fn: |member: &PopulationMember| {
346 let lr = member
347 .params
348 .get("lr")
349 .and_then(|v| v.as_f64())
350 .unwrap_or(0.01);
351 Ok(Value::json(serde_json::json!({"lr": lr})))
352 },
353 eval_fn: |member: &PopulationMember| {
354 let lr = member
355 .params
356 .get("lr")
357 .and_then(|v| v.as_f64())
358 .unwrap_or(0.01);
359 Ok(-(lr - 0.1).abs())
360 },
361 };
362
363 let config = test_config();
364 let result = runner.run(&config, &executor).unwrap();
365
366 assert_eq!(result.len(), 6);
367 assert!(result.iter().all(|m| m.fitness.is_some()));
368 assert!(result[0].fitness.unwrap() >= result.last().unwrap().fitness.unwrap());
370 }
371
372 #[test]
373 fn pbt_emits_events() {
374 let bus = Arc::new(EventBus::new(256));
375 let mut rx = bus.subscribe();
376 let runner = PbtRunner::new(bus);
377
378 let executor = FnPbtExecutor {
379 train_fn: |_: &PopulationMember| Ok(Value::Empty),
380 eval_fn: |_: &PopulationMember| Ok(1.0),
381 };
382
383 let config = test_config();
384 runner.run(&config, &executor).unwrap();
385
386 let mut events = Vec::new();
387 while let Ok(e) = rx.try_recv() {
388 events.push(e);
389 }
390
391 let gen_started = events
392 .iter()
393 .filter(|e| matches!(e, Event::GenerationStarted { .. }))
394 .count();
395 let gen_completed = events
396 .iter()
397 .filter(|e| matches!(e, Event::GenerationCompleted { .. }))
398 .count();
399 assert_eq!(gen_started, 3);
400 assert_eq!(gen_completed, 3);
401 }
402
403 #[test]
404 fn pbt_population_evolves() {
405 let bus = Arc::new(EventBus::new(64));
406 let runner = PbtRunner::new(bus);
407
408 let executor = FnPbtExecutor {
409 train_fn: |_: &PopulationMember| Ok(Value::Empty),
410 eval_fn: |member: &PopulationMember| {
411 let lr = member
412 .params
413 .get("lr")
414 .and_then(|v| v.as_f64())
415 .unwrap_or(0.5);
416 Ok(-(lr - 0.1).abs())
418 },
419 };
420
421 let mut config = test_config();
422 config.generations = 10;
423 let result = runner.run(&config, &executor).unwrap();
424
425 assert_eq!(result.len(), 6);
426 assert!(result.iter().all(|m| m.fitness.is_some()));
428 }
429}