1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::random::essentials::Uniform;
8use scirs2_core::random::{seeded_rng, Distribution};
9use sklears_core::prelude::SklearsError;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct KFold {
15 pub n_splits: usize,
17 pub shuffle: bool,
19 pub random_state: Option<u64>,
21}
22
23impl KFold {
24 pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
26 Self {
27 n_splits,
28 shuffle,
29 random_state,
30 }
31 }
32
33 pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
35 if n_samples < self.n_splits {
36 return Err(SklearsError::InvalidInput(format!(
37 "Cannot split {} samples into {} folds",
38 n_samples, self.n_splits
39 )));
40 }
41
42 let mut indices: Vec<usize> = (0..n_samples).collect();
43
44 if self.shuffle {
45 use std::time::{SystemTime, UNIX_EPOCH};
46
47 let seed = self.random_state.unwrap_or_else(|| {
48 SystemTime::now()
49 .duration_since(UNIX_EPOCH)
50 .unwrap()
51 .as_secs()
52 });
53
54 let mut rng = seeded_rng(seed);
55
56 for i in (1..indices.len()).rev() {
58 let uniform = Uniform::new(0, i + 1).unwrap();
59 let j = uniform.sample(&mut rng);
60 indices.swap(i, j);
61 }
62 }
63
64 let fold_size = n_samples / self.n_splits;
65 let mut splits = Vec::new();
66
67 for fold_idx in 0..self.n_splits {
68 let test_start = fold_idx * fold_size;
69 let test_end = if fold_idx == self.n_splits - 1 {
70 n_samples
71 } else {
72 (fold_idx + 1) * fold_size
73 };
74
75 let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
76 let train_indices: Vec<usize> = indices[..test_start]
77 .iter()
78 .chain(&indices[test_end..])
79 .copied()
80 .collect();
81
82 splits.push((train_indices, test_indices));
83 }
84
85 Ok(splits)
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct StratifiedKFold {
92 pub n_splits: usize,
94 pub shuffle: bool,
96 pub random_state: Option<u64>,
98}
99
100impl StratifiedKFold {
101 pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
103 Self {
104 n_splits,
105 shuffle,
106 random_state,
107 }
108 }
109
110 pub fn split(&self, y: &Array1<i32>) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
112 let n_samples = y.len();
113
114 if n_samples < self.n_splits {
115 return Err(SklearsError::InvalidInput(format!(
116 "Cannot split {} samples into {} folds",
117 n_samples, self.n_splits
118 )));
119 }
120
121 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
123 for (idx, &label) in y.iter().enumerate() {
124 class_indices.entry(label).or_default().push(idx);
125 }
126
127 if self.shuffle {
129 use std::time::{SystemTime, UNIX_EPOCH};
130
131 let seed = self.random_state.unwrap_or_else(|| {
132 SystemTime::now()
133 .duration_since(UNIX_EPOCH)
134 .unwrap()
135 .as_secs()
136 });
137
138 let mut rng = seeded_rng(seed);
139
140 for indices in class_indices.values_mut() {
141 for i in (1..indices.len()).rev() {
142 let uniform = Uniform::new(0, i + 1).unwrap();
143 let j = uniform.sample(&mut rng);
144 indices.swap(i, j);
145 }
146 }
147 }
148
149 let mut splits: Vec<(Vec<usize>, Vec<usize>)> = vec![];
151
152 for fold_idx in 0..self.n_splits {
153 let mut train_indices = Vec::new();
154 let mut test_indices = Vec::new();
155
156 for indices in class_indices.values() {
157 let fold_size = indices.len() / self.n_splits;
158 let test_start = fold_idx * fold_size;
159 let test_end = if fold_idx == self.n_splits - 1 {
160 indices.len()
161 } else {
162 (fold_idx + 1) * fold_size
163 };
164
165 test_indices.extend(&indices[test_start..test_end]);
166 train_indices.extend(&indices[..test_start]);
167 train_indices.extend(&indices[test_end..]);
168 }
169
170 splits.push((train_indices, test_indices));
171 }
172
173 Ok(splits)
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct CVScore {
180 pub mean: f64,
182 pub std: f64,
184 pub scores: Vec<f64>,
186}
187
188#[derive(Debug, Clone)]
190pub struct ParameterGrid {
191 parameters: HashMap<String, Vec<f64>>,
192}
193
194impl ParameterGrid {
195 pub fn new() -> Self {
197 Self {
198 parameters: HashMap::new(),
199 }
200 }
201
202 pub fn add_parameter(mut self, name: String, values: Vec<f64>) -> Self {
204 self.parameters.insert(name, values);
205 self
206 }
207
208 pub fn combinations(&self) -> Vec<HashMap<String, f64>> {
210 if self.parameters.is_empty() {
211 return vec![HashMap::new()];
212 }
213
214 let mut result = vec![HashMap::new()];
215
216 for (param_name, param_values) in &self.parameters {
217 let mut new_result = Vec::new();
218
219 for combination in &result {
220 for &value in param_values {
221 let mut new_combination = combination.clone();
222 new_combination.insert(param_name.clone(), value);
223 new_result.push(new_combination);
224 }
225 }
226
227 result = new_result;
228 }
229
230 result
231 }
232
233 pub fn n_combinations(&self) -> usize {
235 if self.parameters.is_empty() {
236 return 0;
237 }
238
239 self.parameters.values().map(|v| v.len()).product()
240 }
241}
242
243impl Default for ParameterGrid {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct ParameterDistribution {
252 parameters: HashMap<String, (f64, f64)>, }
254
255impl ParameterDistribution {
256 pub fn new() -> Self {
258 Self {
259 parameters: HashMap::new(),
260 }
261 }
262
263 pub fn add_parameter(mut self, name: String, min: f64, max: f64) -> Self {
265 self.parameters.insert(name, (min, max));
266 self
267 }
268
269 pub fn sample(&self, n_iter: usize, random_state: Option<u64>) -> Vec<HashMap<String, f64>> {
271 use std::time::{SystemTime, UNIX_EPOCH};
272
273 let seed = random_state.unwrap_or_else(|| {
274 SystemTime::now()
275 .duration_since(UNIX_EPOCH)
276 .unwrap()
277 .as_secs()
278 });
279
280 let mut rng = seeded_rng(seed);
281
282 (0..n_iter)
283 .map(|_| {
284 self.parameters
285 .iter()
286 .map(|(name, &(min, max))| {
287 let uniform = Uniform::new_inclusive(min, max).unwrap();
288 (name.clone(), uniform.sample(&mut rng))
289 })
290 .collect()
291 })
292 .collect()
293 }
294}
295
296impl Default for ParameterDistribution {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302pub trait PreprocessingMetric {
304 fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64;
306}
307
308pub struct VariancePreservationMetric;
310
311impl PreprocessingMetric for VariancePreservationMetric {
312 fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
313 let mut total_variance_ratio = 0.0;
314
315 for j in 0..x_original.ncols() {
316 let original_col = x_original.column(j);
317 let transformed_col = x_transformed.column(j);
318
319 let original_var = Self::compute_variance(original_col);
320 let transformed_var = Self::compute_variance(transformed_col);
321
322 if original_var > 1e-10 {
323 total_variance_ratio += transformed_var / original_var;
324 }
325 }
326
327 total_variance_ratio / x_original.ncols() as f64
328 }
329}
330
331impl VariancePreservationMetric {
332 fn compute_variance<'a, I>(values: I) -> f64
333 where
334 I: IntoIterator<Item = &'a f64>,
335 {
336 let vals: Vec<f64> = values
337 .into_iter()
338 .copied()
339 .filter(|v| !v.is_nan())
340 .collect();
341
342 if vals.is_empty() {
343 return 0.0;
344 }
345
346 let mean = vals.iter().sum::<f64>() / vals.len() as f64;
347 vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / vals.len() as f64
348 }
349}
350
351pub struct InformationPreservationMetric;
353
354impl PreprocessingMetric for InformationPreservationMetric {
355 fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
356 let mut total_correlation = 0.0;
358 let mut count = 0;
359
360 for j in 0..x_original.ncols().min(x_transformed.ncols()) {
361 let corr = Self::compute_correlation(x_original, x_transformed, j);
362 if !corr.is_nan() {
363 total_correlation += corr.abs();
364 count += 1;
365 }
366 }
367
368 if count > 0 {
369 total_correlation / count as f64
370 } else {
371 0.0
372 }
373 }
374}
375
376impl InformationPreservationMetric {
377 fn compute_correlation(x1: &Array2<f64>, x2: &Array2<f64>, col_idx: usize) -> f64 {
378 let col1 = x1.column(col_idx);
379 let col2 = x2.column(col_idx);
380
381 let pairs: Vec<(f64, f64)> = col1
382 .iter()
383 .zip(col2.iter())
384 .filter(|(a, b)| !a.is_nan() && !b.is_nan())
385 .map(|(&a, &b)| (a, b))
386 .collect();
387
388 if pairs.len() < 2 {
389 return 0.0;
390 }
391
392 let mean1 = pairs.iter().map(|(a, _)| a).sum::<f64>() / pairs.len() as f64;
393 let mean2 = pairs.iter().map(|(_, b)| b).sum::<f64>() / pairs.len() as f64;
394
395 let mut cov = 0.0;
396 let mut var1 = 0.0;
397 let mut var2 = 0.0;
398
399 for (a, b) in &pairs {
400 let d1 = a - mean1;
401 let d2 = b - mean2;
402 cov += d1 * d2;
403 var1 += d1 * d1;
404 var2 += d2 * d2;
405 }
406
407 if var1 < 1e-10 || var2 < 1e-10 {
408 return 0.0;
409 }
410
411 cov / (var1 * var2).sqrt()
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use scirs2_core::random::essentials::Normal;
419 use scirs2_core::random::{seeded_rng, Distribution};
420
421 fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
422 let mut rng = seeded_rng(seed);
423 let normal = Normal::new(0.0, 1.0).unwrap();
424
425 let data: Vec<f64> = (0..nrows * ncols)
426 .map(|_| normal.sample(&mut rng))
427 .collect();
428
429 Array2::from_shape_vec((nrows, ncols), data).unwrap()
430 }
431
432 #[test]
433 fn test_kfold_split() {
434 let kfold = KFold::new(5, false, Some(42));
435 let splits = kfold.split(100).unwrap();
436
437 assert_eq!(splits.len(), 5);
438
439 for (train, test) in &splits {
440 assert!(train.len() > 0);
441 assert!(test.len() > 0);
442 assert_eq!(train.len() + test.len(), 100);
443 }
444 }
445
446 #[test]
447 fn test_kfold_shuffle() {
448 let kfold1 = KFold::new(3, true, Some(42));
449 let splits1 = kfold1.split(30).unwrap();
450
451 let kfold2 = KFold::new(3, false, None);
452 let splits2 = kfold2.split(30).unwrap();
453
454 let different = splits1[0].0 != splits2[0].0;
456 assert!(different);
457 }
458
459 #[test]
460 fn test_stratified_kfold() {
461 let y = Array1::from_vec(vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]);
462
463 let stratified = StratifiedKFold::new(3, false, Some(42));
464 let splits = stratified.split(&y).unwrap();
465
466 assert_eq!(splits.len(), 3);
467
468 for (train_indices, test_indices) in &splits {
470 let _train_classes: Vec<i32> = train_indices.iter().map(|&i| y[i]).collect();
471 let test_classes: Vec<i32> = test_indices.iter().map(|&i| y[i]).collect();
472
473 let test_0 = test_classes.iter().filter(|&&c| c == 0).count();
475 let test_1 = test_classes.iter().filter(|&&c| c == 1).count();
476 let test_2 = test_classes.iter().filter(|&&c| c == 2).count();
477
478 assert!(test_0 > 0);
480 assert!(test_1 > 0);
481 assert!(test_2 > 0);
482 }
483 }
484
485 #[test]
486 fn test_parameter_grid() {
487 let grid = ParameterGrid::new()
488 .add_parameter("alpha".to_string(), vec![0.1, 1.0, 10.0])
489 .add_parameter("beta".to_string(), vec![0.5, 1.5]);
490
491 let combinations = grid.combinations();
492
493 assert_eq!(combinations.len(), 6); assert_eq!(grid.n_combinations(), 6);
495
496 let has_alpha_0_1 = combinations.iter().any(|c| c.get("alpha") == Some(&0.1));
498 assert!(has_alpha_0_1);
499 }
500
501 #[test]
502 fn test_parameter_distribution() {
503 let dist = ParameterDistribution::new()
504 .add_parameter("alpha".to_string(), 0.0, 1.0)
505 .add_parameter("beta".to_string(), 0.0, 10.0);
506
507 let samples = dist.sample(10, Some(42));
508
509 assert_eq!(samples.len(), 10);
510
511 for sample in &samples {
512 let alpha = sample.get("alpha").unwrap();
513 let beta = sample.get("beta").unwrap();
514
515 assert!(*alpha >= 0.0 && *alpha <= 1.0);
516 assert!(*beta >= 0.0 && *beta <= 10.0);
517 }
518 }
519
520 #[test]
521 fn test_variance_preservation_metric() {
522 let x_original = generate_test_data(100, 5, 42);
523 let x_transformed = x_original.clone();
524
525 let metric = VariancePreservationMetric;
526 let score = metric.evaluate(&x_original, &x_transformed);
527
528 assert!((score - 1.0).abs() < 0.1);
530 }
531
532 #[test]
533 fn test_information_preservation_metric() {
534 let x_original = generate_test_data(100, 5, 123);
535 let x_transformed = x_original.clone();
536
537 let metric = InformationPreservationMetric;
538 let score = metric.evaluate(&x_original, &x_transformed);
539
540 assert!(score > 0.9);
542 }
543
544 #[test]
545 fn test_kfold_edge_case_small_dataset() {
546 let kfold = KFold::new(5, false, Some(42));
547 let result = kfold.split(3);
548
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_empty_parameter_grid() {
554 let grid = ParameterGrid::new();
555 let combinations = grid.combinations();
556
557 assert_eq!(combinations.len(), 1);
558 assert!(combinations[0].is_empty());
559 assert_eq!(grid.n_combinations(), 0);
560 }
561}