1use crate::random::{
26 arrays::OptimizedArrayRandom, random_he_weights, random_xavier_weights, seeded_rng, thread_rng,
27 ParallelRng, Random, ThreadLocalRngPool,
28};
29use ::ndarray::{Array, Array1, Array2, Array3, Ix2};
30use rand_distr::{Normal, Uniform};
31use std::collections::HashMap;
32
33use crate::random::Rng;
34
35pub fn train_val_test_split<T: Clone>(
37 data: &[T],
38 train_ratio: f64,
39 val_ratio: f64,
40 test_ratio: f64,
41 seed: u64,
42) -> (Vec<T>, Vec<T>, Vec<T>) {
43 assert!(
44 (train_ratio + val_ratio + test_ratio - 1.0).abs() < 1e-6,
45 "Ratios must sum to 1.0"
46 );
47
48 let mut rng = seeded_rng(seed);
49 let mut indices: Vec<usize> = (0..data.len()).collect();
50
51 use rand::seq::SliceRandom;
52 indices.shuffle(&mut rng.rng);
53
54 let train_end = (data.len() as f64 * train_ratio) as usize;
55 let val_end = train_end + (data.len() as f64 * val_ratio) as usize;
56
57 let train_data = indices[..train_end]
58 .iter()
59 .map(|&i| data[i].clone())
60 .collect();
61 let val_data = indices[train_end..val_end]
62 .iter()
63 .map(|&i| data[i].clone())
64 .collect();
65 let test_data = indices[val_end..]
66 .iter()
67 .map(|&i| data[i].clone())
68 .collect();
69
70 (train_data, val_data, test_data)
71}
72
73pub fn train_test_split<T: Clone>(data: &[T], train_ratio: f64, seed: u64) -> (Vec<T>, Vec<T>) {
75 let test_ratio = 1.0 - train_ratio;
76 let (train, _, test) = train_val_test_split(data, train_ratio, 0.0, test_ratio, seed);
77 (train, test)
78}
79
80pub fn stratified_split<T: Clone, K: Clone + Eq + std::hash::Hash>(
82 data: &[(T, K)],
83 train_ratio: f64,
84 seed: u64,
85) -> (Vec<(T, K)>, Vec<(T, K)>) {
86 let mut rng = seeded_rng(seed);
87 use std::collections::HashMap;
88
89 let mut class_groups: HashMap<K, Vec<&(T, K)>> = HashMap::new();
91 for item in data {
92 class_groups.entry(item.1.clone()).or_default().push(item);
93 }
94
95 let mut train_data = Vec::new();
96 let mut test_data = Vec::new();
97
98 for (_, mut group) in class_groups {
100 use rand::seq::SliceRandom;
101 group.shuffle(&mut rng.rng);
102
103 let train_size = (group.len() as f64 * train_ratio) as usize;
104
105 for item in group.iter().take(train_size) {
106 train_data.push((*item).clone());
107 }
108 for item in group.iter().skip(train_size) {
109 test_data.push((*item).clone());
110 }
111 }
112
113 (train_data, test_data)
114}
115
116pub struct WeightInitializer;
118
119impl WeightInitializer {
120 pub fn xavier(fan_in: usize, fan_out: usize, seed: u64) -> Array2<f64> {
122 let mut rng = seeded_rng(seed);
123 random_xavier_weights(fan_in, fan_out, &mut rng)
124 }
125
126 pub fn he(fan_in: usize, fan_out: usize, seed: u64) -> Array2<f64> {
128 let mut rng = seeded_rng(seed);
129 random_he_weights(fan_in, fan_out, &mut rng)
130 }
131
132 pub fn lecun(fan_in: usize, fan_out: usize, seed: u64) -> Array2<f64> {
134 let mut rng = seeded_rng(seed);
135 let std_dev = (1.0 / fan_in as f64).sqrt();
136 Array::random_bulk(
137 Ix2(fan_out, fan_in),
138 Normal::new(0.0, std_dev).expect("Operation failed"),
139 &mut rng,
140 )
141 }
142
143 pub fn uniform(fan_in: usize, fan_out: usize, limit: f64, seed: u64) -> Array2<f64> {
145 let mut rng = seeded_rng(seed);
146 Array::random_bulk(
147 Ix2(fan_out, fan_in),
148 Uniform::new(-limit, limit).expect("Operation failed"),
149 &mut rng,
150 )
151 }
152
153 pub fn zeros(fan_in: usize, fan_out: usize) -> Array2<f64> {
155 Array2::zeros([fan_out, fan_in])
156 }
157
158 pub fn identity(size: usize) -> Array2<f64> {
160 Array2::eye(size)
161 }
162
163 pub fn orthogonal(fan_in: usize, fan_out: usize, seed: u64) -> Array2<f64> {
165 let mut rng = seeded_rng(seed);
166
167 let random_matrix = Array::random_bulk(
169 Ix2(fan_out, fan_in),
170 Normal::new(0.0, 1.0).expect("Operation failed"),
171 &mut rng,
172 );
173
174 let norm = (random_matrix.mapv(|x| x * x).sum() as f64).sqrt();
177 random_matrix / norm
178 }
179}
180
181pub struct CrossValidator {
183 k_folds: usize,
184 seed: u64,
185}
186
187impl CrossValidator {
188 pub fn new(k_folds: usize, seed: u64) -> Self {
190 Self { k_folds, seed }
191 }
192
193 pub fn split<T: Clone>(&self, data: &[T]) -> Vec<(Vec<T>, Vec<T>)> {
195 crate::random::scientific::cross_validation_splits(data, self.k_folds, self.seed)
196 }
197
198 pub fn leave_one_out<T: Clone>(&self, data: &[T]) -> Vec<(Vec<T>, Vec<T>)> {
200 (0..data.len())
201 .map(|i| {
202 let test_item = vec![data[i].clone()];
203 let train_data = data
204 .iter()
205 .enumerate()
206 .filter(|(idx, _)| *idx != i)
207 .map(|(_, item)| item.clone())
208 .collect();
209 (train_data, test_item)
210 })
211 .collect()
212 }
213
214 pub fn stratified_split<T: Clone, K: Clone + Eq + std::hash::Hash>(
216 &self,
217 data: &[(T, K)],
218 ) -> Vec<(Vec<(T, K)>, Vec<(T, K)>)> {
219 let mut rng = seeded_rng(self.seed);
220
221 let mut class_groups: HashMap<K, Vec<&(T, K)>> = HashMap::new();
223 for item in data {
224 class_groups.entry(item.1.clone()).or_default().push(item);
225 }
226
227 let mut folds = vec![Vec::new(); self.k_folds];
228
229 for (_, mut group) in class_groups {
231 use rand::seq::SliceRandom;
232 group.shuffle(&mut rng.rng);
233
234 for (i, item) in group.iter().enumerate() {
235 let fold_idx = i % self.k_folds;
236 folds[fold_idx].push((*item).clone());
237 }
238 }
239
240 (0..self.k_folds)
242 .map(|test_fold| {
243 let test_data = folds[test_fold].clone();
244 let train_data = folds
245 .iter()
246 .enumerate()
247 .filter(|(i, _)| *i != test_fold)
248 .flat_map(|(_, fold)| fold.iter().cloned())
249 .collect();
250 (train_data, test_data)
251 })
252 .collect()
253 }
254}
255
256pub struct DataAugmentor {
258 seed: u64,
259}
260
261impl DataAugmentor {
262 pub fn new(seed: u64) -> Self {
264 Self { seed }
265 }
266
267 pub fn add_noise(&self, data: &Array1<f64>, noise_std: f64) -> Array1<f64> {
269 let mut rng = seeded_rng(self.seed);
270 let noise_dist = Normal::new(0.0, noise_std).expect("Operation failed");
271
272 data + &Array::random_bulk(data.raw_dim(), noise_dist, &mut rng)
273 }
274
275 pub fn random_dropout(&self, data: &Array1<f64>, dropout_rate: f64) -> Array1<f64> {
277 let mut rng = seeded_rng(self.seed);
278 let keep_prob = 1.0 - dropout_rate;
279
280 data.mapv(|x| {
281 if rng.sample(Uniform::new(0.0, 1.0).expect("Operation failed")) < keep_prob {
282 x / keep_prob } else {
284 0.0
285 }
286 })
287 }
288
289 pub fn random_scale(&self, data: &Array1<f64>, scale_range: (f64, f64)) -> Array1<f64> {
291 let mut rng = seeded_rng(self.seed);
292 let scale_factor =
293 rng.sample(Uniform::new(scale_range.0, scale_range.1).expect("Operation failed"));
294 data * scale_factor
295 }
296
297 pub fn random_rotation_2d(&self, data: &Array1<f64>, max_angle: f64) -> Array1<f64> {
299 if data.len() != 2 {
300 return data.clone();
301 }
302
303 let mut rng = seeded_rng(self.seed);
304 let angle = rng.sample(Uniform::new(-max_angle, max_angle).expect("Operation failed"));
305
306 let cos_angle = angle.cos();
307 let sin_angle = angle.sin();
308
309 let x = data[0];
310 let y = data[1];
311
312 Array1::from(vec![
313 x * cos_angle - y * sin_angle,
314 x * sin_angle + y * cos_angle,
315 ])
316 }
317}
318
319pub struct BatchGenerator<T> {
321 data: Vec<T>,
322 batch_size: usize,
323 shuffle: bool,
324 seed: u64,
325 current_epoch: usize,
326}
327
328impl<T: Clone> BatchGenerator<T> {
329 pub fn new(data: Vec<T>, batch_size: usize, shuffle: bool, seed: u64) -> Self {
331 Self {
332 data,
333 batch_size,
334 shuffle,
335 seed,
336 current_epoch: 0,
337 }
338 }
339
340 pub fn epoch(&mut self) -> Vec<Vec<T>> {
342 let mut epoch_data = self.data.clone();
343
344 if self.shuffle {
345 let mut rng = seeded_rng(self.seed + self.current_epoch as u64);
346 use rand::seq::SliceRandom;
347 epoch_data.shuffle(&mut rng.rng);
348 }
349
350 let batches = epoch_data
351 .chunks(self.batch_size)
352 .map(|chunk| chunk.to_vec())
353 .collect();
354
355 self.current_epoch += 1;
356 batches
357 }
358
359 pub fn batches_per_epoch(&self) -> usize {
361 (self.data.len() + self.batch_size - 1) / self.batch_size
362 }
363
364 pub fn reset(&mut self) {
366 self.current_epoch = 0;
367 }
368}
369
370pub mod hyperopt {
372 use super::*;
373
374 pub struct RandomSearch {
376 seed: u64,
377 param_ranges: HashMap<String, (f64, f64)>,
378 }
379
380 impl RandomSearch {
381 pub fn new(seed: u64) -> Self {
383 Self {
384 seed,
385 param_ranges: HashMap::new(),
386 }
387 }
388
389 pub fn add_param_range(mut self, name: String, min: f64, max: f64) -> Self {
391 self.param_ranges.insert(name, (min, max));
392 self
393 }
394
395 pub fn sample_params(&self, n_trials: usize) -> Vec<HashMap<String, f64>> {
397 let mut rng = seeded_rng(self.seed);
398
399 (0..n_trials)
400 .map(|_| {
401 self.param_ranges
402 .iter()
403 .map(|(name, (min, max))| {
404 let value =
405 rng.sample(Uniform::new(*min, *max).expect("Operation failed"));
406 (name.clone(), value)
407 })
408 .collect()
409 })
410 .collect()
411 }
412 }
413
414 pub struct GridSearch {
416 param_grids: HashMap<String, Vec<f64>>,
417 }
418
419 impl GridSearch {
420 pub fn new() -> Self {
422 Self {
423 param_grids: HashMap::new(),
424 }
425 }
426
427 pub fn add_param_grid(mut self, name: String, values: Vec<f64>) -> Self {
429 self.param_grids.insert(name, values);
430 self
431 }
432
433 pub fn all_combinations(&self) -> Vec<HashMap<String, f64>> {
435 let param_names: Vec<String> = self.param_grids.keys().cloned().collect();
436 let param_values: Vec<Vec<f64>> = param_names
437 .iter()
438 .map(|name| self.param_grids[name].clone())
439 .collect();
440
441 let combinations =
442 crate::random::scientific::ExperimentalDesign::factorial_design(¶m_values);
443
444 combinations
445 .into_iter()
446 .map(|combo| {
447 param_names
448 .iter()
449 .zip(combo.iter())
450 .map(|(name, &value)| (name.clone(), value))
451 .collect()
452 })
453 .collect()
454 }
455 }
456}
457
458pub mod ensemble {
460 use super::*;
461
462 pub fn bootstrap_samples<T: Clone>(
464 data: &[T],
465 n_estimators: usize,
466 sample_ratio: f64,
467 seed: u64,
468 ) -> Vec<Vec<T>> {
469 let pool = ThreadLocalRngPool::new(seed);
470 let sample_size = (data.len() as f64 * sample_ratio) as usize;
471
472 (0..n_estimators)
473 .map(|i| {
474 pool.with_rng(|rng| {
475 (0..sample_size)
476 .map(|_| {
477 let idx = rng.random_range(0..data.len());
478 data[idx].clone()
479 })
480 .collect()
481 })
482 })
483 .collect()
484 }
485
486 pub fn random_subspace_features(
488 n_features: usize,
489 max_features: usize,
490 n_estimators: usize,
491 seed: u64,
492 ) -> Vec<Vec<usize>> {
493 let mut rng = seeded_rng(seed);
494
495 (0..n_estimators)
496 .map(|_| {
497 let mut features: Vec<usize> = (0..n_features).collect();
498 use rand::seq::SliceRandom;
499 features.shuffle(&mut rng.rng);
500 features.into_iter().take(max_features).collect()
501 })
502 .collect()
503 }
504}
505
506pub mod active_learning {
508 use super::*;
509
510 pub struct UncertaintySampler {
512 seed: u64,
513 }
514
515 impl UncertaintySampler {
516 pub fn new(seed: u64) -> Self {
517 Self { seed }
518 }
519
520 pub fn random_sample<T: Clone>(&self, candidates: &[T], n_samples: usize) -> Vec<T> {
522 let mut rng = seeded_rng(self.seed);
523 use rand::seq::SliceRandom;
524
525 let mut indices: Vec<usize> = (0..candidates.len()).collect();
527 indices.shuffle(&mut rng.rng);
528 indices
529 .into_iter()
530 .take(n_samples.min(candidates.len()))
531 .map(|i| candidates[i].clone())
532 .collect()
533 }
534
535 pub fn entropy_sampling<T: Clone>(
537 &self,
538 candidates: &[(T, f64)], n_samples: usize,
540 ) -> Vec<T> {
541 let mut scored_candidates = candidates.to_vec();
542 scored_candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
543
544 scored_candidates
545 .into_iter()
546 .take(n_samples)
547 .map(|(data, _)| data)
548 .collect()
549 }
550 }
551}