Skip to main content

somatize_runtime/executors/
pbt.rs

1//! Population-Based Training runner.
2//!
3//! PBT is a cyclic evolutionary process where each generation:
4//! 1. **Train**: each population member trains for N steps
5//! 2. **Evaluate**: each member is evaluated to produce a fitness score
6//! 3. **Exploit/Explore**: underperformers copy top performers, then mutate hyperparameters
7//!
8//! Each generation's training phase uses the existing sampler infrastructure.
9
10use crate::event_bus::EventBus;
11use crate::sampler::{hash_u64, pseudo_random};
12use somatize_core::error::Result;
13use somatize_core::event::Event;
14use somatize_core::search::{SearchDimension, SearchSpace};
15use somatize_core::strategy::{ExploitStrategy, ExploreStrategy};
16use somatize_core::value::Value;
17use std::collections::HashMap;
18use std::sync::Arc;
19
20/// Configuration for a PBT run.
21#[derive(Debug, Clone)]
22pub struct PbtConfig {
23    pub population_size: usize,
24    pub generations: usize,
25    pub exploit: ExploitStrategy,
26    pub explore: ExploreStrategy,
27    pub search_space: SearchSpace,
28    pub train_steps_per_generation: usize,
29}
30
31/// A single member of the population.
32#[derive(Debug, Clone)]
33pub struct PopulationMember {
34    pub id: String,
35    pub params: HashMap<String, serde_json::Value>,
36    pub state: Value,
37    pub fitness: Option<f64>,
38}
39
40/// Trait for the training + evaluation callback.
41pub trait PbtExecutor: Send + Sync {
42    /// Train a member for one generation. Returns updated state.
43    fn train(&self, member: &PopulationMember) -> Result<Value>;
44    /// Evaluate a member. Returns fitness score (higher = better).
45    fn evaluate(&self, member: &PopulationMember) -> Result<f64>;
46}
47
48/// Function-based PBT executor for convenience.
49pub struct FnPbtExecutor<T, E> {
50    pub train_fn: T,
51    pub eval_fn: E,
52}
53
54impl<T, E> PbtExecutor for FnPbtExecutor<T, E>
55where
56    T: Fn(&PopulationMember) -> Result<Value> + Send + Sync,
57    E: Fn(&PopulationMember) -> Result<f64> + Send + Sync,
58{
59    fn train(&self, member: &PopulationMember) -> Result<Value> {
60        (self.train_fn)(member)
61    }
62    fn evaluate(&self, member: &PopulationMember) -> Result<f64> {
63        (self.eval_fn)(member)
64    }
65}
66
67/// Orchestrates the PBT evolutionary cycle.
68pub struct PbtRunner {
69    event_bus: Arc<EventBus>,
70}
71
72impl PbtRunner {
73    pub fn new(event_bus: Arc<EventBus>) -> Self {
74        Self { event_bus }
75    }
76
77    /// Run the full PBT evolutionary process.
78    ///
79    /// Returns the final population sorted by fitness (best first).
80    pub fn run(
81        &self,
82        config: &PbtConfig,
83        executor: &dyn PbtExecutor,
84    ) -> Result<Vec<PopulationMember>> {
85        let study_id = somatize_core::util::timestamp_id("pbt");
86        let mut rng_state: u64 = 42;
87
88        // Initialize population with random params
89        let mut population = self.initialize_population(config, &mut rng_state);
90
91        for generation in 0..config.generations {
92            self.event_bus.emit(Event::GenerationStarted {
93                study_id: study_id.clone(),
94                generation,
95                population_size: population.len(),
96            });
97
98            // Stage 1: Train
99            for member in &mut population {
100                match executor.train(member) {
101                    Ok(new_state) => member.state = new_state,
102                    Err(e) => {
103                        tracing::warn!("PBT train failed for {}: {e}", member.id);
104                    }
105                }
106            }
107
108            // Stage 2: Evaluate
109            for member in &mut population {
110                match executor.evaluate(member) {
111                    Ok(fitness) => member.fitness = Some(fitness),
112                    Err(e) => {
113                        tracing::warn!("PBT evaluate failed for {}: {e}", member.id);
114                        member.fitness = Some(f64::NEG_INFINITY);
115                    }
116                }
117            }
118
119            // Sort by fitness (descending)
120            population.sort_by(|a, b| {
121                b.fitness
122                    .unwrap_or(f64::NEG_INFINITY)
123                    .partial_cmp(&a.fitness.unwrap_or(f64::NEG_INFINITY))
124                    .unwrap_or(std::cmp::Ordering::Equal)
125            });
126
127            let best_fitness = population[0].fitness.unwrap_or(0.0);
128            let mean_fitness =
129                population.iter().filter_map(|m| m.fitness).sum::<f64>() / population.len() as f64;
130
131            // Stage 3: Exploit/Explore
132            self.evolve(
133                &mut population,
134                config,
135                generation,
136                &study_id,
137                &mut rng_state,
138            );
139
140            self.event_bus.emit(Event::GenerationCompleted {
141                study_id: study_id.clone(),
142                generation,
143                best_fitness,
144                mean_fitness,
145            });
146        }
147
148        // Final sort
149        population.sort_by(|a, b| {
150            b.fitness
151                .unwrap_or(f64::NEG_INFINITY)
152                .partial_cmp(&a.fitness.unwrap_or(f64::NEG_INFINITY))
153                .unwrap_or(std::cmp::Ordering::Equal)
154        });
155
156        Ok(population)
157    }
158
159    fn initialize_population(
160        &self,
161        config: &PbtConfig,
162        rng_state: &mut u64,
163    ) -> Vec<PopulationMember> {
164        let mut population = Vec::with_capacity(config.population_size);
165
166        for i in 0..config.population_size {
167            let params = sample_params(&config.search_space, rng_state);
168            population.push(PopulationMember {
169                id: format!("member_{i}"),
170                params,
171                state: Value::Empty,
172                fitness: None,
173            });
174        }
175
176        population
177    }
178
179    fn evolve(
180        &self,
181        population: &mut [PopulationMember],
182        config: &PbtConfig,
183        generation: usize,
184        study_id: &str,
185        rng_state: &mut u64,
186    ) {
187        let n = population.len();
188        if n < 2 {
189            return;
190        }
191
192        let cutoff = match &config.exploit {
193            ExploitStrategy::Truncation { fraction } => {
194                let c = ((n as f64) * fraction).ceil() as usize;
195                c.max(1).min(n / 2)
196            }
197            ExploitStrategy::Binary { .. } => n / 2,
198            _ => n / 2,
199        };
200
201        // Exploit: bottom performers copy from top
202        match &config.exploit {
203            ExploitStrategy::Truncation { .. } => {
204                for i in 0..cutoff {
205                    let bottom_idx = n - 1 - i;
206                    let top_idx = i;
207                    if bottom_idx <= top_idx {
208                        break;
209                    }
210
211                    let donor_id = population[top_idx].id.clone();
212                    let replaced_id = population[bottom_idx].id.clone();
213
214                    population[bottom_idx].params = population[top_idx].params.clone();
215                    population[bottom_idx].state = population[top_idx].state.clone();
216
217                    self.event_bus.emit(Event::MemberExploited {
218                        study_id: study_id.to_string(),
219                        generation,
220                        replaced_id,
221                        donor_id,
222                    });
223                }
224            }
225            ExploitStrategy::Binary { .. } => {
226                for i in cutoff..n {
227                    *rng_state = hash_u64(*rng_state, i as u64, generation as u64);
228                    let opponent = (*rng_state as usize) % cutoff;
229                    let my_fitness = population[i].fitness.unwrap_or(f64::NEG_INFINITY);
230                    let opp_fitness = population[opponent].fitness.unwrap_or(f64::NEG_INFINITY);
231                    if my_fitness < opp_fitness {
232                        let donor_id = population[opponent].id.clone();
233                        let replaced_id = population[i].id.clone();
234                        population[i].params = population[opponent].params.clone();
235                        population[i].state = population[opponent].state.clone();
236
237                        self.event_bus.emit(Event::MemberExploited {
238                            study_id: study_id.to_string(),
239                            generation,
240                            replaced_id,
241                            donor_id,
242                        });
243                    }
244                }
245            }
246            _ => {}
247        }
248
249        // Explore: mutate exploited members' hyperparameters
250        match &config.explore {
251            ExploreStrategy::Perturbation { factor } => {
252                for member in population[(n - cutoff)..].iter_mut() {
253                    perturb_params(&mut member.params, *factor, rng_state);
254                }
255            }
256            ExploreStrategy::Resample => {
257                for member in population[(n - cutoff)..].iter_mut() {
258                    member.params = sample_params(&config.search_space, rng_state);
259                }
260            }
261            _ => {}
262        }
263    }
264}
265
266/// Sample random parameters from a search space.
267fn sample_params(space: &SearchSpace, rng_state: &mut u64) -> HashMap<String, serde_json::Value> {
268    let mut params = HashMap::new();
269
270    for (dim_idx, dim) in space.dimensions.iter().enumerate() {
271        *rng_state = hash_u64(*rng_state, dim_idx as u64, 0);
272        let value = match dim {
273            SearchDimension::Float { low, high, .. } => {
274                let t = pseudo_random(*rng_state);
275                let v = low + t * (high - low);
276                serde_json::Value::from(v)
277            }
278            SearchDimension::Int { low, high, .. } => {
279                let t = pseudo_random(*rng_state);
280                let range = (*high - *low + 1) as f64;
281                let v = *low + (t * range) as i64;
282                serde_json::Value::from(v.min(*high))
283            }
284            SearchDimension::Categorical { choices, .. } => {
285                let t = pseudo_random(*rng_state);
286                let idx = (t * choices.len() as f64) as usize;
287                let idx = idx.min(choices.len() - 1);
288                choices[idx].clone()
289            }
290            _ => continue,
291        };
292        params.insert(dim.name().to_string(), value);
293    }
294
295    params
296}
297
298/// Perturb numeric parameters by a random factor in [1-factor, 1+factor].
299fn perturb_params(
300    params: &mut HashMap<String, serde_json::Value>,
301    factor: f64,
302    rng_state: &mut u64,
303) {
304    for (i, value) in params.values_mut().enumerate() {
305        if let Some(v) = value.as_f64() {
306            *rng_state = hash_u64(*rng_state, i as u64, 999);
307            let t = pseudo_random(*rng_state);
308            let perturbation = 1.0 + (t * 2.0 - 1.0) * factor;
309            *value = serde_json::Value::from(v * perturbation);
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use somatize_core::search::Scale;
318
319    fn test_config() -> PbtConfig {
320        let mut space = SearchSpace::new();
321        space.add(SearchDimension::Float {
322            name: "lr".into(),
323            low: 0.001,
324            high: 1.0,
325            scale: Scale::Log,
326            default: None,
327        });
328
329        PbtConfig {
330            population_size: 6,
331            generations: 3,
332            exploit: ExploitStrategy::Truncation { fraction: 0.33 },
333            explore: ExploreStrategy::Perturbation { factor: 0.2 },
334            search_space: space,
335            train_steps_per_generation: 10,
336        }
337    }
338
339    #[test]
340    fn pbt_basic_run() {
341        let bus = Arc::new(EventBus::new(256));
342        let runner = PbtRunner::new(bus);
343
344        let executor = FnPbtExecutor {
345            train_fn: |member: &PopulationMember| {
346                let lr = member
347                    .params
348                    .get("lr")
349                    .and_then(|v| v.as_f64())
350                    .unwrap_or(0.01);
351                Ok(Value::json(serde_json::json!({"lr": lr})))
352            },
353            eval_fn: |member: &PopulationMember| {
354                let lr = member
355                    .params
356                    .get("lr")
357                    .and_then(|v| v.as_f64())
358                    .unwrap_or(0.01);
359                Ok(-(lr - 0.1).abs())
360            },
361        };
362
363        let config = test_config();
364        let result = runner.run(&config, &executor).unwrap();
365
366        assert_eq!(result.len(), 6);
367        assert!(result.iter().all(|m| m.fitness.is_some()));
368        // Sorted by fitness descending
369        assert!(result[0].fitness.unwrap() >= result.last().unwrap().fitness.unwrap());
370    }
371
372    #[test]
373    fn pbt_emits_events() {
374        let bus = Arc::new(EventBus::new(256));
375        let mut rx = bus.subscribe();
376        let runner = PbtRunner::new(bus);
377
378        let executor = FnPbtExecutor {
379            train_fn: |_: &PopulationMember| Ok(Value::Empty),
380            eval_fn: |_: &PopulationMember| Ok(1.0),
381        };
382
383        let config = test_config();
384        runner.run(&config, &executor).unwrap();
385
386        let mut events = Vec::new();
387        while let Ok(e) = rx.try_recv() {
388            events.push(e);
389        }
390
391        let gen_started = events
392            .iter()
393            .filter(|e| matches!(e, Event::GenerationStarted { .. }))
394            .count();
395        let gen_completed = events
396            .iter()
397            .filter(|e| matches!(e, Event::GenerationCompleted { .. }))
398            .count();
399        assert_eq!(gen_started, 3);
400        assert_eq!(gen_completed, 3);
401    }
402
403    #[test]
404    fn pbt_population_evolves() {
405        let bus = Arc::new(EventBus::new(64));
406        let runner = PbtRunner::new(bus);
407
408        let executor = FnPbtExecutor {
409            train_fn: |_: &PopulationMember| Ok(Value::Empty),
410            eval_fn: |member: &PopulationMember| {
411                let lr = member
412                    .params
413                    .get("lr")
414                    .and_then(|v| v.as_f64())
415                    .unwrap_or(0.5);
416                // Fitness = -|lr - 0.1| (best at lr=0.1)
417                Ok(-(lr - 0.1).abs())
418            },
419        };
420
421        let mut config = test_config();
422        config.generations = 10;
423        let result = runner.run(&config, &executor).unwrap();
424
425        assert_eq!(result.len(), 6);
426        // All should have fitness
427        assert!(result.iter().all(|m| m.fitness.is_some()));
428    }
429}