1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::random::essentials::Uniform;
8use scirs2_core::random::{seeded_rng, Distribution};
9use sklears_core::prelude::SklearsError;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct KFold {
15 pub n_splits: usize,
17 pub shuffle: bool,
19 pub random_state: Option<u64>,
21}
22
23impl KFold {
24 pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
26 Self {
27 n_splits,
28 shuffle,
29 random_state,
30 }
31 }
32
33 pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
35 if n_samples < self.n_splits {
36 return Err(SklearsError::InvalidInput(format!(
37 "Cannot split {} samples into {} folds",
38 n_samples, self.n_splits
39 )));
40 }
41
42 let mut indices: Vec<usize> = (0..n_samples).collect();
43
44 if self.shuffle {
45 use std::time::{SystemTime, UNIX_EPOCH};
46
47 let seed = self.random_state.unwrap_or_else(|| {
48 SystemTime::now()
49 .duration_since(UNIX_EPOCH)
50 .expect("operation should succeed")
51 .as_secs()
52 });
53
54 let mut rng = seeded_rng(seed);
55
56 for i in (1..indices.len()).rev() {
58 let uniform = Uniform::new(0, i + 1).expect("operation should succeed");
59 let j = uniform.sample(&mut rng);
60 indices.swap(i, j);
61 }
62 }
63
64 let fold_size = n_samples / self.n_splits;
65 let mut splits = Vec::new();
66
67 for fold_idx in 0..self.n_splits {
68 let test_start = fold_idx * fold_size;
69 let test_end = if fold_idx == self.n_splits - 1 {
70 n_samples
71 } else {
72 (fold_idx + 1) * fold_size
73 };
74
75 let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
76 let train_indices: Vec<usize> = indices[..test_start]
77 .iter()
78 .chain(&indices[test_end..])
79 .copied()
80 .collect();
81
82 splits.push((train_indices, test_indices));
83 }
84
85 Ok(splits)
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct StratifiedKFold {
92 pub n_splits: usize,
94 pub shuffle: bool,
96 pub random_state: Option<u64>,
98}
99
100impl StratifiedKFold {
101 pub fn new(n_splits: usize, shuffle: bool, random_state: Option<u64>) -> Self {
103 Self {
104 n_splits,
105 shuffle,
106 random_state,
107 }
108 }
109
110 pub fn split(&self, y: &Array1<i32>) -> Result<Vec<(Vec<usize>, Vec<usize>)>, SklearsError> {
112 let n_samples = y.len();
113
114 if n_samples < self.n_splits {
115 return Err(SklearsError::InvalidInput(format!(
116 "Cannot split {} samples into {} folds",
117 n_samples, self.n_splits
118 )));
119 }
120
121 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
123 for (idx, &label) in y.iter().enumerate() {
124 class_indices.entry(label).or_default().push(idx);
125 }
126
127 if self.shuffle {
129 use std::time::{SystemTime, UNIX_EPOCH};
130
131 let seed = self.random_state.unwrap_or_else(|| {
132 SystemTime::now()
133 .duration_since(UNIX_EPOCH)
134 .expect("operation should succeed")
135 .as_secs()
136 });
137
138 let mut rng = seeded_rng(seed);
139
140 for indices in class_indices.values_mut() {
141 for i in (1..indices.len()).rev() {
142 let uniform = Uniform::new(0, i + 1).expect("operation should succeed");
143 let j = uniform.sample(&mut rng);
144 indices.swap(i, j);
145 }
146 }
147 }
148
149 let mut splits: Vec<(Vec<usize>, Vec<usize>)> = vec![];
151
152 for fold_idx in 0..self.n_splits {
153 let mut train_indices = Vec::new();
154 let mut test_indices = Vec::new();
155
156 for indices in class_indices.values() {
157 let fold_size = indices.len() / self.n_splits;
158 let test_start = fold_idx * fold_size;
159 let test_end = if fold_idx == self.n_splits - 1 {
160 indices.len()
161 } else {
162 (fold_idx + 1) * fold_size
163 };
164
165 test_indices.extend(&indices[test_start..test_end]);
166 train_indices.extend(&indices[..test_start]);
167 train_indices.extend(&indices[test_end..]);
168 }
169
170 splits.push((train_indices, test_indices));
171 }
172
173 Ok(splits)
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct CVScore {
180 pub mean: f64,
182 pub std: f64,
184 pub scores: Vec<f64>,
186}
187
188#[derive(Debug, Clone)]
190pub struct ParameterGrid {
191 parameters: HashMap<String, Vec<f64>>,
192}
193
194impl ParameterGrid {
195 pub fn new() -> Self {
197 Self {
198 parameters: HashMap::new(),
199 }
200 }
201
202 pub fn add_parameter(mut self, name: String, values: Vec<f64>) -> Self {
204 self.parameters.insert(name, values);
205 self
206 }
207
208 pub fn combinations(&self) -> Vec<HashMap<String, f64>> {
210 if self.parameters.is_empty() {
211 return vec![HashMap::new()];
212 }
213
214 let mut result = vec![HashMap::new()];
215
216 for (param_name, param_values) in &self.parameters {
217 let mut new_result = Vec::new();
218
219 for combination in &result {
220 for &value in param_values {
221 let mut new_combination = combination.clone();
222 new_combination.insert(param_name.clone(), value);
223 new_result.push(new_combination);
224 }
225 }
226
227 result = new_result;
228 }
229
230 result
231 }
232
233 pub fn n_combinations(&self) -> usize {
235 if self.parameters.is_empty() {
236 return 0;
237 }
238
239 self.parameters.values().map(|v| v.len()).product()
240 }
241}
242
243impl Default for ParameterGrid {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct ParameterDistribution {
252 parameters: HashMap<String, (f64, f64)>, }
254
255impl ParameterDistribution {
256 pub fn new() -> Self {
258 Self {
259 parameters: HashMap::new(),
260 }
261 }
262
263 pub fn add_parameter(mut self, name: String, min: f64, max: f64) -> Self {
265 self.parameters.insert(name, (min, max));
266 self
267 }
268
269 pub fn sample(&self, n_iter: usize, random_state: Option<u64>) -> Vec<HashMap<String, f64>> {
271 use std::time::{SystemTime, UNIX_EPOCH};
272
273 let seed = random_state.unwrap_or_else(|| {
274 SystemTime::now()
275 .duration_since(UNIX_EPOCH)
276 .expect("operation should succeed")
277 .as_secs()
278 });
279
280 let mut rng = seeded_rng(seed);
281
282 (0..n_iter)
283 .map(|_| {
284 self.parameters
285 .iter()
286 .map(|(name, &(min, max))| {
287 let uniform =
288 Uniform::new_inclusive(min, max).expect("operation should succeed");
289 (name.clone(), uniform.sample(&mut rng))
290 })
291 .collect()
292 })
293 .collect()
294 }
295}
296
297impl Default for ParameterDistribution {
298 fn default() -> Self {
299 Self::new()
300 }
301}
302
303pub trait PreprocessingMetric {
305 fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64;
307}
308
309pub struct VariancePreservationMetric;
311
312impl PreprocessingMetric for VariancePreservationMetric {
313 fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
314 let mut total_variance_ratio = 0.0;
315
316 for j in 0..x_original.ncols() {
317 let original_col = x_original.column(j);
318 let transformed_col = x_transformed.column(j);
319
320 let original_var = Self::compute_variance(original_col);
321 let transformed_var = Self::compute_variance(transformed_col);
322
323 if original_var > 1e-10 {
324 total_variance_ratio += transformed_var / original_var;
325 }
326 }
327
328 total_variance_ratio / x_original.ncols() as f64
329 }
330}
331
332impl VariancePreservationMetric {
333 fn compute_variance<'a, I>(values: I) -> f64
334 where
335 I: IntoIterator<Item = &'a f64>,
336 {
337 let vals: Vec<f64> = values
338 .into_iter()
339 .copied()
340 .filter(|v| !v.is_nan())
341 .collect();
342
343 if vals.is_empty() {
344 return 0.0;
345 }
346
347 let mean = vals.iter().sum::<f64>() / vals.len() as f64;
348 vals.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / vals.len() as f64
349 }
350}
351
352pub struct InformationPreservationMetric;
354
355impl PreprocessingMetric for InformationPreservationMetric {
356 fn evaluate(&self, x_original: &Array2<f64>, x_transformed: &Array2<f64>) -> f64 {
357 let mut total_correlation = 0.0;
359 let mut count = 0;
360
361 for j in 0..x_original.ncols().min(x_transformed.ncols()) {
362 let corr = Self::compute_correlation(x_original, x_transformed, j);
363 if !corr.is_nan() {
364 total_correlation += corr.abs();
365 count += 1;
366 }
367 }
368
369 if count > 0 {
370 total_correlation / count as f64
371 } else {
372 0.0
373 }
374 }
375}
376
377impl InformationPreservationMetric {
378 fn compute_correlation(x1: &Array2<f64>, x2: &Array2<f64>, col_idx: usize) -> f64 {
379 let col1 = x1.column(col_idx);
380 let col2 = x2.column(col_idx);
381
382 let pairs: Vec<(f64, f64)> = col1
383 .iter()
384 .zip(col2.iter())
385 .filter(|(a, b)| !a.is_nan() && !b.is_nan())
386 .map(|(&a, &b)| (a, b))
387 .collect();
388
389 if pairs.len() < 2 {
390 return 0.0;
391 }
392
393 let mean1 = pairs.iter().map(|(a, _)| a).sum::<f64>() / pairs.len() as f64;
394 let mean2 = pairs.iter().map(|(_, b)| b).sum::<f64>() / pairs.len() as f64;
395
396 let mut cov = 0.0;
397 let mut var1 = 0.0;
398 let mut var2 = 0.0;
399
400 for (a, b) in &pairs {
401 let d1 = a - mean1;
402 let d2 = b - mean2;
403 cov += d1 * d2;
404 var1 += d1 * d1;
405 var2 += d2 * d2;
406 }
407
408 if var1 < 1e-10 || var2 < 1e-10 {
409 return 0.0;
410 }
411
412 cov / (var1 * var2).sqrt()
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use scirs2_core::random::essentials::Normal;
420 use scirs2_core::random::{seeded_rng, Distribution};
421
422 fn generate_test_data(nrows: usize, ncols: usize, seed: u64) -> Array2<f64> {
423 let mut rng = seeded_rng(seed);
424 let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
425
426 let data: Vec<f64> = (0..nrows * ncols)
427 .map(|_| normal.sample(&mut rng))
428 .collect();
429
430 Array2::from_shape_vec((nrows, ncols), data).expect("shape and data length should match")
431 }
432
433 #[test]
434 fn test_kfold_split() {
435 let kfold = KFold::new(5, false, Some(42));
436 let splits = kfold.split(100).expect("operation should succeed");
437
438 assert_eq!(splits.len(), 5);
439
440 for (train, test) in &splits {
441 assert!(train.len() > 0);
442 assert!(test.len() > 0);
443 assert_eq!(train.len() + test.len(), 100);
444 }
445 }
446
447 #[test]
448 fn test_kfold_shuffle() {
449 let kfold1 = KFold::new(3, true, Some(42));
450 let splits1 = kfold1.split(30).expect("operation should succeed");
451
452 let kfold2 = KFold::new(3, false, None);
453 let splits2 = kfold2.split(30).expect("operation should succeed");
454
455 let different = splits1[0].0 != splits2[0].0;
457 assert!(different);
458 }
459
460 #[test]
461 fn test_stratified_kfold() {
462 let y = Array1::from_vec(vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]);
463
464 let stratified = StratifiedKFold::new(3, false, Some(42));
465 let splits = stratified.split(&y).expect("operation should succeed");
466
467 assert_eq!(splits.len(), 3);
468
469 for (train_indices, test_indices) in &splits {
471 let _train_classes: Vec<i32> = train_indices.iter().map(|&i| y[i]).collect();
472 let test_classes: Vec<i32> = test_indices.iter().map(|&i| y[i]).collect();
473
474 let test_0 = test_classes.iter().filter(|&&c| c == 0).count();
476 let test_1 = test_classes.iter().filter(|&&c| c == 1).count();
477 let test_2 = test_classes.iter().filter(|&&c| c == 2).count();
478
479 assert!(test_0 > 0);
481 assert!(test_1 > 0);
482 assert!(test_2 > 0);
483 }
484 }
485
486 #[test]
487 fn test_parameter_grid() {
488 let grid = ParameterGrid::new()
489 .add_parameter("alpha".to_string(), vec![0.1, 1.0, 10.0])
490 .add_parameter("beta".to_string(), vec![0.5, 1.5]);
491
492 let combinations = grid.combinations();
493
494 assert_eq!(combinations.len(), 6); assert_eq!(grid.n_combinations(), 6);
496
497 let has_alpha_0_1 = combinations.iter().any(|c| c.get("alpha") == Some(&0.1));
499 assert!(has_alpha_0_1);
500 }
501
502 #[test]
503 fn test_parameter_distribution() {
504 let dist = ParameterDistribution::new()
505 .add_parameter("alpha".to_string(), 0.0, 1.0)
506 .add_parameter("beta".to_string(), 0.0, 10.0);
507
508 let samples = dist.sample(10, Some(42));
509
510 assert_eq!(samples.len(), 10);
511
512 for sample in &samples {
513 let alpha = sample.get("alpha").expect("sampling should succeed");
514 let beta = sample.get("beta").expect("sampling should succeed");
515
516 assert!(*alpha >= 0.0 && *alpha <= 1.0);
517 assert!(*beta >= 0.0 && *beta <= 10.0);
518 }
519 }
520
521 #[test]
522 fn test_variance_preservation_metric() {
523 let x_original = generate_test_data(100, 5, 42);
524 let x_transformed = x_original.clone();
525
526 let metric = VariancePreservationMetric;
527 let score = metric.evaluate(&x_original, &x_transformed);
528
529 assert!((score - 1.0).abs() < 0.1);
531 }
532
533 #[test]
534 fn test_information_preservation_metric() {
535 let x_original = generate_test_data(100, 5, 123);
536 let x_transformed = x_original.clone();
537
538 let metric = InformationPreservationMetric;
539 let score = metric.evaluate(&x_original, &x_transformed);
540
541 assert!(score > 0.9);
543 }
544
545 #[test]
546 fn test_kfold_edge_case_small_dataset() {
547 let kfold = KFold::new(5, false, Some(42));
548 let result = kfold.split(3);
549
550 assert!(result.is_err());
551 }
552
553 #[test]
554 fn test_empty_parameter_grid() {
555 let grid = ParameterGrid::new();
556 let combinations = grid.combinations();
557
558 assert_eq!(combinations.len(), 1);
559 assert!(combinations[0].is_empty());
560 assert_eq!(grid.n_combinations(), 0);
561 }
562}