Skip to main content

sklears_model_selection/
population_based_training.rs

1//! Population-based Training (PBT) for hyperparameter optimization
2//!
3//! Population-based Training is a method for training a population of neural networks (or ML models)
4//! in parallel while optimizing hyperparameters. It combines parallel search with evolutionary
5//! methods to exploit and explore the hyperparameter space dynamically.
6//!
7//! The key idea is to periodically evaluate the performance of all models in the population,
8//! then replace the worst-performing models with copies of the best-performing models,
9//! with perturbed hyperparameters to continue exploration.
10
11use crate::{CrossValidator, KFold, Scoring};
12use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Dim, OwnedRepr};
13use scirs2_core::numeric::ToPrimitive;
14use scirs2_core::random::rngs::StdRng;
15use scirs2_core::random::{RngExt, SeedableRng};
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18use sklears_core::{
19    error::{Result, SklearsError},
20    traits::{Fit, Predict, Score},
21};
22use std::collections::HashMap;
23use std::fmt::Debug;
24use std::marker::PhantomData;
25
26/// Configuration for Population-based Training
27#[derive(Debug, Clone)]
28pub struct PBTConfig {
29    /// Population size (number of parallel workers)
30    pub population_size: usize,
31    /// How often to perform exploitation and exploration (in training steps)
32    pub perturbation_interval: usize,
33    /// Fraction of population to replace in each perturbation
34    pub replacement_fraction: f64,
35    /// Standard deviation for hyperparameter perturbation
36    pub perturbation_factor: f64,
37    /// Maximum number of training iterations
38    pub max_iterations: usize,
39    /// Early stopping patience
40    pub patience: Option<usize>,
41    /// Random seed for reproducibility
42    pub random_state: Option<u64>,
43}
44
45impl Default for PBTConfig {
46    fn default() -> Self {
47        Self {
48            population_size: 20,
49            perturbation_interval: 10,
50            replacement_fraction: 0.25,
51            perturbation_factor: 0.2,
52            max_iterations: 100,
53            patience: Some(10),
54            random_state: None,
55        }
56    }
57}
58
59/// Parameter space definition for PBT
60#[derive(Debug, Clone)]
61pub struct PBTParameterSpace {
62    /// Continuous parameters with their bounds
63    pub continuous: HashMap<String, (f64, f64)>,
64    /// Discrete parameters with their possible values
65    pub discrete: HashMap<String, Vec<f64>>,
66}
67
68impl Default for PBTParameterSpace {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl PBTParameterSpace {
75    pub fn new() -> Self {
76        Self {
77            continuous: HashMap::new(),
78            discrete: HashMap::new(),
79        }
80    }
81
82    /// Add a continuous parameter with bounds
83    pub fn add_continuous(mut self, name: &str, low: f64, high: f64) -> Self {
84        self.continuous.insert(name.to_string(), (low, high));
85        self
86    }
87
88    /// Add a discrete parameter with possible values
89    pub fn add_discrete(mut self, name: &str, values: Vec<f64>) -> Self {
90        self.discrete.insert(name.to_string(), values);
91        self
92    }
93
94    /// Sample a random parameter configuration
95    pub fn sample<R: RngExt>(&self, rng: &mut R) -> PBTParameters {
96        let mut params = PBTParameters::new();
97
98        for (name, (low, high)) in &self.continuous {
99            let value = rng.random_range(*low..*high + 1.0);
100            params.set(name.clone(), value);
101        }
102
103        for (name, values) in &self.discrete {
104            let idx = rng.random_range(0..values.len());
105            params.set(name.clone(), values[idx]);
106        }
107
108        params
109    }
110
111    /// Perturb parameters with exploration
112    pub fn perturb<R: RngExt>(
113        &self,
114        params: &PBTParameters,
115        factor: f64,
116        rng: &mut R,
117    ) -> PBTParameters {
118        let mut new_params = params.clone();
119
120        for (name, (low, high)) in &self.continuous {
121            if let Some(&current_value) = params.get(name) {
122                let range = high - low;
123                let perturbation = rng.random_range(-factor..factor + 1.0) * range;
124                let new_value = (current_value + perturbation).clamp(*low, *high);
125                new_params.set(name.clone(), new_value);
126            }
127        }
128
129        for (name, values) in &self.discrete {
130            if rng.random_bool(factor.min(1.0)) {
131                // Probability of changing discrete parameter
132                let idx = rng.random_range(0..values.len());
133                new_params.set(name.clone(), values[idx]);
134            }
135        }
136
137        new_params
138    }
139}
140
141/// Parameter configuration for a single worker
142#[derive(Debug, Clone)]
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144pub struct PBTParameters {
145    params: HashMap<String, f64>,
146}
147
148impl Default for PBTParameters {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl PBTParameters {
155    pub fn new() -> Self {
156        Self {
157            params: HashMap::new(),
158        }
159    }
160
161    pub fn set(&mut self, name: String, value: f64) {
162        self.params.insert(name, value);
163    }
164
165    pub fn get(&self, name: &str) -> Option<&f64> {
166        self.params.get(name)
167    }
168
169    pub fn iter(&self) -> impl Iterator<Item = (&String, &f64)> {
170        self.params.iter()
171    }
172}
173
174/// A worker in the PBT population
175#[derive(Debug, Clone)]
176pub struct PBTWorker<E> {
177    /// Worker ID
178    pub id: usize,
179    /// Current hyperparameters
180    pub parameters: PBTParameters,
181    /// Current model state
182    pub estimator: Option<E>,
183    /// Performance history
184    pub score_history: Vec<f64>,
185    /// Current performance score
186    pub current_score: f64,
187    /// Training step count
188    pub step: usize,
189}
190
191impl<E> PBTWorker<E> {
192    pub fn new(id: usize, parameters: PBTParameters) -> Self {
193        Self {
194            id,
195            parameters,
196            estimator: None,
197            score_history: Vec::new(),
198            current_score: f64::NEG_INFINITY,
199            step: 0,
200        }
201    }
202
203    /// Check if this worker should be replaced based on performance
204    pub fn should_be_replaced(&self, threshold_percentile: f64, all_scores: &[f64]) -> bool {
205        if all_scores.is_empty() {
206            return false;
207        }
208
209        let mut sorted_scores = all_scores.to_vec();
210        sorted_scores.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
211
212        let threshold_idx = (threshold_percentile * sorted_scores.len() as f64) as usize;
213        let threshold = sorted_scores
214            .get(threshold_idx)
215            .unwrap_or(&f64::NEG_INFINITY);
216
217        self.current_score < *threshold
218    }
219}
220
221/// Population-based Training optimizer
222pub struct PopulationBasedTraining<E, X, Y> {
223    config: PBTConfig,
224    parameter_space: PBTParameterSpace,
225    population: Vec<PBTWorker<E>>,
226    rng: StdRng,
227    _phantom: PhantomData<(X, Y)>,
228}
229
230impl<E, X, Y> PopulationBasedTraining<E, X, Y>
231where
232    E: Clone + Debug,
233{
234    /// Create a new PBT optimizer
235    pub fn new(config: PBTConfig, parameter_space: PBTParameterSpace) -> Self {
236        let rng = match config.random_state {
237            Some(seed) => StdRng::seed_from_u64(seed),
238            None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
239        };
240
241        Self {
242            config,
243            parameter_space,
244            population: Vec::new(),
245            rng,
246            _phantom: PhantomData,
247        }
248    }
249
250    /// Initialize the population with random hyperparameters
251    pub fn initialize_population(&mut self) {
252        self.population.clear();
253
254        for i in 0..self.config.population_size {
255            let parameters = self.parameter_space.sample(&mut self.rng);
256            let worker = PBTWorker::new(i, parameters);
257            self.population.push(worker);
258        }
259    }
260
261    /// Get the current population
262    pub fn population(&self) -> &[PBTWorker<E>] {
263        &self.population
264    }
265
266    /// Get the best worker in the current population
267    pub fn best_worker(&self) -> Option<&PBTWorker<E>> {
268        self.population.iter().max_by(|a, b| {
269            a.current_score
270                .partial_cmp(&b.current_score)
271                .expect("operation should succeed")
272        })
273    }
274
275    /// Update worker performance
276    pub fn update_worker_score(&mut self, worker_id: usize, score: f64) -> Result<()> {
277        let worker = self
278            .population
279            .get_mut(worker_id)
280            .ok_or_else(|| SklearsError::InvalidInput(format!("Worker {} not found", worker_id)))?;
281
282        worker.current_score = score;
283        worker.score_history.push(score);
284        worker.step += 1;
285
286        Ok(())
287    }
288
289    /// Perform exploitation and exploration step
290    pub fn exploit_and_explore(&mut self) -> Result<()> {
291        let population_size = self.population.len();
292        if population_size == 0 {
293            return Err(SklearsError::InvalidOperation(
294                "Empty population".to_string(),
295            ));
296        }
297
298        // Get current scores for ranking
299        let scores: Vec<f64> = self.population.iter().map(|w| w.current_score).collect();
300
301        // Identify workers to replace (bottom fraction)
302        let num_to_replace = (self.config.replacement_fraction * population_size as f64) as usize;
303        let mut worker_scores: Vec<(usize, f64)> =
304            scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
305        worker_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("operation should succeed"));
306
307        // Get indices of worst and best performers
308        let worst_indices: Vec<usize> = worker_scores
309            .iter()
310            .take(num_to_replace)
311            .map(|(i, _)| *i)
312            .collect();
313        let best_indices: Vec<usize> = worker_scores
314            .iter()
315            .rev()
316            .take(num_to_replace)
317            .map(|(i, _)| *i)
318            .collect();
319
320        // Replace worst performers with perturbed copies of best performers
321        for (&worst_idx, &best_idx) in worst_indices.iter().zip(best_indices.iter()) {
322            if worst_idx != best_idx {
323                // Copy the best performer's parameters
324                let best_params = self.population[best_idx].parameters.clone();
325
326                // Perturb the parameters for exploration
327                let perturbed_params = self.parameter_space.perturb(
328                    &best_params,
329                    self.config.perturbation_factor,
330                    &mut self.rng,
331                );
332
333                // Update the worst performer
334                self.population[worst_idx].parameters = perturbed_params;
335                self.population[worst_idx].current_score = f64::NEG_INFINITY;
336                self.population[worst_idx].score_history.clear();
337                self.population[worst_idx].step = 0;
338                self.population[worst_idx].estimator = None;
339            }
340        }
341
342        Ok(())
343    }
344
345    /// Check if any worker has converged
346    pub fn check_convergence(&self) -> bool {
347        if let Some(patience) = self.config.patience {
348            for worker in &self.population {
349                if worker.score_history.len() >= patience {
350                    let recent_scores =
351                        &worker.score_history[worker.score_history.len() - patience..];
352                    let max_recent = recent_scores
353                        .iter()
354                        .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
355                    let current = worker.current_score;
356
357                    // If no improvement in patience steps, consider converged
358                    if current <= max_recent {
359                        return true;
360                    }
361                }
362            }
363        }
364        false
365    }
366
367    /// Get optimization statistics
368    pub fn get_statistics(&self) -> PBTStatistics {
369        let scores: Vec<f64> = self.population.iter().map(|w| w.current_score).collect();
370
371        let best_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
372        let worst_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
373        let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
374
375        let variance = scores
376            .iter()
377            .map(|&x| (x - mean_score).powi(2))
378            .sum::<f64>()
379            / scores.len() as f64;
380        let std_dev = variance.sqrt();
381
382        PBTStatistics {
383            generation: self.population.iter().map(|w| w.step).max().unwrap_or(0),
384            population_size: self.population.len(),
385            best_score,
386            worst_score,
387            mean_score,
388            std_dev,
389        }
390    }
391}
392
393/// Statistics for PBT optimization
394#[derive(Debug, Clone)]
395#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
396pub struct PBTStatistics {
397    pub generation: usize,
398    pub population_size: usize,
399    pub best_score: f64,
400    pub worst_score: f64,
401    pub mean_score: f64,
402    pub std_dev: f64,
403}
404
405/// Result of PBT optimization
406#[derive(Debug, Clone)]
407pub struct PBTResult<E> {
408    /// Best hyperparameters found
409    pub best_params: PBTParameters,
410    /// Best score achieved
411    pub best_score: f64,
412    /// Best model/estimator
413    pub best_estimator: Option<E>,
414    /// Optimization history
415    pub history: Vec<PBTStatistics>,
416    /// Final population
417    pub final_population: Vec<PBTWorker<E>>,
418}
419
420/// Configuration function type for PBT
421pub type PBTConfigFn<E> = Box<dyn Fn(E, &PBTParameters) -> Result<E>>;
422
423/// Population-based Training with cross-validation
424pub struct PopulationBasedTrainingCV<E, X, Y> {
425    base_estimator: E,
426    config: PBTConfig,
427    parameter_space: PBTParameterSpace,
428    config_fn: PBTConfigFn<E>,
429    cv: Box<dyn CrossValidator>,
430    scoring: Option<Scoring>,
431    _phantom: PhantomData<(X, Y)>,
432}
433
434impl<E, X, Y> PopulationBasedTrainingCV<E, X, Y>
435where
436    E: Clone + Debug,
437{
438    /// Create a new PBT with cross-validation
439    pub fn new(
440        base_estimator: E,
441        config: PBTConfig,
442        parameter_space: PBTParameterSpace,
443        config_fn: PBTConfigFn<E>,
444    ) -> Self {
445        Self {
446            base_estimator,
447            config,
448            parameter_space,
449            config_fn,
450            cv: Box::new(KFold::new(5)),
451            scoring: None,
452            _phantom: PhantomData,
453        }
454    }
455
456    /// Set cross-validation strategy
457    pub fn cv<CV>(mut self, cv: CV) -> Self
458    where
459        CV: CrossValidator + 'static,
460    {
461        self.cv = Box::new(cv);
462        self
463    }
464
465    /// Set scoring function
466    pub fn scoring(mut self, scoring: Scoring) -> Self {
467        self.scoring = Some(scoring);
468        self
469    }
470}
471
472impl<E> PopulationBasedTrainingCV<E, Array2<f64>, Array1<f64>>
473where
474    E: Clone
475        + Debug
476        + Fit<
477            ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
478            ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
479        > + Score<
480            ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
481            ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
482        >,
483    E::Fitted: Clone
484        + Debug
485        + Predict<
486            ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
487            ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
488        > + Score<
489            ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
490            ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
491        >,
492{
493    /// Fit the PBT optimizer with cross-validation
494    pub fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<PBTResult<E::Fitted>> {
495        let mut pbt: PopulationBasedTraining<E::Fitted, Array2<f64>, Array1<f64>> =
496            PopulationBasedTraining::new(self.config.clone(), self.parameter_space.clone());
497        pbt.initialize_population();
498
499        let mut history = Vec::new();
500        let mut best_global_score = f64::NEG_INFINITY;
501        let mut best_global_params = None;
502        let mut best_global_estimator = None;
503
504        for iteration in 0..self.config.max_iterations {
505            // Evaluate all workers using cross-validation
506            let population_size = pbt.population.len();
507            for worker_id in 0..population_size {
508                let worker_params = pbt.population[worker_id].parameters.clone();
509
510                // Configure estimator with current parameters
511                let configured_estimator =
512                    (self.config_fn)(self.base_estimator.clone(), &worker_params)?;
513
514                // Perform cross-validation
515                let cv_splits = self.cv.split(x.nrows(), None);
516                let mut cv_scores = Vec::new();
517
518                for (train_idx, test_idx) in cv_splits {
519                    let x_train_view = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
520                    let y_train_view = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
521                    let x_test_view = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
522                    let y_test_view = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
523
524                    let x_train = Array2::from_shape_vec(
525                        (x_train_view.nrows(), x_train_view.ncols()),
526                        x_train_view.iter().copied().collect(),
527                    )?;
528                    let y_train = Array1::from_vec(y_train_view.iter().copied().collect());
529                    let x_test = Array2::from_shape_vec(
530                        (x_test_view.nrows(), x_test_view.ncols()),
531                        x_test_view.iter().copied().collect(),
532                    )?;
533                    let y_test = Array1::from_vec(y_test_view.iter().copied().collect());
534
535                    let fitted_estimator = configured_estimator.clone().fit(&x_train, &y_train)?;
536                    let score = fitted_estimator.score(&x_test, &y_test)?;
537                    cv_scores.push(score);
538                }
539
540                let mean_score = cv_scores
541                    .iter()
542                    .copied()
543                    .map(|x| x.to_f64().unwrap_or(0.0))
544                    .sum::<f64>()
545                    / cv_scores.len() as f64;
546                pbt.update_worker_score(worker_id, mean_score)?;
547
548                // Track global best
549                if mean_score > best_global_score {
550                    best_global_score = mean_score;
551                    best_global_params = Some(worker_params);
552
553                    // Train final model on full dataset
554                    let final_estimator = configured_estimator.fit(x, y)?;
555                    best_global_estimator = Some(final_estimator);
556                }
557            }
558
559            // Record statistics
560            let stats = pbt.get_statistics();
561            history.push(stats);
562
563            // Perform exploitation and exploration
564            if iteration > 0 && iteration % self.config.perturbation_interval == 0 {
565                pbt.exploit_and_explore()?;
566            }
567
568            // Check for convergence
569            if pbt.check_convergence() {
570                break;
571            }
572        }
573
574        Ok(PBTResult {
575            best_params: best_global_params.unwrap_or_else(PBTParameters::new),
576            best_score: best_global_score,
577            best_estimator: best_global_estimator,
578            history,
579            final_population: pbt.population,
580        })
581    }
582}
583
584#[allow(non_snake_case)]
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_pbt_parameter_space() {
591        let space = PBTParameterSpace::new()
592            .add_continuous("learning_rate", 0.001, 0.1)
593            .add_discrete("n_estimators", vec![10.0, 50.0, 100.0]);
594
595        let mut rng = StdRng::seed_from_u64(42);
596        let params = space.sample(&mut rng);
597
598        assert!(params.get("learning_rate").is_some());
599        assert!(params.get("n_estimators").is_some());
600    }
601
602    #[test]
603    fn test_pbt_worker() {
604        let mut params = PBTParameters::new();
605        params.set("learning_rate".to_string(), 0.01);
606
607        let worker = PBTWorker::<i32>::new(0, params);
608        assert_eq!(worker.id, 0);
609        assert_eq!(worker.current_score, f64::NEG_INFINITY);
610    }
611
612    #[test]
613    fn test_pbt_config() {
614        let config = PBTConfig::default();
615        assert_eq!(config.population_size, 20);
616        assert_eq!(config.perturbation_interval, 10);
617    }
618
619    #[test]
620    fn test_parameter_perturbation() {
621        let space = PBTParameterSpace::new().add_continuous("learning_rate", 0.001, 0.1);
622
623        let mut params = PBTParameters::new();
624        params.set("learning_rate".to_string(), 0.05);
625
626        let mut rng = StdRng::seed_from_u64(42);
627        let perturbed = space.perturb(&params, 0.1, &mut rng);
628
629        let original = params
630            .get("learning_rate")
631            .expect("operation should succeed");
632        let new_val = perturbed
633            .get("learning_rate")
634            .expect("operation should succeed");
635
636        // Should be different but within bounds
637        assert_ne!(original, new_val);
638        assert!(*new_val >= 0.001 && *new_val <= 0.1);
639    }
640}