1use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Dimension, OwnedRepr};
5use sklears_core::types::{Float, Int};
6
7pub fn check_consistent_length<T>(arrays: &[&Array1<T>]) -> UtilsResult<()> {
9 if arrays.is_empty() {
10 return Ok(());
11 }
12
13 let first_length = arrays[0].len();
14 for (_i, array) in arrays.iter().enumerate().skip(1) {
15 if array.len() != first_length {
16 return Err(UtilsError::ShapeMismatch {
17 expected: vec![first_length],
18 actual: vec![array.len()],
19 });
20 }
21 }
22 Ok(())
23}
24
25pub fn check_consistent_length_xy<T, U>(x: &Array2<T>, y: &Array1<U>) -> UtilsResult<()> {
27 if x.nrows() != y.len() {
28 return Err(UtilsError::ShapeMismatch {
29 expected: vec![x.nrows()],
30 actual: vec![y.len()],
31 });
32 }
33 Ok(())
34}
35
36pub fn check_array_2d<T>(array: &Array2<T>) -> UtilsResult<()> {
38 check_non_empty(array)?;
39
40 if array.ncols() == 0 {
41 return Err(UtilsError::InvalidParameter(
42 "Array must have at least one column".to_string(),
43 ));
44 }
45
46 Ok(())
47}
48
49pub fn check_x_y(x: &Array2<Float>, y: &Array1<Int>) -> UtilsResult<()> {
51 if x.nrows() != y.len() {
52 return Err(UtilsError::ShapeMismatch {
53 expected: vec![x.nrows()],
54 actual: vec![y.len()],
55 });
56 }
57
58 if x.is_empty() || y.is_empty() {
59 return Err(UtilsError::EmptyInput);
60 }
61
62 Ok(())
63}
64
65pub fn check_x_y_regression(x: &Array2<Float>, y: &Array1<Float>) -> UtilsResult<()> {
67 if x.nrows() != y.len() {
68 return Err(UtilsError::ShapeMismatch {
69 expected: vec![x.nrows()],
70 actual: vec![y.len()],
71 });
72 }
73
74 if x.is_empty() || y.is_empty() {
75 return Err(UtilsError::EmptyInput);
76 }
77
78 Ok(())
79}
80
81pub fn check_non_empty<T, D: Dimension>(array: &ArrayBase<OwnedRepr<T>, D>) -> UtilsResult<()> {
83 if array.is_empty() {
84 return Err(UtilsError::EmptyInput);
85 }
86 Ok(())
87}
88
89pub fn check_positive(value: Float, name: &str) -> UtilsResult<()> {
91 if value <= 0.0 {
92 return Err(UtilsError::InvalidParameter(format!(
93 "{name} must be positive, got {value}"
94 )));
95 }
96 Ok(())
97}
98
99pub fn check_non_negative(value: Float, name: &str) -> UtilsResult<()> {
101 if value < 0.0 {
102 return Err(UtilsError::InvalidParameter(format!(
103 "{name} must be non-negative, got {value}"
104 )));
105 }
106 Ok(())
107}
108
109pub fn check_range(value: Float, min: Float, max: Float, name: &str) -> UtilsResult<()> {
111 if value < min || value > max {
112 return Err(UtilsError::InvalidParameter(format!(
113 "{name} must be in range [{min}, {max}], got {value}"
114 )));
115 }
116 Ok(())
117}
118
119pub fn check_positive_int(value: usize, name: &str) -> UtilsResult<()> {
121 if value == 0 {
122 return Err(UtilsError::InvalidParameter(format!(
123 "{name} must be positive, got {value}"
124 )));
125 }
126 Ok(())
127}
128
129pub fn check_min_samples(n_samples: usize, min_samples: usize) -> UtilsResult<()> {
131 if n_samples < min_samples {
132 return Err(UtilsError::InsufficientData {
133 min: min_samples,
134 actual: n_samples,
135 });
136 }
137 Ok(())
138}
139
140pub fn check_finite(array: &Array2<Float>) -> UtilsResult<()> {
142 for &value in array.iter() {
143 if !value.is_finite() {
144 return Err(UtilsError::InvalidParameter(
145 "Array contains non-finite values (NaN or infinity)".to_string(),
146 ));
147 }
148 }
149 Ok(())
150}
151
152pub fn check_finite_1d(array: &Array1<Float>) -> UtilsResult<()> {
154 for &value in array.iter() {
155 if !value.is_finite() {
156 return Err(UtilsError::InvalidParameter(
157 "Array contains non-finite values (NaN or infinity)".to_string(),
158 ));
159 }
160 }
161 Ok(())
162}
163
164pub fn validate_features(x: &Array2<Float>) -> UtilsResult<()> {
166 check_non_empty(x)?;
167 check_finite(x)?;
168
169 if x.ncols() == 0 {
170 return Err(UtilsError::InvalidParameter(
171 "Feature matrix must have at least one feature".to_string(),
172 ));
173 }
174
175 Ok(())
176}
177
178pub fn validate_target(y: &Array1<Int>) -> UtilsResult<()> {
180 check_non_empty(y)?;
181 Ok(())
182}
183
184pub fn validate_target_regression(y: &Array1<Float>) -> UtilsResult<()> {
186 check_non_empty(y)?;
187 check_finite_1d(y)?;
188 Ok(())
189}
190
191pub fn validate_class_labels(y: &Array1<Int>) -> UtilsResult<Vec<Int>> {
193 validate_target(y)?;
194
195 let mut classes: Vec<Int> = y.iter().copied().collect();
196 classes.sort_unstable();
197 classes.dedup();
198
199 for &class in &classes {
200 if class < 0 {
201 return Err(UtilsError::InvalidParameter(format!(
202 "Class labels must be non-negative, found {class}"
203 )));
204 }
205 }
206
207 Ok(classes)
208}
209
210pub fn check_min_classes(classes: &[Int], min_classes: usize) -> UtilsResult<()> {
212 if classes.len() < min_classes {
213 return Err(UtilsError::InvalidParameter(format!(
214 "Need at least {min_classes} classes, found {}",
215 classes.len()
216 )));
217 }
218 Ok(())
219}
220
221pub fn validate_sample_weights(sample_weight: &Array1<Float>, n_samples: usize) -> UtilsResult<()> {
223 if sample_weight.len() != n_samples {
224 return Err(UtilsError::ShapeMismatch {
225 expected: vec![n_samples],
226 actual: vec![sample_weight.len()],
227 });
228 }
229
230 check_finite_1d(sample_weight)?;
231
232 for &weight in sample_weight.iter() {
233 if weight < 0.0 {
234 return Err(UtilsError::InvalidParameter(
235 "Sample weights must be non-negative".to_string(),
236 ));
237 }
238 }
239
240 if sample_weight.sum() <= 0.0 {
241 return Err(UtilsError::InvalidParameter(
242 "Sum of sample weights must be positive".to_string(),
243 ));
244 }
245
246 Ok(())
247}
248
249pub fn check_matmul_shapes(a: &Array2<Float>, b: &Array2<Float>) -> UtilsResult<()> {
251 if a.ncols() != b.nrows() {
252 return Err(UtilsError::ShapeMismatch {
253 expected: vec![a.nrows(), a.ncols(), b.ncols()],
254 actual: vec![a.nrows(), a.ncols(), b.nrows(), b.ncols()],
255 });
256 }
257 Ok(())
258}
259
260pub fn validate_learning_rate(learning_rate: Float) -> UtilsResult<()> {
262 check_positive(learning_rate, "learning_rate")?;
263 check_range(learning_rate, 0.0, 1.0, "learning_rate")?;
264 Ok(())
265}
266
267pub fn validate_regularization(alpha: Float) -> UtilsResult<()> {
269 check_non_negative(alpha, "alpha")?;
270 Ok(())
271}
272
273pub fn validate_tolerance(tol: Float) -> UtilsResult<()> {
275 check_positive(tol, "tol")?;
276 Ok(())
277}
278
279pub fn validate_max_iter(max_iter: usize) -> UtilsResult<()> {
281 check_positive_int(max_iter, "max_iter")?;
282 Ok(())
283}
284
285pub fn validate_cv_folds(folds: &Array1<i32>, n_samples: usize, n_folds: usize) -> UtilsResult<()> {
287 if folds.len() != n_samples {
288 return Err(UtilsError::ShapeMismatch {
289 expected: vec![n_samples],
290 actual: vec![folds.len()],
291 });
292 }
293
294 for &fold_idx in folds.iter() {
296 if fold_idx < 0 || fold_idx >= n_folds as i32 {
297 return Err(UtilsError::InvalidParameter(format!(
298 "Fold index {fold_idx} is out of range [0, {n_folds})"
299 )));
300 }
301 }
302
303 let mut fold_counts = vec![0; n_folds];
305 for &fold_idx in folds.iter() {
306 fold_counts[fold_idx as usize] += 1;
307 }
308
309 for (i, &count) in fold_counts.iter().enumerate() {
310 if count == 0 {
311 return Err(UtilsError::InvalidParameter(format!(
312 "Fold {i} has no samples assigned"
313 )));
314 }
315 }
316
317 Ok(())
318}
319
320pub fn validate_feature_importance(
322 importance: &Array1<Float>,
323 n_features: usize,
324) -> UtilsResult<()> {
325 if importance.len() != n_features {
326 return Err(UtilsError::ShapeMismatch {
327 expected: vec![n_features],
328 actual: vec![importance.len()],
329 });
330 }
331
332 for (i, &value) in importance.iter().enumerate() {
334 if value < 0.0 {
335 return Err(UtilsError::InvalidParameter(format!(
336 "Feature importance at index {i} is negative: {value}"
337 )));
338 }
339 if !value.is_finite() {
340 return Err(UtilsError::InvalidParameter(format!(
341 "Feature importance at index {i} is not finite: {value}"
342 )));
343 }
344 }
345
346 if importance.iter().all(|&x| x == 0.0) {
348 return Err(UtilsError::InvalidParameter(
349 "All feature importance values are zero".to_string(),
350 ));
351 }
352
353 Ok(())
354}
355
356pub fn validate_classification_predictions(
358 predictions: &Array1<i32>,
359 n_samples: usize,
360 valid_classes: &[i32],
361) -> UtilsResult<()> {
362 if predictions.len() != n_samples {
363 return Err(UtilsError::ShapeMismatch {
364 expected: vec![n_samples],
365 actual: vec![predictions.len()],
366 });
367 }
368
369 for (i, &pred) in predictions.iter().enumerate() {
371 if !valid_classes.contains(&pred) {
372 return Err(UtilsError::InvalidParameter(format!(
373 "Prediction at index {i} ({pred}) is not a valid class label"
374 )));
375 }
376 }
377
378 Ok(())
379}
380
381pub fn validate_regression_predictions(
383 predictions: &Array1<Float>,
384 n_samples: usize,
385) -> UtilsResult<()> {
386 if predictions.len() != n_samples {
387 return Err(UtilsError::ShapeMismatch {
388 expected: vec![n_samples],
389 actual: vec![predictions.len()],
390 });
391 }
392
393 for (i, &value) in predictions.iter().enumerate() {
395 if !value.is_finite() {
396 return Err(UtilsError::InvalidParameter(format!(
397 "Prediction at index {i} is not finite: {value}"
398 )));
399 }
400 }
401
402 Ok(())
403}
404
405pub fn validate_sparse_matrix(
407 data: &Array1<Float>,
408 indices: &Array1<usize>,
409 indptr: &Array1<usize>,
410 n_rows: usize,
411 n_cols: usize,
412) -> UtilsResult<()> {
413 if data.len() != indices.len() {
415 return Err(UtilsError::ShapeMismatch {
416 expected: vec![data.len()],
417 actual: vec![indices.len()],
418 });
419 }
420
421 if indptr.len() != n_rows + 1 {
422 return Err(UtilsError::ShapeMismatch {
423 expected: vec![n_rows + 1],
424 actual: vec![indptr.len()],
425 });
426 }
427
428 if indptr[0] != 0 {
430 return Err(UtilsError::InvalidParameter(
431 "indptr must start with 0".to_string(),
432 ));
433 }
434
435 for i in 1..indptr.len() {
436 if indptr[i] < indptr[i - 1] {
437 return Err(UtilsError::InvalidParameter(
438 "indptr must be non-decreasing".to_string(),
439 ));
440 }
441 }
442
443 if indptr[indptr.len() - 1] != data.len() {
445 return Err(UtilsError::InvalidParameter(
446 "Last indptr value must equal data length".to_string(),
447 ));
448 }
449
450 for &col_idx in indices.iter() {
452 if col_idx >= n_cols {
453 return Err(UtilsError::InvalidParameter(format!(
454 "Column index {col_idx} is out of bounds for matrix with {n_cols} columns"
455 )));
456 }
457 }
458
459 for (i, &value) in data.iter().enumerate() {
461 if !value.is_finite() {
462 return Err(UtilsError::InvalidParameter(format!(
463 "Data value at index {i} is not finite: {value}"
464 )));
465 }
466 }
467
468 Ok(())
469}
470
471pub fn validate_time_series(
473 data: &Array2<Float>,
474 timestamps: &Array1<Float>,
475 min_samples: usize,
476) -> UtilsResult<()> {
477 if data.nrows() != timestamps.len() {
479 return Err(UtilsError::ShapeMismatch {
480 expected: vec![data.nrows()],
481 actual: vec![timestamps.len()],
482 });
483 }
484
485 if data.nrows() < min_samples {
487 return Err(UtilsError::InsufficientData {
488 min: min_samples,
489 actual: data.nrows(),
490 });
491 }
492
493 for i in 1..timestamps.len() {
495 if timestamps[i] <= timestamps[i - 1] {
496 return Err(UtilsError::InvalidParameter(format!(
497 "Timestamps must be strictly increasing. Found {} <= {} at index {}",
498 timestamps[i],
499 timestamps[i - 1],
500 i
501 )));
502 }
503 }
504
505 for (i, &ts) in timestamps.iter().enumerate() {
507 if !ts.is_finite() {
508 return Err(UtilsError::InvalidParameter(format!(
509 "Timestamp at index {i} is not finite: {ts}"
510 )));
511 }
512 }
513
514 for ((i, j), &value) in data.indexed_iter() {
515 if !value.is_finite() {
516 return Err(UtilsError::InvalidParameter(format!(
517 "Data value at index ({i}, {j}) is not finite: {value}"
518 )));
519 }
520 }
521
522 Ok(())
523}
524
525pub fn validate_probability_distribution(
527 probabilities: &Array1<Float>,
528 tolerance: Float,
529) -> UtilsResult<()> {
530 for (i, &prob) in probabilities.iter().enumerate() {
532 if prob < 0.0 {
533 return Err(UtilsError::InvalidParameter(format!(
534 "Probability at index {i} is negative: {prob}"
535 )));
536 }
537 if !prob.is_finite() {
538 return Err(UtilsError::InvalidParameter(format!(
539 "Probability at index {i} is not finite: {prob}"
540 )));
541 }
542 }
543
544 let sum: Float = probabilities.sum();
546 if (sum - 1.0).abs() > tolerance {
547 return Err(UtilsError::InvalidParameter(format!(
548 "Probabilities must sum to 1.0 (±{tolerance}), got {sum}"
549 )));
550 }
551
552 Ok(())
553}
554
555#[allow(non_snake_case)]
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use scirs2_core::ndarray::{array, Array2};
560
561 #[test]
562 fn test_check_consistent_length() {
563 let a = array![1, 2, 3];
564 let b = array![4, 5, 6];
565 let c = array![7, 8];
566
567 assert!(check_consistent_length(&[&a, &b]).is_ok());
568 assert!(check_consistent_length(&[&a, &c]).is_err());
569 }
570
571 #[test]
572 fn test_check_x_y() {
573 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
574 let y_good = array![0, 1, 0];
575 let y_bad = array![0, 1];
576
577 assert!(check_x_y(&x, &y_good).is_ok());
578 assert!(check_x_y(&x, &y_bad).is_err());
579 }
580
581 #[test]
582 fn test_check_positive() {
583 assert!(check_positive(1.0, "test").is_ok());
584 assert!(check_positive(0.0, "test").is_err());
585 assert!(check_positive(-1.0, "test").is_err());
586 }
587
588 #[test]
589 fn test_check_range() {
590 assert!(check_range(0.5, 0.0, 1.0, "test").is_ok());
591 assert!(check_range(-0.1, 0.0, 1.0, "test").is_err());
592 assert!(check_range(1.1, 0.0, 1.0, "test").is_err());
593 }
594
595 #[test]
596 fn test_validate_class_labels() {
597 let y_good = array![0, 1, 2, 1, 0];
598 let y_bad = array![0, 1, -1, 1, 0];
599
600 let classes = validate_class_labels(&y_good).unwrap();
601 assert_eq!(classes, vec![0, 1, 2]);
602
603 assert!(validate_class_labels(&y_bad).is_err());
604 }
605
606 #[test]
607 fn test_check_finite() {
608 let good = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
609 let bad = Array2::from_shape_vec((2, 2), vec![1.0, Float::NAN, 3.0, 4.0]).unwrap();
610
611 assert!(check_finite(&good).is_ok());
612 assert!(check_finite(&bad).is_err());
613 }
614
615 #[test]
616 fn test_validate_sample_weights() {
617 let good_weights = array![1.0, 2.0, 1.5];
618 let bad_weights = array![1.0, -1.0, 1.5];
619 let zero_sum_weights = array![0.0, 0.0, 0.0];
620
621 assert!(validate_sample_weights(&good_weights, 3).is_ok());
622 assert!(validate_sample_weights(&bad_weights, 3).is_err());
623 assert!(validate_sample_weights(&zero_sum_weights, 3).is_err());
624 }
625
626 #[test]
627 fn test_validate_cv_folds() {
628 let good_folds = array![0, 1, 2, 0, 1, 2];
629 let bad_folds_range = array![0, 1, 3, 0, 1, 2]; let bad_folds_missing = array![0, 0, 1, 1, 1, 1]; assert!(validate_cv_folds(&good_folds, 6, 3).is_ok());
633 assert!(validate_cv_folds(&bad_folds_range, 6, 3).is_err());
634 assert!(validate_cv_folds(&bad_folds_missing, 6, 3).is_err());
635 }
636
637 #[test]
638 fn test_validate_feature_importance() {
639 let good_importance = array![0.5, 0.3, 0.2];
640 let bad_importance_negative = array![0.5, -0.1, 0.2];
641 let bad_importance_all_zero = array![0.0, 0.0, 0.0];
642
643 assert!(validate_feature_importance(&good_importance, 3).is_ok());
644 assert!(validate_feature_importance(&bad_importance_negative, 3).is_err());
645 assert!(validate_feature_importance(&bad_importance_all_zero, 3).is_err());
646 }
647
648 #[test]
649 fn test_validate_classification_predictions() {
650 let good_predictions = array![0, 1, 2, 1, 0];
651 let bad_predictions = array![0, 1, 3, 1, 0]; let valid_classes = vec![0, 1, 2];
653
654 assert!(validate_classification_predictions(&good_predictions, 5, &valid_classes).is_ok());
655 assert!(validate_classification_predictions(&bad_predictions, 5, &valid_classes).is_err());
656 }
657
658 #[test]
659 fn test_validate_regression_predictions() {
660 let good_predictions = array![1.5, 2.3, -0.5, 10.0];
661 let bad_predictions = array![1.5, Float::NAN, -0.5, 10.0];
662
663 assert!(validate_regression_predictions(&good_predictions, 4).is_ok());
664 assert!(validate_regression_predictions(&bad_predictions, 4).is_err());
665 }
666
667 #[test]
668 fn test_validate_sparse_matrix() {
669 let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
671 let indices = array![0, 2, 2, 0, 1, 2];
672 let indptr = array![0, 2, 3, 6];
673
674 assert!(validate_sparse_matrix(&data, &indices, &indptr, 3, 3).is_ok());
675
676 let bad_indptr = array![1, 2, 3, 6];
678 assert!(validate_sparse_matrix(&data, &indices, &bad_indptr, 3, 3).is_err());
679 }
680
681 #[test]
682 fn test_validate_time_series() {
683 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
684 let good_timestamps = array![1.0, 2.0, 3.0];
685 let bad_timestamps = array![1.0, 1.5, 1.2]; assert!(validate_time_series(&data, &good_timestamps, 2).is_ok());
688 assert!(validate_time_series(&data, &bad_timestamps, 2).is_err());
689 }
690
691 #[test]
692 fn test_validate_probability_distribution() {
693 let good_probs = array![0.3, 0.5, 0.2];
694 let bad_probs_negative = array![0.3, -0.1, 0.8];
695 let bad_probs_sum = array![0.3, 0.5, 0.3]; assert!(validate_probability_distribution(&good_probs, 1e-6).is_ok());
698 assert!(validate_probability_distribution(&bad_probs_negative, 1e-6).is_err());
699 assert!(validate_probability_distribution(&bad_probs_sum, 1e-6).is_err());
700 }
701}