1use ndarray::{ArrayBase, Dimension, ScalarOperand};
7use num_traits::{Float, One, Zero};
8
9use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
10
11pub fn check_in_bounds<T, S>(value: T, min: T, max: T, name: S) -> CoreResult<T>
29where
30 T: PartialOrd + std::fmt::Display + Copy,
31 S: Into<String>,
32{
33 if value < min || value > max {
34 return Err(CoreError::ValueError(
35 ErrorContext::new(format!(
36 "{} must be between {min} and {max}, got {value}",
37 name.into()
38 ))
39 .with_location(ErrorLocation::new(file!(), line!())),
40 ));
41 }
42 Ok(value)
43}
44
45pub fn check_positive<T, S>(value: T, name: S) -> CoreResult<T>
61where
62 T: PartialOrd + std::fmt::Display + Copy + Zero,
63 S: Into<String>,
64{
65 if value <= T::zero() {
66 return Err(CoreError::ValueError(
67 ErrorContext::new({
68 let name_str = name.into();
69 format!("{name_str} must be positive, got {value}")
70 })
71 .with_location(ErrorLocation::new(file!(), line!())),
72 ));
73 }
74 Ok(value)
75}
76
77pub fn check_non_negative<T, S>(value: T, name: S) -> CoreResult<T>
93where
94 T: PartialOrd + std::fmt::Display + Copy + Zero,
95 S: Into<String>,
96{
97 if value < T::zero() {
98 return Err(CoreError::ValueError(
99 ErrorContext::new({
100 let name_str = name.into();
101 format!("{name_str} must be non-negative, got {value}")
102 })
103 .with_location(ErrorLocation::new(file!(), line!())),
104 ));
105 }
106 Ok(value)
107}
108
109pub fn check_finite<T, S>(value: T, name: S) -> CoreResult<T>
125where
126 T: Float + std::fmt::Display + Copy,
127 S: Into<String>,
128{
129 if !value.is_finite() {
130 return Err(CoreError::ValueError(
131 ErrorContext::new({
132 let name_str = name.into();
133 format!("{name_str} must be finite, got {value}")
134 })
135 .with_location(ErrorLocation::new(file!(), line!())),
136 ));
137 }
138 Ok(value)
139}
140
141pub fn checkarray_finite<S, A, D>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
157where
158 S: ndarray::Data,
159 D: Dimension,
160 S::Elem: Float + std::fmt::Display,
161 A: Into<String>,
162{
163 let name = name.into();
164 for (idx, &value) in array.indexed_iter() {
165 if !value.is_finite() {
166 return Err(CoreError::ValueError(
167 ErrorContext::new(format!(
168 "{name} must contain only finite values, got {value} at {idx:?}"
169 ))
170 .with_location(ErrorLocation::new(file!(), line!())),
171 ));
172 }
173 }
174 Ok(())
175}
176
177pub fn checkshape<S, D, A>(
194 array: &ArrayBase<S, D>,
195 expectedshape: &[usize],
196 name: A,
197) -> CoreResult<()>
198where
199 S: ndarray::Data,
200 D: Dimension,
201 A: Into<String>,
202{
203 let actualshape = array.shape();
204 if actualshape != expectedshape {
205 return Err(CoreError::ShapeError(
206 ErrorContext::new(format!(
207 "{} has incorrect shape: expected {expectedshape:?}, got {actualshape:?}",
208 name.into()
209 ))
210 .with_location(ErrorLocation::new(file!(), line!())),
211 ));
212 }
213 Ok(())
214}
215
216pub fn check_1d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
232where
233 S: ndarray::Data,
234 D: Dimension,
235 A: Into<String>,
236{
237 if array.ndim() != 1 {
238 return Err(CoreError::ShapeError(
239 ErrorContext::new({
240 let name_str = name.into();
241 let ndim = array.ndim();
242 format!("{name_str} must be 1D, got {ndim}D")
243 })
244 .with_location(ErrorLocation::new(file!(), line!())),
245 ));
246 }
247 Ok(())
248}
249
250pub fn check_2d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
266where
267 S: ndarray::Data,
268 D: Dimension,
269 A: Into<String>,
270{
271 if array.ndim() != 2 {
272 return Err(CoreError::ShapeError(
273 ErrorContext::new({
274 let name_str = name.into();
275 let ndim = array.ndim();
276 format!("{name_str} must be 2D, got {ndim}D")
277 })
278 .with_location(ErrorLocation::new(file!(), line!())),
279 ));
280 }
281 Ok(())
282}
283
284pub fn check_sameshape<S1, S2, D1, D2, A, B>(
302 a: &ArrayBase<S1, D1>,
303 a_name: A,
304 b: &ArrayBase<S2, D2>,
305 b_name: B,
306) -> CoreResult<()>
307where
308 S1: ndarray::Data,
309 S2: ndarray::Data,
310 D1: Dimension,
311 D2: Dimension,
312 A: Into<String>,
313 B: Into<String>,
314{
315 let ashape = a.shape();
316 let bshape = b.shape();
317 if ashape != bshape {
318 return Err(CoreError::ShapeError(
319 ErrorContext::new(format!(
320 "{} and {} must have the same shape, got {:?} and {:?}",
321 a_name.into(),
322 b_name.into(),
323 ashape,
324 bshape
325 ))
326 .with_location(ErrorLocation::new(file!(), line!())),
327 ));
328 }
329 Ok(())
330}
331
332pub fn check_square<S, D, A>(matrix: &ArrayBase<S, D>, name: A) -> CoreResult<()>
348where
349 S: ndarray::Data,
350 D: Dimension,
351 A: Into<String> + std::string::ToString,
352{
353 check_2d(matrix, name.to_string())?;
354 let shape = matrix.shape();
355 if shape[0] != shape[1] {
356 return Err(CoreError::ShapeError(
357 ErrorContext::new(format!(
358 "{} must be square, got shape {:?}",
359 name.into(),
360 shape
361 ))
362 .with_location(ErrorLocation::new(file!(), line!())),
363 ));
364 }
365 Ok(())
366}
367
368pub fn check_probability<T, S>(p: T, name: S) -> CoreResult<T>
384where
385 T: Float + std::fmt::Display + Copy,
386 S: Into<String>,
387{
388 if p < T::zero() || p > T::one() {
389 return Err(CoreError::ValueError(
390 ErrorContext::new(format!(
391 "{} must be between 0 and 1, got {}",
392 name.into(),
393 p
394 ))
395 .with_location(ErrorLocation::new(file!(), line!())),
396 ));
397 }
398 Ok(p)
399}
400
401pub fn check_probabilities<S, D, A>(probs: &ArrayBase<S, D>, name: A) -> CoreResult<()>
417where
418 S: ndarray::Data,
419 D: Dimension,
420 S::Elem: Float + std::fmt::Display,
421 A: Into<String>,
422{
423 let name = name.into();
424 for (idx, &p) in probs.indexed_iter() {
425 if p < S::Elem::zero() || p > S::Elem::one() {
426 return Err(CoreError::ValueError(
427 ErrorContext::new(format!(
428 "{name} must contain only values between 0 and 1, got {p} at {idx:?}"
429 ))
430 .with_location(ErrorLocation::new(file!(), line!())),
431 ));
432 }
433 }
434 Ok(())
435}
436
437pub fn check_probabilities_sum_to_one<S, D, A>(
450 probs: &ArrayBase<S, D>,
451 name: A,
452 tol: Option<S::Elem>,
453) -> CoreResult<()>
454where
455 S: ndarray::Data,
456 D: Dimension,
457 S::Elem: Float + std::fmt::Display + ScalarOperand,
458 A: Into<String> + std::string::ToString,
459{
460 let tol = tol.unwrap_or_else(|| {
461 let eps: f64 = 1e-10;
462 num_traits::cast(eps).unwrap_or_else(|| {
463 S::Elem::epsilon()
465 })
466 });
467
468 check_probabilities(probs, name.to_string())?;
469
470 let sum = probs.sum();
471 let one = S::Elem::one();
472
473 if (sum - one).abs() > tol {
474 return Err(CoreError::ValueError(
475 ErrorContext::new({
476 let name_str = name.into();
477 format!("{name_str} must sum to 1, got sum = {sum}")
478 })
479 .with_location(ErrorLocation::new(file!(), line!())),
480 ));
481 }
482
483 Ok(())
484}
485
486pub fn check_not_empty<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
498where
499 S: ndarray::Data,
500 D: Dimension,
501 A: Into<String>,
502{
503 if array.is_empty() {
504 return Err(CoreError::ValueError(
505 ErrorContext::new({
506 let name_str = name.into();
507 format!("{name_str} cannot be empty")
508 })
509 .with_location(ErrorLocation::new(file!(), line!())),
510 ));
511 }
512 Ok(())
513}
514
515pub fn check_min_samples<S, D, A>(
528 array: &ArrayBase<S, D>,
529 min_samples: usize,
530 name: A,
531) -> CoreResult<()>
532where
533 S: ndarray::Data,
534 D: Dimension,
535 A: Into<String>,
536{
537 let n_samples = array.shape()[0];
538 if n_samples < min_samples {
539 return Err(CoreError::ValueError(
540 ErrorContext::new(format!(
541 "{} must have at least {} samples, got {}",
542 name.into(),
543 min_samples,
544 n_samples
545 ))
546 .with_location(ErrorLocation::new(file!(), line!())),
547 ));
548 }
549 Ok(())
550}
551
552pub mod clustering {
554 use super::*;
555
556 pub fn check_n_clusters_bounds<S, D>(
569 data: &ArrayBase<S, D>,
570 n_clusters: usize,
571 operation: &str,
572 ) -> CoreResult<()>
573 where
574 S: ndarray::Data,
575 D: Dimension,
576 {
577 let n_samples = data.shape()[0];
578
579 if n_clusters == 0 {
580 return Err(CoreError::ValueError(
581 ErrorContext::new(format!(
582 "{operation}: number of _clusters must be > 0, got {n_clusters}"
583 ))
584 .with_location(ErrorLocation::new(file!(), line!())),
585 ));
586 }
587
588 if n_clusters > n_samples {
589 return Err(CoreError::ValueError(
590 ErrorContext::new(format!(
591 "{operation}: number of _clusters ({n_clusters}) cannot exceed number of samples ({n_samples})"
592 ))
593 .with_location(ErrorLocation::new(file!(), line!())),
594 ));
595 }
596
597 Ok(())
598 }
599
600 pub fn validate_clustering_data<S, D>(
614 data: &ArrayBase<S, D>,
615 _operation: &str,
616 check_finite: bool,
617 min_samples: Option<usize>,
618 ) -> CoreResult<()>
619 where
620 S: ndarray::Data,
621 D: Dimension,
622 S::Elem: Float + std::fmt::Display,
623 {
624 check_not_empty(data, "data")?;
626
627 check_2d(data, "data")?;
629
630 if let Some(min) = min_samples {
632 check_min_samples(data, min, "data")?;
633 }
634
635 if check_finite {
637 checkarray_finite(data, "data")?;
638 }
639
640 Ok(())
641 }
642}
643
644pub mod parameters {
646 use super::*;
647
648 pub fn check_iteration_params<T>(
661 max_iter: usize,
662 tolerance: T,
663 operation: &str,
664 ) -> CoreResult<()>
665 where
666 T: Float + std::fmt::Display + Copy,
667 {
668 if max_iter == 0 {
669 return Err(CoreError::ValueError(
670 ErrorContext::new(format!("{operation}: max_iter must be > 0, got {max_iter}"))
671 .with_location(ErrorLocation::new(file!(), line!())),
672 ));
673 }
674
675 check_positive(tolerance, format!("{operation} tolerance"))?;
676
677 Ok(())
678 }
679
680 pub fn check_unit_interval<T>(value: T, name: &str, operation: &str) -> CoreResult<T>
693 where
694 T: Float + std::fmt::Display + Copy,
695 {
696 if value < T::zero() || value > T::one() {
697 return Err(CoreError::ValueError(
698 ErrorContext::new(format!(
699 "{operation}: {name} must be in [0, 1], got {value}"
700 ))
701 .with_location(ErrorLocation::new(file!(), line!())),
702 ));
703 }
704 Ok(value)
705 }
706
707 pub fn checkbandwidth<T>(bandwidth: T, operation: &str) -> CoreResult<T>
719 where
720 T: Float + std::fmt::Display + Copy,
721 {
722 check_positive(bandwidth, format!("{operation} bandwidth"))
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use ndarray::{arr1, arr2};
730
731 #[test]
732 fn test_check_in_bounds() {
733 assert!(check_in_bounds(5, 0, 10, "param").is_ok());
734 assert!(check_in_bounds(0, 0, 10, "param").is_ok());
735 assert!(check_in_bounds(10, 0, 10, "param").is_ok());
736 assert!(check_in_bounds(-1, 0, 10, "param").is_err());
737 assert!(check_in_bounds(11, 0, 10, "param").is_err());
738 }
739
740 #[test]
741 fn test_check_positive() {
742 assert!(check_positive(5, "param").is_ok());
743 assert!(check_positive(0.1, "param").is_ok());
744 assert!(check_positive(0, "param").is_err());
745 assert!(check_positive(-1, "param").is_err());
746 }
747
748 #[test]
749 fn test_check_non_negative() {
750 assert!(check_non_negative(5, "param").is_ok());
751 assert!(check_non_negative(0, "param").is_ok());
752 assert!(check_non_negative(-0.1, "param").is_err());
753 assert!(check_non_negative(-1, "param").is_err());
754 }
755
756 #[test]
757 fn test_check_finite() {
758 assert!(check_finite(5.0, "param").is_ok());
759 assert!(check_finite(0.0, "param").is_ok());
760 assert!(check_finite(-1.0, "param").is_ok());
761 assert!(check_finite(f64::INFINITY, "param").is_err());
762 assert!(check_finite(f64::NEG_INFINITY, "param").is_err());
763 assert!(check_finite(f64::NAN, "param").is_err());
764 }
765
766 #[test]
767 fn test_checkarray_finite() {
768 let a = arr1(&[1.0, 2.0, 3.0]);
769 assert!(checkarray_finite(&a, "array").is_ok());
770
771 let b = arr1(&[1.0, f64::INFINITY, 3.0]);
772 assert!(checkarray_finite(&b, "array").is_err());
773
774 let c = arr1(&[1.0, f64::NAN, 3.0]);
775 assert!(checkarray_finite(&c, "array").is_err());
776 }
777
778 #[test]
779 fn test_checkshape() {
780 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
781 assert!(checkshape(&a, &[2, 2], "array").is_ok());
782 assert!(checkshape(&a, &[2, 3], "array").is_err());
783 }
784
785 #[test]
786 fn test_check_1d() {
787 let a = arr1(&[1.0, 2.0, 3.0]);
788 assert!(check_1d(&a, "array").is_ok());
789
790 let b = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
791 assert!(check_1d(&b, "array").is_err());
792 }
793
794 #[test]
795 fn test_check_2d() {
796 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
797 assert!(check_2d(&a, "array").is_ok());
798
799 let b = arr1(&[1.0, 2.0, 3.0]);
800 assert!(check_2d(&b, "array").is_err());
801 }
802
803 #[test]
804 fn test_check_sameshape() {
805 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
806 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
807 assert!(check_sameshape(&a, "a", &b, "b").is_ok());
808
809 let c = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
810 assert!(check_sameshape(&a, "a", &c, "c").is_err());
811 }
812
813 #[test]
814 fn test_check_square() {
815 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
816 assert!(check_square(&a, "matrix").is_ok());
817
818 let b = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
819 assert!(check_square(&b, "matrix").is_err());
820
821 let c = arr1(&[1.0, 2.0, 3.0]);
822 assert!(check_square(&c, "matrix").is_err());
823 }
824
825 #[test]
826 fn test_check_probability() {
827 assert!(check_probability(0.0, "p").is_ok());
828 assert!(check_probability(0.5, "p").is_ok());
829 assert!(check_probability(1.0, "p").is_ok());
830 assert!(check_probability(-0.1, "p").is_err());
831 assert!(check_probability(1.1, "p").is_err());
832 }
833
834 #[test]
835 fn test_check_probabilities() {
836 let a = arr1(&[0.0, 0.5, 1.0]);
837 assert!(check_probabilities(&a, "probs").is_ok());
838
839 let b = arr1(&[0.0, 0.5, 1.1]);
840 assert!(check_probabilities(&b, "probs").is_err());
841
842 let c = arr1(&[-0.1, 0.5, 1.0]);
843 assert!(check_probabilities(&c, "probs").is_err());
844 }
845
846 #[test]
847 fn test_check_probabilities_sum_to_one() {
848 let a = arr1(&[0.3, 0.2, 0.5]);
849 assert!(check_probabilities_sum_to_one(&a, "probs", None).is_ok());
850
851 let b = arr1(&[0.3, 0.2, 0.6]);
852 assert!(check_probabilities_sum_to_one(&b, "probs", None).is_err());
853
854 let c = arr1(&[0.3, 0.2, 0.501]);
856 assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.01)).is_ok());
857 assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.0001)).is_err());
858 }
859
860 #[test]
861 fn test_check_not_empty() {
862 let a = arr1(&[1.0, 2.0, 3.0]);
863 assert!(check_not_empty(&a, "array").is_ok());
864
865 let b = arr1(&[] as &[f64]);
866 assert!(check_not_empty(&b, "array").is_err());
867 }
868
869 #[test]
870 fn test_check_min_samples() {
871 let a = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
872 assert!(check_min_samples(&a, 2, "array").is_ok());
873 assert!(check_min_samples(&a, 3, "array").is_ok());
874 assert!(check_min_samples(&a, 4, "array").is_err());
875 }
876
877 mod clustering_tests {
878 use super::*;
879 use crate::validation::clustering::*;
880
881 #[test]
882 fn test_check_n_clusters_bounds() {
883 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
884
885 assert!(check_n_clusters_bounds(&data, 1, "test").is_ok());
886 assert!(check_n_clusters_bounds(&data, 2, "test").is_ok());
887 assert!(check_n_clusters_bounds(&data, 3, "test").is_ok());
888 assert!(check_n_clusters_bounds(&data, 0, "test").is_err());
889 assert!(check_n_clusters_bounds(&data, 4, "test").is_err());
890 }
891
892 #[test]
893 fn test_validate_clustering_data() {
894 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
895 assert!(validate_clustering_data(&data, "test", true, Some(2)).is_ok());
896 assert!(validate_clustering_data(&data, "test", true, Some(4)).is_err());
897
898 let empty_data = arr2(&[] as &[[f64; 2]; 0]);
899 assert!(validate_clustering_data(&empty_data, "test", true, None).is_err());
900
901 let inf_data = arr2(&[[1.0, f64::INFINITY], [3.0, 4.0]]);
902 assert!(validate_clustering_data(&inf_data, "test", true, None).is_err());
903 assert!(validate_clustering_data(&inf_data, "test", false, None).is_ok());
904 }
905 }
906
907 mod parameters_tests {
908 use crate::validation::parameters::*;
909
910 #[test]
911 fn test_check_iteration_params() {
912 assert!(check_iteration_params(100, 1e-6, "test").is_ok());
913 assert!(check_iteration_params(0, 1e-6, "test").is_err());
914 assert!(check_iteration_params(100, 0.0, "test").is_err());
915 assert!(check_iteration_params(100, -1e-6, "test").is_err());
916 }
917
918 #[test]
919 fn test_check_unit_interval() {
920 assert!(check_unit_interval(0.0, "param", "test").is_ok());
921 assert!(check_unit_interval(0.5, "param", "test").is_ok());
922 assert!(check_unit_interval(1.0, "param", "test").is_ok());
923 assert!(check_unit_interval(-0.1, "param", "test").is_err());
924 assert!(check_unit_interval(1.1, "param", "test").is_err());
925 }
926
927 #[test]
928 fn test_checkbandwidth() {
929 assert!(checkbandwidth(1.0, "test").is_ok());
930 assert!(checkbandwidth(0.1, "test").is_ok());
931 assert!(checkbandwidth(0.0, "test").is_err());
932 assert!(checkbandwidth(-1.0, "test").is_err());
933 }
934 }
935}
936
937pub mod custom {
939 use super::*;
940 use std::fmt;
941 use std::marker::PhantomData;
942
943 pub trait Validator<T> {
945 fn validate(&self, value: &T, name: &str) -> CoreResult<()>;
947
948 fn description(&self) -> String;
950
951 fn and<V: Validator<T>>(self, other: V) -> CompositeValidator<T, Self, V>
953 where
954 Self: Sized,
955 {
956 CompositeValidator::new(self, other)
957 }
958
959 fn when<F>(self, condition: F) -> ConditionalValidator<T, Self, F>
961 where
962 Self: Sized,
963 F: Fn(&T) -> bool,
964 {
965 ConditionalValidator::new(self, condition)
966 }
967 }
968
969 pub struct CompositeValidator<T, V1, V2> {
971 validator1: V1,
972 validator2: V2,
973 _phantom: PhantomData<T>,
974 }
975
976 impl<T, V1, V2> CompositeValidator<T, V1, V2> {
977 pub fn new(validator1: V1, validator2: V2) -> Self {
978 Self {
979 validator1,
980 validator2,
981 _phantom: PhantomData,
982 }
983 }
984 }
985
986 impl<T, V1, V2> Validator<T> for CompositeValidator<T, V1, V2>
987 where
988 V1: Validator<T>,
989 V2: Validator<T>,
990 {
991 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
992 self.validator1.validate(value, name)?;
993 self.validator2.validate(value, name)?;
994 Ok(())
995 }
996
997 fn description(&self) -> String {
998 format!(
999 "{} AND {}",
1000 self.validator1.description(),
1001 self.validator2.description()
1002 )
1003 }
1004 }
1005
1006 pub struct ConditionalValidator<T, V, F> {
1008 validator: V,
1009 condition: F,
1010 phantom: PhantomData<T>,
1011 }
1012
1013 impl<T, V, F> ConditionalValidator<T, V, F> {
1014 pub fn new(validator: V, condition: F) -> Self {
1015 Self {
1016 validator,
1017 condition,
1018 phantom: PhantomData,
1019 }
1020 }
1021 }
1022
1023 impl<T, V, F> Validator<T> for ConditionalValidator<T, V, F>
1024 where
1025 V: Validator<T>,
1026 F: Fn(&T) -> bool,
1027 {
1028 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1029 if (self.condition)(value) {
1030 self.validator.validate(value, name)
1031 } else {
1032 Ok(())
1033 }
1034 }
1035
1036 fn description(&self) -> String {
1037 {
1038 let desc = self.validator.description();
1039 format!("IF condition THEN {desc}")
1040 }
1041 }
1042 }
1043
1044 pub struct RangeValidator<T> {
1046 min: Option<T>,
1047 max: Option<T>,
1048 min_inclusive: bool,
1049 max_inclusive: bool,
1050 }
1051
1052 impl<T> RangeValidator<T>
1053 where
1054 T: PartialOrd + Copy + fmt::Display,
1055 {
1056 pub fn new() -> Self {
1057 Self {
1058 min: None,
1059 max: None,
1060 min_inclusive: true,
1061 max_inclusive: true,
1062 }
1063 }
1064
1065 pub fn min(mut self, min: T) -> Self {
1066 self.min = Some(min);
1067 self
1068 }
1069
1070 pub fn max(mut self, max: T) -> Self {
1071 self.max = Some(max);
1072 self
1073 }
1074
1075 pub fn min_exclusive(mut self, min: T) -> Self {
1076 self.min = Some(min);
1077 self.min_inclusive = false;
1078 self
1079 }
1080
1081 pub fn max_exclusive(mut self, max: T) -> Self {
1082 self.max = Some(max);
1083 self.max_inclusive = false;
1084 self
1085 }
1086
1087 pub fn in_range(min: T, max: T) -> Self {
1088 Self::new().min(min).max(max)
1089 }
1090
1091 pub fn in_range_exclusive(min: T, max: T) -> Self {
1092 Self::new().min_exclusive(min).max_exclusive(max)
1093 }
1094 }
1095
1096 impl<T> Default for RangeValidator<T>
1097 where
1098 T: PartialOrd + Copy + fmt::Display,
1099 {
1100 fn default() -> Self {
1101 Self::new()
1102 }
1103 }
1104
1105 impl<T> Validator<T> for RangeValidator<T>
1106 where
1107 T: PartialOrd + Copy + fmt::Display,
1108 {
1109 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1110 if let Some(min) = self.min {
1111 let valid = if self.min_inclusive {
1112 *value >= min
1113 } else {
1114 *value > min
1115 };
1116 if !valid {
1117 let op = if self.min_inclusive { ">=" } else { ">" };
1118 return Err(CoreError::ValueError(
1119 ErrorContext::new(format!("{name} must be {op} {min}, got {value}"))
1120 .with_location(ErrorLocation::new(file!(), line!())),
1121 ));
1122 }
1123 }
1124
1125 if let Some(max) = self.max {
1126 let valid = if self.max_inclusive {
1127 *value <= max
1128 } else {
1129 *value < max
1130 };
1131 if !valid {
1132 let op = if self.max_inclusive { "<=" } else { "<" };
1133 return Err(CoreError::ValueError(
1134 ErrorContext::new(format!("{name} must be {op} {max}, got {value}"))
1135 .with_location(ErrorLocation::new(file!(), line!())),
1136 ));
1137 }
1138 }
1139
1140 Ok(())
1141 }
1142
1143 fn description(&self) -> String {
1144 match (self.min, self.max) {
1145 (Some(min), Some(max)) => {
1146 let min_op = if self.min_inclusive { ">=" } else { ">" };
1147 let max_op = if self.max_inclusive { "<=" } else { "<" };
1148 format!("value {min_op} {min} and {max_op} {max}")
1149 }
1150 (Some(min), None) => {
1151 let op = if self.min_inclusive { ">=" } else { ">" };
1152 format!("value {op} {min}")
1153 }
1154 (None, Some(max)) => {
1155 let op = if self.max_inclusive { "<=" } else { "<" };
1156 format!("value {op} {max}")
1157 }
1158 (None, None) => "no range constraints".to_string(),
1159 }
1160 }
1161 }
1162
1163 type ShapeValidatorFn = Box<dyn Fn(&[usize]) -> CoreResult<()>>;
1165
1166 pub struct ArrayValidator<T, D>
1168 where
1169 D: Dimension,
1170 {
1171 shape_validator: Option<ShapeValidatorFn>,
1172 element_validator: Option<Box<dyn Validator<T>>>,
1173 size_validator: Option<RangeValidator<usize>>,
1174 phantom: PhantomData<D>,
1175 }
1176
1177 impl<T, D> ArrayValidator<T, D>
1178 where
1179 D: Dimension,
1180 {
1181 pub fn new() -> Self {
1182 Self {
1183 shape_validator: None,
1184 element_validator: None,
1185 size_validator: None,
1186 phantom: PhantomData,
1187 }
1188 }
1189
1190 pub fn withshape<F>(mut self, validator: F) -> Self
1191 where
1192 F: Fn(&[usize]) -> CoreResult<()> + 'static,
1193 {
1194 self.shape_validator = Some(Box::new(validator));
1195 self
1196 }
1197
1198 pub fn with_elements<V>(mut self, validator: V) -> Self
1199 where
1200 V: Validator<T> + 'static,
1201 {
1202 self.element_validator = Some(Box::new(validator));
1203 self
1204 }
1205
1206 pub fn with_size(mut self, validator: RangeValidator<usize>) -> Self {
1207 self.size_validator = Some(validator);
1208 self
1209 }
1210
1211 pub fn minsize(self, minsize: usize) -> Self {
1212 self.with_size(RangeValidator::new().min(minsize))
1213 }
1214
1215 pub fn maxsize(self, maxsize: usize) -> Self {
1216 self.with_size(RangeValidator::new().max(maxsize))
1217 }
1218
1219 pub fn exact_size(self, size: usize) -> Self {
1220 self.with_size(RangeValidator::new().min(size).max(size))
1221 }
1222 }
1223
1224 impl<T, D> Default for ArrayValidator<T, D>
1225 where
1226 D: Dimension,
1227 {
1228 fn default() -> Self {
1229 Self::new()
1230 }
1231 }
1232
1233 impl<S, T, D> Validator<ArrayBase<S, D>> for ArrayValidator<T, D>
1234 where
1235 S: ndarray::Data<Elem = T>,
1236 T: Clone,
1237 D: Dimension,
1238 {
1239 fn validate(&self, array: &ArrayBase<S, D>, name: &str) -> CoreResult<()> {
1240 if let Some(ref shape_validator) = self.shape_validator {
1242 shape_validator(array.shape())?;
1243 }
1244
1245 if let Some(ref size_validator) = self.size_validator {
1247 size_validator.validate(&array.len(), &format!("{name} size"))?;
1248 }
1249
1250 if let Some(ref element_validator) = self.element_validator {
1252 for (idx, element) in array.indexed_iter() {
1253 element_validator.validate(element, &format!("{name} element at {idx:?}"))?;
1254 }
1255 }
1256
1257 Ok(())
1258 }
1259
1260 fn description(&self) -> String {
1261 let mut parts = Vec::new();
1262
1263 if self.shape_validator.is_some() {
1264 parts.push("shape validation".to_string());
1265 }
1266
1267 if let Some(ref size_validator) = self.size_validator {
1268 {
1269 let desc = size_validator.description();
1270 parts.push(format!("size {desc}"));
1271 }
1272 }
1273
1274 if let Some(ref element_validator) = self.element_validator {
1275 {
1276 let desc = element_validator.description();
1277 parts.push(format!("elements {desc}"));
1278 }
1279 }
1280
1281 if parts.is_empty() {
1282 "no array constraints".to_string()
1283 } else {
1284 parts.join(" AND ")
1285 }
1286 }
1287 }
1288
1289 pub struct FunctionValidator<T, F> {
1291 func: F,
1292 description: String,
1293 phantom: PhantomData<T>,
1294 }
1295
1296 impl<T, F> FunctionValidator<T, F>
1297 where
1298 F: Fn(&T, &str) -> CoreResult<()>,
1299 {
1300 pub fn new(func: F, description: impl Into<String>) -> Self {
1301 Self {
1302 func,
1303 description: description.into(),
1304 phantom: PhantomData,
1305 }
1306 }
1307 }
1308
1309 impl<T, F> Validator<T> for FunctionValidator<T, F>
1310 where
1311 F: Fn(&T, &str) -> CoreResult<()>,
1312 {
1313 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1314 (self.func)(value, name)
1315 }
1316
1317 fn description(&self) -> String {
1318 self.description.clone()
1319 }
1320 }
1321
1322 pub struct ValidatorBuilder<T> {
1324 validators: Vec<Box<dyn Validator<T>>>,
1325 }
1326
1327 impl<T: 'static> ValidatorBuilder<T> {
1328 pub fn new() -> Self {
1329 Self {
1330 validators: Vec::new(),
1331 }
1332 }
1333
1334 pub fn with_validator<V: Validator<T> + 'static>(mut self, validator: V) -> Self {
1335 self.validators.push(Box::new(validator));
1336 self
1337 }
1338
1339 pub fn with_function<F>(self, func: F, description: impl Into<String>) -> Self
1340 where
1341 F: Fn(&T, &str) -> CoreResult<()> + 'static,
1342 {
1343 self.with_validator(FunctionValidator::new(func, description))
1344 }
1345
1346 pub fn build(self) -> MultiValidator<T> {
1347 MultiValidator {
1348 validators: self.validators,
1349 }
1350 }
1351 }
1352
1353 impl<T: 'static> Default for ValidatorBuilder<T> {
1354 fn default() -> Self {
1355 Self::new()
1356 }
1357 }
1358
1359 pub struct MultiValidator<T> {
1361 validators: Vec<Box<dyn Validator<T>>>,
1362 }
1363
1364 impl<T: 'static> Validator<T> for MultiValidator<T> {
1365 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1366 for validator in &self.validators {
1367 validator.validate(value, name)?;
1368 }
1369 Ok(())
1370 }
1371
1372 fn description(&self) -> String {
1373 if self.validators.is_empty() {
1374 "no validators".to_string()
1375 } else {
1376 self.validators
1377 .iter()
1378 .map(|v| v.description())
1379 .collect::<Vec<_>>()
1380 .join(" AND ")
1381 }
1382 }
1383 }
1384
1385 pub fn validate_with<T, V: Validator<T>>(
1387 value: &T,
1388 validator: &V,
1389 name: impl Into<String>,
1390 ) -> CoreResult<()> {
1391 validator.validate(value, &name.into())
1392 }
1393
1394 #[cfg(test)]
1395 mod tests {
1396 use super::*;
1397 use ndarray::arr1;
1398
1399 #[test]
1400 fn test_range_validator() {
1401 let validator = RangeValidator::in_range(0.0, 1.0);
1402
1403 assert!(validator.validate(&0.5, "value").is_ok());
1404 assert!(validator.validate(&0.0, "value").is_ok());
1405 assert!(validator.validate(&1.0, "value").is_ok());
1406 assert!(validator.validate(&-0.1, "value").is_err());
1407 assert!(validator.validate(&1.1, "value").is_err());
1408 }
1409
1410 #[test]
1411 fn test_range_validator_exclusive() {
1412 let validator = RangeValidator::in_range_exclusive(0.0, 1.0);
1413
1414 assert!(validator.validate(&0.5, "value").is_ok());
1415 assert!(validator.validate(&0.0, "value").is_err());
1416 assert!(validator.validate(&1.0, "value").is_err());
1417 }
1418
1419 #[test]
1420 fn test_composite_validator() {
1421 let positive = RangeValidator::new().min(0.0);
1422 let max_one = RangeValidator::new().max(1.0);
1423 let validator = positive.and(max_one);
1424
1425 assert!(validator.validate(&0.5, "value").is_ok());
1426 assert!(validator.validate(&-0.1, "value").is_err());
1427 assert!(validator.validate(&1.1, "value").is_err());
1428 }
1429
1430 #[test]
1431 fn test_conditional_validator() {
1432 let validator = RangeValidator::new().min(0.0).when(|x: &f64| *x > 0.0);
1433
1434 assert!(validator.validate(&0.5, "value").is_ok());
1435 assert!(validator.validate(&-0.5, "value").is_ok()); assert!(validator.validate(&0.0, "value").is_ok()); }
1438
1439 #[test]
1440 fn testarray_validator() {
1441 let element_validator = RangeValidator::in_range(0.0, 1.0);
1442 let array_validator = ArrayValidator::new()
1443 .with_elements(element_validator)
1444 .minsize(2);
1445
1446 let validarray = arr1(&[0.2, 0.8]);
1447 assert!(array_validator.validate(&validarray, "array").is_ok());
1448
1449 let invalidarray = arr1(&[0.2, 1.5]);
1450 assert!(array_validator.validate(&invalidarray, "array").is_err());
1451
1452 let too_smallarray = arr1(&[0.5]);
1453 assert!(array_validator.validate(&too_smallarray, "array").is_err());
1454 }
1455
1456 #[test]
1457 fn test_function_validator() {
1458 let validator = FunctionValidator::new(
1459 |value: &i32, name: &str| {
1460 if *value % 2 == 0 {
1461 Ok(())
1462 } else {
1463 Err(CoreError::ValueError(
1464 ErrorContext::new(format!("{name} must be even, got {value}"))
1465 .with_location(ErrorLocation::new(file!(), line!())),
1466 ))
1467 }
1468 },
1469 "value must be even",
1470 );
1471
1472 assert!(validator.validate(&4, "number").is_ok());
1473 assert!(validator.validate(&3, "number").is_err());
1474 }
1475
1476 #[test]
1477 fn test_validator_builder() {
1478 let validator = ValidatorBuilder::new()
1479 .with_validator(RangeValidator::new().min(0.0))
1480 .with_validator(RangeValidator::new().max(1.0))
1481 .with_function(
1482 |value: &f64, name: &str| {
1483 if *value != 0.5 {
1484 Ok(())
1485 } else {
1486 Err(CoreError::ValueError(
1487 ErrorContext::new(format!("{name} cannot be 0.5"))
1488 .with_location(ErrorLocation::new(file!(), line!())),
1489 ))
1490 }
1491 },
1492 "value cannot be 0.5",
1493 )
1494 .build();
1495
1496 assert!(validator.validate(&0.3, "value").is_ok());
1497 assert!(validator.validate(&0.5, "value").is_err());
1498 assert!(validator.validate(&-0.1, "value").is_err());
1499 assert!(validator.validate(&1.1, "value").is_err());
1500 }
1501 }
1502}
1503
1504pub mod production;
1506
1507pub mod cross_platform;
1509
1510#[cfg(feature = "data_validation")]
1512pub mod data;