1use crate::{CrossValidator, KFold, Scoring};
12use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Dim, OwnedRepr};
13use scirs2_core::numeric::ToPrimitive;
14use scirs2_core::random::rngs::StdRng;
15use scirs2_core::random::{RngExt, SeedableRng};
16#[cfg(feature = "serde")]
17use serde::{Deserialize, Serialize};
18use sklears_core::{
19 error::{Result, SklearsError},
20 traits::{Fit, Predict, Score},
21};
22use std::collections::HashMap;
23use std::fmt::Debug;
24use std::marker::PhantomData;
25
26#[derive(Debug, Clone)]
28pub struct PBTConfig {
29 pub population_size: usize,
31 pub perturbation_interval: usize,
33 pub replacement_fraction: f64,
35 pub perturbation_factor: f64,
37 pub max_iterations: usize,
39 pub patience: Option<usize>,
41 pub random_state: Option<u64>,
43}
44
45impl Default for PBTConfig {
46 fn default() -> Self {
47 Self {
48 population_size: 20,
49 perturbation_interval: 10,
50 replacement_fraction: 0.25,
51 perturbation_factor: 0.2,
52 max_iterations: 100,
53 patience: Some(10),
54 random_state: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct PBTParameterSpace {
62 pub continuous: HashMap<String, (f64, f64)>,
64 pub discrete: HashMap<String, Vec<f64>>,
66}
67
68impl Default for PBTParameterSpace {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74impl PBTParameterSpace {
75 pub fn new() -> Self {
76 Self {
77 continuous: HashMap::new(),
78 discrete: HashMap::new(),
79 }
80 }
81
82 pub fn add_continuous(mut self, name: &str, low: f64, high: f64) -> Self {
84 self.continuous.insert(name.to_string(), (low, high));
85 self
86 }
87
88 pub fn add_discrete(mut self, name: &str, values: Vec<f64>) -> Self {
90 self.discrete.insert(name.to_string(), values);
91 self
92 }
93
94 pub fn sample<R: RngExt>(&self, rng: &mut R) -> PBTParameters {
96 let mut params = PBTParameters::new();
97
98 for (name, (low, high)) in &self.continuous {
99 let value = rng.random_range(*low..*high + 1.0);
100 params.set(name.clone(), value);
101 }
102
103 for (name, values) in &self.discrete {
104 let idx = rng.random_range(0..values.len());
105 params.set(name.clone(), values[idx]);
106 }
107
108 params
109 }
110
111 pub fn perturb<R: RngExt>(
113 &self,
114 params: &PBTParameters,
115 factor: f64,
116 rng: &mut R,
117 ) -> PBTParameters {
118 let mut new_params = params.clone();
119
120 for (name, (low, high)) in &self.continuous {
121 if let Some(¤t_value) = params.get(name) {
122 let range = high - low;
123 let perturbation = rng.random_range(-factor..factor + 1.0) * range;
124 let new_value = (current_value + perturbation).clamp(*low, *high);
125 new_params.set(name.clone(), new_value);
126 }
127 }
128
129 for (name, values) in &self.discrete {
130 if rng.random_bool(factor.min(1.0)) {
131 let idx = rng.random_range(0..values.len());
133 new_params.set(name.clone(), values[idx]);
134 }
135 }
136
137 new_params
138 }
139}
140
141#[derive(Debug, Clone)]
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144pub struct PBTParameters {
145 params: HashMap<String, f64>,
146}
147
148impl Default for PBTParameters {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl PBTParameters {
155 pub fn new() -> Self {
156 Self {
157 params: HashMap::new(),
158 }
159 }
160
161 pub fn set(&mut self, name: String, value: f64) {
162 self.params.insert(name, value);
163 }
164
165 pub fn get(&self, name: &str) -> Option<&f64> {
166 self.params.get(name)
167 }
168
169 pub fn iter(&self) -> impl Iterator<Item = (&String, &f64)> {
170 self.params.iter()
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct PBTWorker<E> {
177 pub id: usize,
179 pub parameters: PBTParameters,
181 pub estimator: Option<E>,
183 pub score_history: Vec<f64>,
185 pub current_score: f64,
187 pub step: usize,
189}
190
191impl<E> PBTWorker<E> {
192 pub fn new(id: usize, parameters: PBTParameters) -> Self {
193 Self {
194 id,
195 parameters,
196 estimator: None,
197 score_history: Vec::new(),
198 current_score: f64::NEG_INFINITY,
199 step: 0,
200 }
201 }
202
203 pub fn should_be_replaced(&self, threshold_percentile: f64, all_scores: &[f64]) -> bool {
205 if all_scores.is_empty() {
206 return false;
207 }
208
209 let mut sorted_scores = all_scores.to_vec();
210 sorted_scores.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
211
212 let threshold_idx = (threshold_percentile * sorted_scores.len() as f64) as usize;
213 let threshold = sorted_scores
214 .get(threshold_idx)
215 .unwrap_or(&f64::NEG_INFINITY);
216
217 self.current_score < *threshold
218 }
219}
220
221pub struct PopulationBasedTraining<E, X, Y> {
223 config: PBTConfig,
224 parameter_space: PBTParameterSpace,
225 population: Vec<PBTWorker<E>>,
226 rng: StdRng,
227 _phantom: PhantomData<(X, Y)>,
228}
229
230impl<E, X, Y> PopulationBasedTraining<E, X, Y>
231where
232 E: Clone + Debug,
233{
234 pub fn new(config: PBTConfig, parameter_space: PBTParameterSpace) -> Self {
236 let rng = match config.random_state {
237 Some(seed) => StdRng::seed_from_u64(seed),
238 None => StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
239 };
240
241 Self {
242 config,
243 parameter_space,
244 population: Vec::new(),
245 rng,
246 _phantom: PhantomData,
247 }
248 }
249
250 pub fn initialize_population(&mut self) {
252 self.population.clear();
253
254 for i in 0..self.config.population_size {
255 let parameters = self.parameter_space.sample(&mut self.rng);
256 let worker = PBTWorker::new(i, parameters);
257 self.population.push(worker);
258 }
259 }
260
261 pub fn population(&self) -> &[PBTWorker<E>] {
263 &self.population
264 }
265
266 pub fn best_worker(&self) -> Option<&PBTWorker<E>> {
268 self.population.iter().max_by(|a, b| {
269 a.current_score
270 .partial_cmp(&b.current_score)
271 .expect("operation should succeed")
272 })
273 }
274
275 pub fn update_worker_score(&mut self, worker_id: usize, score: f64) -> Result<()> {
277 let worker = self
278 .population
279 .get_mut(worker_id)
280 .ok_or_else(|| SklearsError::InvalidInput(format!("Worker {} not found", worker_id)))?;
281
282 worker.current_score = score;
283 worker.score_history.push(score);
284 worker.step += 1;
285
286 Ok(())
287 }
288
289 pub fn exploit_and_explore(&mut self) -> Result<()> {
291 let population_size = self.population.len();
292 if population_size == 0 {
293 return Err(SklearsError::InvalidOperation(
294 "Empty population".to_string(),
295 ));
296 }
297
298 let scores: Vec<f64> = self.population.iter().map(|w| w.current_score).collect();
300
301 let num_to_replace = (self.config.replacement_fraction * population_size as f64) as usize;
303 let mut worker_scores: Vec<(usize, f64)> =
304 scores.iter().enumerate().map(|(i, &s)| (i, s)).collect();
305 worker_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("operation should succeed"));
306
307 let worst_indices: Vec<usize> = worker_scores
309 .iter()
310 .take(num_to_replace)
311 .map(|(i, _)| *i)
312 .collect();
313 let best_indices: Vec<usize> = worker_scores
314 .iter()
315 .rev()
316 .take(num_to_replace)
317 .map(|(i, _)| *i)
318 .collect();
319
320 for (&worst_idx, &best_idx) in worst_indices.iter().zip(best_indices.iter()) {
322 if worst_idx != best_idx {
323 let best_params = self.population[best_idx].parameters.clone();
325
326 let perturbed_params = self.parameter_space.perturb(
328 &best_params,
329 self.config.perturbation_factor,
330 &mut self.rng,
331 );
332
333 self.population[worst_idx].parameters = perturbed_params;
335 self.population[worst_idx].current_score = f64::NEG_INFINITY;
336 self.population[worst_idx].score_history.clear();
337 self.population[worst_idx].step = 0;
338 self.population[worst_idx].estimator = None;
339 }
340 }
341
342 Ok(())
343 }
344
345 pub fn check_convergence(&self) -> bool {
347 if let Some(patience) = self.config.patience {
348 for worker in &self.population {
349 if worker.score_history.len() >= patience {
350 let recent_scores =
351 &worker.score_history[worker.score_history.len() - patience..];
352 let max_recent = recent_scores
353 .iter()
354 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
355 let current = worker.current_score;
356
357 if current <= max_recent {
359 return true;
360 }
361 }
362 }
363 }
364 false
365 }
366
367 pub fn get_statistics(&self) -> PBTStatistics {
369 let scores: Vec<f64> = self.population.iter().map(|w| w.current_score).collect();
370
371 let best_score = scores.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
372 let worst_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
373 let mean_score = scores.iter().sum::<f64>() / scores.len() as f64;
374
375 let variance = scores
376 .iter()
377 .map(|&x| (x - mean_score).powi(2))
378 .sum::<f64>()
379 / scores.len() as f64;
380 let std_dev = variance.sqrt();
381
382 PBTStatistics {
383 generation: self.population.iter().map(|w| w.step).max().unwrap_or(0),
384 population_size: self.population.len(),
385 best_score,
386 worst_score,
387 mean_score,
388 std_dev,
389 }
390 }
391}
392
393#[derive(Debug, Clone)]
395#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
396pub struct PBTStatistics {
397 pub generation: usize,
398 pub population_size: usize,
399 pub best_score: f64,
400 pub worst_score: f64,
401 pub mean_score: f64,
402 pub std_dev: f64,
403}
404
405#[derive(Debug, Clone)]
407pub struct PBTResult<E> {
408 pub best_params: PBTParameters,
410 pub best_score: f64,
412 pub best_estimator: Option<E>,
414 pub history: Vec<PBTStatistics>,
416 pub final_population: Vec<PBTWorker<E>>,
418}
419
420pub type PBTConfigFn<E> = Box<dyn Fn(E, &PBTParameters) -> Result<E>>;
422
423pub struct PopulationBasedTrainingCV<E, X, Y> {
425 base_estimator: E,
426 config: PBTConfig,
427 parameter_space: PBTParameterSpace,
428 config_fn: PBTConfigFn<E>,
429 cv: Box<dyn CrossValidator>,
430 scoring: Option<Scoring>,
431 _phantom: PhantomData<(X, Y)>,
432}
433
434impl<E, X, Y> PopulationBasedTrainingCV<E, X, Y>
435where
436 E: Clone + Debug,
437{
438 pub fn new(
440 base_estimator: E,
441 config: PBTConfig,
442 parameter_space: PBTParameterSpace,
443 config_fn: PBTConfigFn<E>,
444 ) -> Self {
445 Self {
446 base_estimator,
447 config,
448 parameter_space,
449 config_fn,
450 cv: Box::new(KFold::new(5)),
451 scoring: None,
452 _phantom: PhantomData,
453 }
454 }
455
456 pub fn cv<CV>(mut self, cv: CV) -> Self
458 where
459 CV: CrossValidator + 'static,
460 {
461 self.cv = Box::new(cv);
462 self
463 }
464
465 pub fn scoring(mut self, scoring: Scoring) -> Self {
467 self.scoring = Some(scoring);
468 self
469 }
470}
471
472impl<E> PopulationBasedTrainingCV<E, Array2<f64>, Array1<f64>>
473where
474 E: Clone
475 + Debug
476 + Fit<
477 ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
478 ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
479 > + Score<
480 ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
481 ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
482 >,
483 E::Fitted: Clone
484 + Debug
485 + Predict<
486 ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
487 ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
488 > + Score<
489 ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>, f64>,
490 ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>, f64>,
491 >,
492{
493 pub fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<PBTResult<E::Fitted>> {
495 let mut pbt: PopulationBasedTraining<E::Fitted, Array2<f64>, Array1<f64>> =
496 PopulationBasedTraining::new(self.config.clone(), self.parameter_space.clone());
497 pbt.initialize_population();
498
499 let mut history = Vec::new();
500 let mut best_global_score = f64::NEG_INFINITY;
501 let mut best_global_params = None;
502 let mut best_global_estimator = None;
503
504 for iteration in 0..self.config.max_iterations {
505 let population_size = pbt.population.len();
507 for worker_id in 0..population_size {
508 let worker_params = pbt.population[worker_id].parameters.clone();
509
510 let configured_estimator =
512 (self.config_fn)(self.base_estimator.clone(), &worker_params)?;
513
514 let cv_splits = self.cv.split(x.nrows(), None);
516 let mut cv_scores = Vec::new();
517
518 for (train_idx, test_idx) in cv_splits {
519 let x_train_view = x.select(scirs2_core::ndarray::Axis(0), &train_idx);
520 let y_train_view = y.select(scirs2_core::ndarray::Axis(0), &train_idx);
521 let x_test_view = x.select(scirs2_core::ndarray::Axis(0), &test_idx);
522 let y_test_view = y.select(scirs2_core::ndarray::Axis(0), &test_idx);
523
524 let x_train = Array2::from_shape_vec(
525 (x_train_view.nrows(), x_train_view.ncols()),
526 x_train_view.iter().copied().collect(),
527 )?;
528 let y_train = Array1::from_vec(y_train_view.iter().copied().collect());
529 let x_test = Array2::from_shape_vec(
530 (x_test_view.nrows(), x_test_view.ncols()),
531 x_test_view.iter().copied().collect(),
532 )?;
533 let y_test = Array1::from_vec(y_test_view.iter().copied().collect());
534
535 let fitted_estimator = configured_estimator.clone().fit(&x_train, &y_train)?;
536 let score = fitted_estimator.score(&x_test, &y_test)?;
537 cv_scores.push(score);
538 }
539
540 let mean_score = cv_scores
541 .iter()
542 .copied()
543 .map(|x| x.to_f64().unwrap_or(0.0))
544 .sum::<f64>()
545 / cv_scores.len() as f64;
546 pbt.update_worker_score(worker_id, mean_score)?;
547
548 if mean_score > best_global_score {
550 best_global_score = mean_score;
551 best_global_params = Some(worker_params);
552
553 let final_estimator = configured_estimator.fit(x, y)?;
555 best_global_estimator = Some(final_estimator);
556 }
557 }
558
559 let stats = pbt.get_statistics();
561 history.push(stats);
562
563 if iteration > 0 && iteration % self.config.perturbation_interval == 0 {
565 pbt.exploit_and_explore()?;
566 }
567
568 if pbt.check_convergence() {
570 break;
571 }
572 }
573
574 Ok(PBTResult {
575 best_params: best_global_params.unwrap_or_else(PBTParameters::new),
576 best_score: best_global_score,
577 best_estimator: best_global_estimator,
578 history,
579 final_population: pbt.population,
580 })
581 }
582}
583
584#[allow(non_snake_case)]
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_pbt_parameter_space() {
591 let space = PBTParameterSpace::new()
592 .add_continuous("learning_rate", 0.001, 0.1)
593 .add_discrete("n_estimators", vec![10.0, 50.0, 100.0]);
594
595 let mut rng = StdRng::seed_from_u64(42);
596 let params = space.sample(&mut rng);
597
598 assert!(params.get("learning_rate").is_some());
599 assert!(params.get("n_estimators").is_some());
600 }
601
602 #[test]
603 fn test_pbt_worker() {
604 let mut params = PBTParameters::new();
605 params.set("learning_rate".to_string(), 0.01);
606
607 let worker = PBTWorker::<i32>::new(0, params);
608 assert_eq!(worker.id, 0);
609 assert_eq!(worker.current_score, f64::NEG_INFINITY);
610 }
611
612 #[test]
613 fn test_pbt_config() {
614 let config = PBTConfig::default();
615 assert_eq!(config.population_size, 20);
616 assert_eq!(config.perturbation_interval, 10);
617 }
618
619 #[test]
620 fn test_parameter_perturbation() {
621 let space = PBTParameterSpace::new().add_continuous("learning_rate", 0.001, 0.1);
622
623 let mut params = PBTParameters::new();
624 params.set("learning_rate".to_string(), 0.05);
625
626 let mut rng = StdRng::seed_from_u64(42);
627 let perturbed = space.perturb(¶ms, 0.1, &mut rng);
628
629 let original = params
630 .get("learning_rate")
631 .expect("operation should succeed");
632 let new_val = perturbed
633 .get("learning_rate")
634 .expect("operation should succeed");
635
636 assert_ne!(original, new_val);
638 assert!(*new_val >= 0.001 && *new_val <= 0.1);
639 }
640}