1use ::ndarray::{ArrayBase, Dimension, ScalarOperand};
7use num_traits::{Float, One, Zero};
8
9use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
10
11pub fn check_in_bounds<T, S>(value: T, min: T, max: T, name: S) -> CoreResult<T>
29where
30 T: PartialOrd + std::fmt::Display + Copy,
31 S: Into<String>,
32{
33 if value < min || value > max {
34 return Err(CoreError::ValueError(
35 ErrorContext::new(format!(
36 "{} must be between {min} and {max}, got {value}",
37 name.into()
38 ))
39 .with_location(ErrorLocation::new(file!(), line!())),
40 ));
41 }
42 Ok(value)
43}
44
45pub fn check_positive<T, S>(value: T, name: S) -> CoreResult<T>
61where
62 T: PartialOrd + std::fmt::Display + Copy + Zero,
63 S: Into<String>,
64{
65 if value <= T::zero() {
66 return Err(CoreError::ValueError(
67 ErrorContext::new({
68 let name_str = name.into();
69 format!("{name_str} must be positive, got {value}")
70 })
71 .with_location(ErrorLocation::new(file!(), line!())),
72 ));
73 }
74 Ok(value)
75}
76
77pub fn check_non_negative<T, S>(value: T, name: S) -> CoreResult<T>
93where
94 T: PartialOrd + std::fmt::Display + Copy + Zero,
95 S: Into<String>,
96{
97 if value < T::zero() {
98 return Err(CoreError::ValueError(
99 ErrorContext::new({
100 let name_str = name.into();
101 format!("{name_str} must be non-negative, got {value}")
102 })
103 .with_location(ErrorLocation::new(file!(), line!())),
104 ));
105 }
106 Ok(value)
107}
108
109pub fn check_finite<T, S>(value: T, name: S) -> CoreResult<T>
125where
126 T: Float + std::fmt::Display + Copy,
127 S: Into<String>,
128{
129 if !value.is_finite() {
130 return Err(CoreError::ValueError(
131 ErrorContext::new({
132 let name_str = name.into();
133 format!("{name_str} must be finite, got {value}")
134 })
135 .with_location(ErrorLocation::new(file!(), line!())),
136 ));
137 }
138 Ok(value)
139}
140
141pub fn checkarray_finite<S, A, D>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
157where
158 S: crate::ndarray::Data,
159 D: Dimension,
160 S::Elem: Float + std::fmt::Display,
161 A: Into<String>,
162{
163 let name = name.into();
164 for (idx, &value) in array.indexed_iter() {
165 if !value.is_finite() {
166 return Err(CoreError::ValueError(
167 ErrorContext::new(format!(
168 "{name} must contain only finite values, got {value} at {idx:?}"
169 ))
170 .with_location(ErrorLocation::new(file!(), line!())),
171 ));
172 }
173 }
174 Ok(())
175}
176
177pub fn checkshape<S, D, A>(
194 array: &ArrayBase<S, D>,
195 expectedshape: &[usize],
196 name: A,
197) -> CoreResult<()>
198where
199 S: crate::ndarray::Data,
200 D: Dimension,
201 A: Into<String>,
202{
203 let actualshape = array.shape();
204 if actualshape != expectedshape {
205 return Err(CoreError::ShapeError(
206 ErrorContext::new(format!(
207 "{} has incorrect shape: expected {expectedshape:?}, got {actualshape:?}",
208 name.into()
209 ))
210 .with_location(ErrorLocation::new(file!(), line!())),
211 ));
212 }
213 Ok(())
214}
215
216pub fn check_1d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
232where
233 S: crate::ndarray::Data,
234 D: Dimension,
235 A: Into<String>,
236{
237 if array.ndim() != 1 {
238 return Err(CoreError::ShapeError(
239 ErrorContext::new({
240 let name_str = name.into();
241 let ndim = array.ndim();
242 format!("{name_str} must be 1D, got {ndim}D")
243 })
244 .with_location(ErrorLocation::new(file!(), line!())),
245 ));
246 }
247 Ok(())
248}
249
250pub fn check_2d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
266where
267 S: crate::ndarray::Data,
268 D: Dimension,
269 A: Into<String>,
270{
271 if array.ndim() != 2 {
272 return Err(CoreError::ShapeError(
273 ErrorContext::new({
274 let name_str = name.into();
275 let ndim = array.ndim();
276 format!("{name_str} must be 2D, got {ndim}D")
277 })
278 .with_location(ErrorLocation::new(file!(), line!())),
279 ));
280 }
281 Ok(())
282}
283
284pub fn check_sameshape<S1, S2, D1, D2, A, B>(
302 a: &ArrayBase<S1, D1>,
303 a_name: A,
304 b: &ArrayBase<S2, D2>,
305 b_name: B,
306) -> CoreResult<()>
307where
308 S1: crate::ndarray::Data,
309 S2: crate::ndarray::Data,
310 D1: Dimension,
311 D2: Dimension,
312 A: Into<String>,
313 B: Into<String>,
314{
315 let ashape = a.shape();
316 let bshape = b.shape();
317 if ashape != bshape {
318 return Err(CoreError::ShapeError(
319 ErrorContext::new(format!(
320 "{} and {} must have the same shape, got {:?} and {:?}",
321 a_name.into(),
322 b_name.into(),
323 ashape,
324 bshape
325 ))
326 .with_location(ErrorLocation::new(file!(), line!())),
327 ));
328 }
329 Ok(())
330}
331
332pub fn check_square<S, D, A>(matrix: &ArrayBase<S, D>, name: A) -> CoreResult<()>
348where
349 S: crate::ndarray::Data,
350 D: Dimension,
351 A: Into<String> + std::string::ToString,
352{
353 check_2d(matrix, name.to_string())?;
354 let shape = matrix.shape();
355 if shape[0] != shape[1] {
356 return Err(CoreError::ShapeError(
357 ErrorContext::new(format!(
358 "{} must be square, got shape {:?}",
359 name.into(),
360 shape
361 ))
362 .with_location(ErrorLocation::new(file!(), line!())),
363 ));
364 }
365 Ok(())
366}
367
368pub fn check_probability<T, S>(p: T, name: S) -> CoreResult<T>
384where
385 T: Float + std::fmt::Display + Copy,
386 S: Into<String>,
387{
388 if p < T::zero() || p > T::one() {
389 return Err(CoreError::ValueError(
390 ErrorContext::new(format!(
391 "{} must be between 0 and 1, got {}",
392 name.into(),
393 p
394 ))
395 .with_location(ErrorLocation::new(file!(), line!())),
396 ));
397 }
398 Ok(p)
399}
400
401pub fn check_probabilities<S, D, A>(probs: &ArrayBase<S, D>, name: A) -> CoreResult<()>
417where
418 S: crate::ndarray::Data,
419 D: Dimension,
420 S::Elem: Float + std::fmt::Display,
421 A: Into<String>,
422{
423 let name = name.into();
424 for (idx, &p) in probs.indexed_iter() {
425 if p < S::Elem::zero() || p > S::Elem::one() {
426 return Err(CoreError::ValueError(
427 ErrorContext::new(format!(
428 "{name} must contain only values between 0 and 1, got {p} at {idx:?}"
429 ))
430 .with_location(ErrorLocation::new(file!(), line!())),
431 ));
432 }
433 }
434 Ok(())
435}
436
437pub fn check_probabilities_sum_to_one<S, D, A>(
450 probs: &ArrayBase<S, D>,
451 name: A,
452 tol: Option<S::Elem>,
453) -> CoreResult<()>
454where
455 S: crate::ndarray::Data,
456 S::Elem: Float,
457 D: Dimension,
458 S::Elem: Float + std::fmt::Display + ScalarOperand,
459 A: Into<String> + std::string::ToString,
460{
461 let tol = tol.unwrap_or_else(|| {
462 let eps: f64 = 1e-10;
463 num_traits::cast(eps).unwrap_or_else(|| {
464 S::Elem::epsilon()
466 })
467 });
468
469 check_probabilities(probs, name.to_string())?;
470
471 let sum = probs.sum();
472 let one = S::Elem::one();
473
474 if (sum - one).abs() > tol {
475 return Err(CoreError::ValueError(
476 ErrorContext::new({
477 let name_str = name.into();
478 format!("{name_str} must sum to 1, got sum = {sum}")
479 })
480 .with_location(ErrorLocation::new(file!(), line!())),
481 ));
482 }
483
484 Ok(())
485}
486
487pub fn check_not_empty<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
499where
500 S: crate::ndarray::Data,
501 D: Dimension,
502 A: Into<String>,
503{
504 if array.is_empty() {
505 return Err(CoreError::ValueError(
506 ErrorContext::new({
507 let name_str = name.into();
508 format!("{name_str} cannot be empty")
509 })
510 .with_location(ErrorLocation::new(file!(), line!())),
511 ));
512 }
513 Ok(())
514}
515
516pub fn check_min_samples<S, D, A>(
529 array: &ArrayBase<S, D>,
530 min_samples: usize,
531 name: A,
532) -> CoreResult<()>
533where
534 S: crate::ndarray::Data,
535 D: Dimension,
536 A: Into<String>,
537{
538 let n_samples = array.shape()[0];
539 if n_samples < min_samples {
540 return Err(CoreError::ValueError(
541 ErrorContext::new(format!(
542 "{} must have at least {} samples, got {}",
543 name.into(),
544 min_samples,
545 n_samples
546 ))
547 .with_location(ErrorLocation::new(file!(), line!())),
548 ));
549 }
550 Ok(())
551}
552
553pub mod clustering {
555 use super::*;
556
557 pub fn check_n_clusters_bounds<S, D>(
570 data: &ArrayBase<S, D>,
571 n_clusters: usize,
572 operation: &str,
573 ) -> CoreResult<()>
574 where
575 S: crate::ndarray::Data,
576 D: Dimension,
577 {
578 let n_samples = data.shape()[0];
579
580 if n_clusters == 0 {
581 return Err(CoreError::ValueError(
582 ErrorContext::new(format!(
583 "{operation}: number of _clusters must be > 0, got {n_clusters}"
584 ))
585 .with_location(ErrorLocation::new(file!(), line!())),
586 ));
587 }
588
589 if n_clusters > n_samples {
590 return Err(CoreError::ValueError(
591 ErrorContext::new(format!(
592 "{operation}: number of _clusters ({n_clusters}) cannot exceed number of samples ({n_samples})"
593 ))
594 .with_location(ErrorLocation::new(file!(), line!())),
595 ));
596 }
597
598 Ok(())
599 }
600
601 pub fn validate_clustering_data<S, D>(
615 data: &ArrayBase<S, D>,
616 _operation: &str,
617 check_finite: bool,
618 min_samples: Option<usize>,
619 ) -> CoreResult<()>
620 where
621 S: crate::ndarray::Data,
622 D: Dimension,
623 S::Elem: Float + std::fmt::Display,
624 {
625 check_not_empty(data, "data")?;
627
628 check_2d(data, "data")?;
630
631 if let Some(min) = min_samples {
633 check_min_samples(data, min, "data")?;
634 }
635
636 if check_finite {
638 checkarray_finite(data, "data")?;
639 }
640
641 Ok(())
642 }
643}
644
645pub mod parameters {
647 use super::*;
648
649 pub fn check_iteration_params<T>(
662 max_iter: usize,
663 tolerance: T,
664 operation: &str,
665 ) -> CoreResult<()>
666 where
667 T: Float + std::fmt::Display + Copy,
668 {
669 if max_iter == 0 {
670 return Err(CoreError::ValueError(
671 ErrorContext::new(format!("{operation}: max_iter must be > 0, got {max_iter}"))
672 .with_location(ErrorLocation::new(file!(), line!())),
673 ));
674 }
675
676 check_positive(tolerance, format!("{operation} tolerance"))?;
677
678 Ok(())
679 }
680
681 pub fn check_unit_interval<T>(value: T, name: &str, operation: &str) -> CoreResult<T>
694 where
695 T: Float + std::fmt::Display + Copy,
696 {
697 if value < T::zero() || value > T::one() {
698 return Err(CoreError::ValueError(
699 ErrorContext::new(format!(
700 "{operation}: {name} must be in [0, 1], got {value}"
701 ))
702 .with_location(ErrorLocation::new(file!(), line!())),
703 ));
704 }
705 Ok(value)
706 }
707
708 pub fn checkbandwidth<T>(bandwidth: T, operation: &str) -> CoreResult<T>
720 where
721 T: Float + std::fmt::Display + Copy,
722 {
723 check_positive(bandwidth, format!("{operation} bandwidth"))
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730 use ndarray::{arr1, arr2};
731
732 #[test]
733 fn test_check_in_bounds() {
734 assert!(check_in_bounds(5, 0, 10, "param").is_ok());
735 assert!(check_in_bounds(0, 0, 10, "param").is_ok());
736 assert!(check_in_bounds(10, 0, 10, "param").is_ok());
737 assert!(check_in_bounds(-1, 0, 10, "param").is_err());
738 assert!(check_in_bounds(11, 0, 10, "param").is_err());
739 }
740
741 #[test]
742 fn test_check_positive() {
743 assert!(check_positive(5, "param").is_ok());
744 assert!(check_positive(0.1, "param").is_ok());
745 assert!(check_positive(0, "param").is_err());
746 assert!(check_positive(-1, "param").is_err());
747 }
748
749 #[test]
750 fn test_check_non_negative() {
751 assert!(check_non_negative(5, "param").is_ok());
752 assert!(check_non_negative(0, "param").is_ok());
753 assert!(check_non_negative(-0.1, "param").is_err());
754 assert!(check_non_negative(-1, "param").is_err());
755 }
756
757 #[test]
758 fn test_check_finite() {
759 assert!(check_finite(5.0, "param").is_ok());
760 assert!(check_finite(0.0, "param").is_ok());
761 assert!(check_finite(-1.0, "param").is_ok());
762 assert!(check_finite(f64::INFINITY, "param").is_err());
763 assert!(check_finite(f64::NEG_INFINITY, "param").is_err());
764 assert!(check_finite(f64::NAN, "param").is_err());
765 }
766
767 #[test]
768 fn test_checkarray_finite() {
769 let a = arr1(&[1.0, 2.0, 3.0]);
770 assert!(checkarray_finite(&a, "array").is_ok());
771
772 let b = arr1(&[1.0, f64::INFINITY, 3.0]);
773 assert!(checkarray_finite(&b, "array").is_err());
774
775 let c = arr1(&[1.0, f64::NAN, 3.0]);
776 assert!(checkarray_finite(&c, "array").is_err());
777 }
778
779 #[test]
780 fn test_checkshape() {
781 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
782 assert!(checkshape(&a, &[2, 2], "array").is_ok());
783 assert!(checkshape(&a, &[2, 3], "array").is_err());
784 }
785
786 #[test]
787 fn test_check_1d() {
788 let a = arr1(&[1.0, 2.0, 3.0]);
789 assert!(check_1d(&a, "array").is_ok());
790
791 let b = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
792 assert!(check_1d(&b, "array").is_err());
793 }
794
795 #[test]
796 fn test_check_2d() {
797 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
798 assert!(check_2d(&a, "array").is_ok());
799
800 let b = arr1(&[1.0, 2.0, 3.0]);
801 assert!(check_2d(&b, "array").is_err());
802 }
803
804 #[test]
805 fn test_check_sameshape() {
806 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
807 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
808 assert!(check_sameshape(&a, "a", &b, "b").is_ok());
809
810 let c = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
811 assert!(check_sameshape(&a, "a", &c, "c").is_err());
812 }
813
814 #[test]
815 fn test_check_square() {
816 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
817 assert!(check_square(&a, "matrix").is_ok());
818
819 let b = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
820 assert!(check_square(&b, "matrix").is_err());
821
822 let c = arr1(&[1.0, 2.0, 3.0]);
823 assert!(check_square(&c, "matrix").is_err());
824 }
825
826 #[test]
827 fn test_check_probability() {
828 assert!(check_probability(0.0, "p").is_ok());
829 assert!(check_probability(0.5, "p").is_ok());
830 assert!(check_probability(1.0, "p").is_ok());
831 assert!(check_probability(-0.1, "p").is_err());
832 assert!(check_probability(1.1, "p").is_err());
833 }
834
835 #[test]
836 fn test_check_probabilities() {
837 let a = arr1(&[0.0, 0.5, 1.0]);
838 assert!(check_probabilities(&a, "probs").is_ok());
839
840 let b = arr1(&[0.0, 0.5, 1.1]);
841 assert!(check_probabilities(&b, "probs").is_err());
842
843 let c = arr1(&[-0.1, 0.5, 1.0]);
844 assert!(check_probabilities(&c, "probs").is_err());
845 }
846
847 #[test]
848 fn test_check_probabilities_sum_to_one() {
849 let a = arr1(&[0.3, 0.2, 0.5]);
850 assert!(check_probabilities_sum_to_one(&a, "probs", None).is_ok());
851
852 let b = arr1(&[0.3, 0.2, 0.6]);
853 assert!(check_probabilities_sum_to_one(&b, "probs", None).is_err());
854
855 let c = arr1(&[0.3, 0.2, 0.501]);
857 assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.01)).is_ok());
858 assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.0001)).is_err());
859 }
860
861 #[test]
862 fn test_check_not_empty() {
863 let a = arr1(&[1.0, 2.0, 3.0]);
864 assert!(check_not_empty(&a, "array").is_ok());
865
866 let b = arr1(&[] as &[f64]);
867 assert!(check_not_empty(&b, "array").is_err());
868 }
869
870 #[test]
871 fn test_check_min_samples() {
872 let a = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
873 assert!(check_min_samples(&a, 2, "array").is_ok());
874 assert!(check_min_samples(&a, 3, "array").is_ok());
875 assert!(check_min_samples(&a, 4, "array").is_err());
876 }
877
878 mod clustering_tests {
879 use super::*;
880 use crate::validation::clustering::*;
881
882 #[test]
883 fn test_check_n_clusters_bounds() {
884 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
885
886 assert!(check_n_clusters_bounds(&data, 1, "test").is_ok());
887 assert!(check_n_clusters_bounds(&data, 2, "test").is_ok());
888 assert!(check_n_clusters_bounds(&data, 3, "test").is_ok());
889 assert!(check_n_clusters_bounds(&data, 0, "test").is_err());
890 assert!(check_n_clusters_bounds(&data, 4, "test").is_err());
891 }
892
893 #[test]
894 fn test_validate_clustering_data() {
895 let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
896 assert!(validate_clustering_data(&data, "test", true, Some(2)).is_ok());
897 assert!(validate_clustering_data(&data, "test", true, Some(4)).is_err());
898
899 let empty_data = arr2(&[] as &[[f64; 2]; 0]);
900 assert!(validate_clustering_data(&empty_data, "test", true, None).is_err());
901
902 let inf_data = arr2(&[[1.0, f64::INFINITY], [3.0, 4.0]]);
903 assert!(validate_clustering_data(&inf_data, "test", true, None).is_err());
904 assert!(validate_clustering_data(&inf_data, "test", false, None).is_ok());
905 }
906 }
907
908 mod parameters_tests {
909 use crate::validation::parameters::*;
910
911 #[test]
912 fn test_check_iteration_params() {
913 assert!(check_iteration_params(100, 1e-6, "test").is_ok());
914 assert!(check_iteration_params(0, 1e-6, "test").is_err());
915 assert!(check_iteration_params(100, 0.0, "test").is_err());
916 assert!(check_iteration_params(100, -1e-6, "test").is_err());
917 }
918
919 #[test]
920 fn test_check_unit_interval() {
921 assert!(check_unit_interval(0.0, "param", "test").is_ok());
922 assert!(check_unit_interval(0.5, "param", "test").is_ok());
923 assert!(check_unit_interval(1.0, "param", "test").is_ok());
924 assert!(check_unit_interval(-0.1, "param", "test").is_err());
925 assert!(check_unit_interval(1.1, "param", "test").is_err());
926 }
927
928 #[test]
929 fn test_checkbandwidth() {
930 assert!(checkbandwidth(1.0, "test").is_ok());
931 assert!(checkbandwidth(0.1, "test").is_ok());
932 assert!(checkbandwidth(0.0, "test").is_err());
933 assert!(checkbandwidth(-1.0, "test").is_err());
934 }
935 }
936}
937
938pub mod custom {
940 use super::*;
941 use std::fmt;
942 use std::marker::PhantomData;
943
944 pub trait Validator<T> {
946 fn validate(&self, value: &T, name: &str) -> CoreResult<()>;
948
949 fn description(&self) -> String;
951
952 fn and<V: Validator<T>>(self, other: V) -> CompositeValidator<T, Self, V>
954 where
955 Self: Sized,
956 {
957 CompositeValidator::new(self, other)
958 }
959
960 fn when<F>(self, condition: F) -> ConditionalValidator<T, Self, F>
962 where
963 Self: Sized,
964 F: Fn(&T) -> bool,
965 {
966 ConditionalValidator::new(self, condition)
967 }
968 }
969
970 pub struct CompositeValidator<T, V1, V2> {
972 validator1: V1,
973 validator2: V2,
974 _phantom: PhantomData<T>,
975 }
976
977 impl<T, V1, V2> CompositeValidator<T, V1, V2> {
978 pub fn new(validator1: V1, validator2: V2) -> Self {
979 Self {
980 validator1,
981 validator2,
982 _phantom: PhantomData,
983 }
984 }
985 }
986
987 impl<T, V1, V2> Validator<T> for CompositeValidator<T, V1, V2>
988 where
989 V1: Validator<T>,
990 V2: Validator<T>,
991 {
992 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
993 self.validator1.validate(value, name)?;
994 self.validator2.validate(value, name)?;
995 Ok(())
996 }
997
998 fn description(&self) -> String {
999 format!(
1000 "{} AND {}",
1001 self.validator1.description(),
1002 self.validator2.description()
1003 )
1004 }
1005 }
1006
1007 pub struct ConditionalValidator<T, V, F> {
1009 validator: V,
1010 condition: F,
1011 phantom: PhantomData<T>,
1012 }
1013
1014 impl<T, V, F> ConditionalValidator<T, V, F> {
1015 pub fn new(validator: V, condition: F) -> Self {
1016 Self {
1017 validator,
1018 condition,
1019 phantom: PhantomData,
1020 }
1021 }
1022 }
1023
1024 impl<T, V, F> Validator<T> for ConditionalValidator<T, V, F>
1025 where
1026 V: Validator<T>,
1027 F: Fn(&T) -> bool,
1028 {
1029 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1030 if (self.condition)(value) {
1031 self.validator.validate(value, name)
1032 } else {
1033 Ok(())
1034 }
1035 }
1036
1037 fn description(&self) -> String {
1038 {
1039 let desc = self.validator.description();
1040 format!("IF condition THEN {desc}")
1041 }
1042 }
1043 }
1044
1045 pub struct RangeValidator<T> {
1047 min: Option<T>,
1048 max: Option<T>,
1049 min_inclusive: bool,
1050 max_inclusive: bool,
1051 }
1052
1053 impl<T> RangeValidator<T>
1054 where
1055 T: PartialOrd + Copy + fmt::Display,
1056 {
1057 pub fn new() -> Self {
1058 Self {
1059 min: None,
1060 max: None,
1061 min_inclusive: true,
1062 max_inclusive: true,
1063 }
1064 }
1065
1066 pub fn min(mut self, min: T) -> Self {
1067 self.min = Some(min);
1068 self
1069 }
1070
1071 pub fn max(mut self, max: T) -> Self {
1072 self.max = Some(max);
1073 self
1074 }
1075
1076 pub fn min_exclusive(mut self, min: T) -> Self {
1077 self.min = Some(min);
1078 self.min_inclusive = false;
1079 self
1080 }
1081
1082 pub fn max_exclusive(mut self, max: T) -> Self {
1083 self.max = Some(max);
1084 self.max_inclusive = false;
1085 self
1086 }
1087
1088 pub fn in_range(min: T, max: T) -> Self {
1089 Self::new().min(min).max(max)
1090 }
1091
1092 pub fn in_range_exclusive(min: T, max: T) -> Self {
1093 Self::new().min_exclusive(min).max_exclusive(max)
1094 }
1095 }
1096
1097 impl<T> Default for RangeValidator<T>
1098 where
1099 T: PartialOrd + Copy + fmt::Display,
1100 {
1101 fn default() -> Self {
1102 Self::new()
1103 }
1104 }
1105
1106 impl<T> Validator<T> for RangeValidator<T>
1107 where
1108 T: PartialOrd + Copy + fmt::Display,
1109 {
1110 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1111 if let Some(min) = self.min {
1112 let valid = if self.min_inclusive {
1113 *value >= min
1114 } else {
1115 *value > min
1116 };
1117 if !valid {
1118 let op = if self.min_inclusive { ">=" } else { ">" };
1119 return Err(CoreError::ValueError(
1120 ErrorContext::new(format!("{name} must be {op} {min}, got {value}"))
1121 .with_location(ErrorLocation::new(file!(), line!())),
1122 ));
1123 }
1124 }
1125
1126 if let Some(max) = self.max {
1127 let valid = if self.max_inclusive {
1128 *value <= max
1129 } else {
1130 *value < max
1131 };
1132 if !valid {
1133 let op = if self.max_inclusive { "<=" } else { "<" };
1134 return Err(CoreError::ValueError(
1135 ErrorContext::new(format!("{name} must be {op} {max}, got {value}"))
1136 .with_location(ErrorLocation::new(file!(), line!())),
1137 ));
1138 }
1139 }
1140
1141 Ok(())
1142 }
1143
1144 fn description(&self) -> String {
1145 match (self.min, self.max) {
1146 (Some(min), Some(max)) => {
1147 let min_op = if self.min_inclusive { ">=" } else { ">" };
1148 let max_op = if self.max_inclusive { "<=" } else { "<" };
1149 format!("value {min_op} {min} and {max_op} {max}")
1150 }
1151 (Some(min), None) => {
1152 let op = if self.min_inclusive { ">=" } else { ">" };
1153 format!("value {op} {min}")
1154 }
1155 (None, Some(max)) => {
1156 let op = if self.max_inclusive { "<=" } else { "<" };
1157 format!("value {op} {max}")
1158 }
1159 (None, None) => "no range constraints".to_string(),
1160 }
1161 }
1162 }
1163
1164 type ShapeValidatorFn = Box<dyn Fn(&[usize]) -> CoreResult<()>>;
1166
1167 pub struct ArrayValidator<T, D>
1169 where
1170 D: Dimension,
1171 {
1172 shape_validator: Option<ShapeValidatorFn>,
1173 element_validator: Option<Box<dyn Validator<T>>>,
1174 size_validator: Option<RangeValidator<usize>>,
1175 phantom: PhantomData<D>,
1176 }
1177
1178 impl<T, D> ArrayValidator<T, D>
1179 where
1180 D: Dimension,
1181 {
1182 pub fn new() -> Self {
1183 Self {
1184 shape_validator: None,
1185 element_validator: None,
1186 size_validator: None,
1187 phantom: PhantomData,
1188 }
1189 }
1190
1191 pub fn withshape<F>(mut self, validator: F) -> Self
1192 where
1193 F: Fn(&[usize]) -> CoreResult<()> + 'static,
1194 {
1195 self.shape_validator = Some(Box::new(validator));
1196 self
1197 }
1198
1199 pub fn with_elements<V>(mut self, validator: V) -> Self
1200 where
1201 V: Validator<T> + 'static,
1202 {
1203 self.element_validator = Some(Box::new(validator));
1204 self
1205 }
1206
1207 pub fn with_size(mut self, validator: RangeValidator<usize>) -> Self {
1208 self.size_validator = Some(validator);
1209 self
1210 }
1211
1212 pub fn minsize(self, minsize: usize) -> Self {
1213 self.with_size(RangeValidator::new().min(minsize))
1214 }
1215
1216 pub fn maxsize(self, maxsize: usize) -> Self {
1217 self.with_size(RangeValidator::new().max(maxsize))
1218 }
1219
1220 pub fn exact_size(self, size: usize) -> Self {
1221 self.with_size(RangeValidator::new().min(size).max(size))
1222 }
1223 }
1224
1225 impl<T, D> Default for ArrayValidator<T, D>
1226 where
1227 D: Dimension,
1228 {
1229 fn default() -> Self {
1230 Self::new()
1231 }
1232 }
1233
1234 impl<S, T, D> Validator<ArrayBase<S, D>> for ArrayValidator<T, D>
1235 where
1236 S: crate::ndarray::Data<Elem = T>,
1237 T: Clone,
1238 D: Dimension,
1239 {
1240 fn validate(&self, array: &ArrayBase<S, D>, name: &str) -> CoreResult<()> {
1241 if let Some(ref shape_validator) = self.shape_validator {
1243 shape_validator(array.shape())?;
1244 }
1245
1246 if let Some(ref size_validator) = self.size_validator {
1248 size_validator.validate(&array.len(), &format!("{name} size"))?;
1249 }
1250
1251 if let Some(ref element_validator) = self.element_validator {
1253 for (idx, element) in array.indexed_iter() {
1254 element_validator.validate(element, &format!("{name} element at {idx:?}"))?;
1255 }
1256 }
1257
1258 Ok(())
1259 }
1260
1261 fn description(&self) -> String {
1262 let mut parts = Vec::new();
1263
1264 if self.shape_validator.is_some() {
1265 parts.push("shape validation".to_string());
1266 }
1267
1268 if let Some(ref size_validator) = self.size_validator {
1269 {
1270 let desc = size_validator.description();
1271 parts.push(format!("size {desc}"));
1272 }
1273 }
1274
1275 if let Some(ref element_validator) = self.element_validator {
1276 {
1277 let desc = element_validator.description();
1278 parts.push(format!("elements {desc}"));
1279 }
1280 }
1281
1282 if parts.is_empty() {
1283 "no array constraints".to_string()
1284 } else {
1285 parts.join(" AND ")
1286 }
1287 }
1288 }
1289
1290 pub struct FunctionValidator<T, F> {
1292 func: F,
1293 description: String,
1294 phantom: PhantomData<T>,
1295 }
1296
1297 impl<T, F> FunctionValidator<T, F>
1298 where
1299 F: Fn(&T, &str) -> CoreResult<()>,
1300 {
1301 pub fn new(func: F, description: impl Into<String>) -> Self {
1302 Self {
1303 func,
1304 description: description.into(),
1305 phantom: PhantomData,
1306 }
1307 }
1308 }
1309
1310 impl<T, F> Validator<T> for FunctionValidator<T, F>
1311 where
1312 F: Fn(&T, &str) -> CoreResult<()>,
1313 {
1314 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1315 (self.func)(value, name)
1316 }
1317
1318 fn description(&self) -> String {
1319 self.description.clone()
1320 }
1321 }
1322
1323 pub struct ValidatorBuilder<T> {
1325 validators: Vec<Box<dyn Validator<T>>>,
1326 }
1327
1328 impl<T: 'static> ValidatorBuilder<T> {
1329 pub fn new() -> Self {
1330 Self {
1331 validators: Vec::new(),
1332 }
1333 }
1334
1335 pub fn with_validator<V: Validator<T> + 'static>(mut self, validator: V) -> Self {
1336 self.validators.push(Box::new(validator));
1337 self
1338 }
1339
1340 pub fn with_function<F>(self, func: F, description: impl Into<String>) -> Self
1341 where
1342 F: Fn(&T, &str) -> CoreResult<()> + 'static,
1343 {
1344 self.with_validator(FunctionValidator::new(func, description))
1345 }
1346
1347 pub fn build(self) -> MultiValidator<T> {
1348 MultiValidator {
1349 validators: self.validators,
1350 }
1351 }
1352 }
1353
1354 impl<T: 'static> Default for ValidatorBuilder<T> {
1355 fn default() -> Self {
1356 Self::new()
1357 }
1358 }
1359
1360 pub struct MultiValidator<T> {
1362 validators: Vec<Box<dyn Validator<T>>>,
1363 }
1364
1365 impl<T: 'static> Validator<T> for MultiValidator<T> {
1366 fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1367 for validator in &self.validators {
1368 validator.validate(value, name)?;
1369 }
1370 Ok(())
1371 }
1372
1373 fn description(&self) -> String {
1374 if self.validators.is_empty() {
1375 "no validators".to_string()
1376 } else {
1377 self.validators
1378 .iter()
1379 .map(|v| v.description())
1380 .collect::<Vec<_>>()
1381 .join(" AND ")
1382 }
1383 }
1384 }
1385
1386 pub fn validate_with<T, V: Validator<T>>(
1388 value: &T,
1389 validator: &V,
1390 name: impl Into<String>,
1391 ) -> CoreResult<()> {
1392 validator.validate(value, &name.into())
1393 }
1394
1395 #[cfg(test)]
1396 mod tests {
1397 use super::*;
1398 use ::ndarray::arr1;
1399
1400 #[test]
1401 fn test_range_validator() {
1402 let validator = RangeValidator::in_range(0.0, 1.0);
1403
1404 assert!(validator.validate(&0.5, "value").is_ok());
1405 assert!(validator.validate(&0.0, "value").is_ok());
1406 assert!(validator.validate(&1.0, "value").is_ok());
1407 assert!(validator.validate(&-0.1, "value").is_err());
1408 assert!(validator.validate(&1.1, "value").is_err());
1409 }
1410
1411 #[test]
1412 fn test_range_validator_exclusive() {
1413 let validator = RangeValidator::in_range_exclusive(0.0, 1.0);
1414
1415 assert!(validator.validate(&0.5, "value").is_ok());
1416 assert!(validator.validate(&0.0, "value").is_err());
1417 assert!(validator.validate(&1.0, "value").is_err());
1418 }
1419
1420 #[test]
1421 fn test_composite_validator() {
1422 let positive = RangeValidator::new().min(0.0);
1423 let max_one = RangeValidator::new().max(1.0);
1424 let validator = positive.and(max_one);
1425
1426 assert!(validator.validate(&0.5, "value").is_ok());
1427 assert!(validator.validate(&-0.1, "value").is_err());
1428 assert!(validator.validate(&1.1, "value").is_err());
1429 }
1430
1431 #[test]
1432 fn test_conditional_validator() {
1433 let validator = RangeValidator::new().min(0.0).when(|x: &f64| *x > 0.0);
1434
1435 assert!(validator.validate(&0.5, "value").is_ok());
1436 assert!(validator.validate(&-0.5, "value").is_ok()); assert!(validator.validate(&0.0, "value").is_ok()); }
1439
1440 #[test]
1441 fn testarray_validator() {
1442 let element_validator = RangeValidator::in_range(0.0, 1.0);
1443 let array_validator = ArrayValidator::new()
1444 .with_elements(element_validator)
1445 .minsize(2);
1446
1447 let validarray = arr1(&[0.2, 0.8]);
1448 assert!(array_validator.validate(&validarray, "array").is_ok());
1449
1450 let invalidarray = arr1(&[0.2, 1.5]);
1451 assert!(array_validator.validate(&invalidarray, "array").is_err());
1452
1453 let too_smallarray = arr1(&[0.5]);
1454 assert!(array_validator.validate(&too_smallarray, "array").is_err());
1455 }
1456
1457 #[test]
1458 fn test_function_validator() {
1459 let validator = FunctionValidator::new(
1460 |value: &i32, name: &str| {
1461 if *value % 2 == 0 {
1462 Ok(())
1463 } else {
1464 Err(CoreError::ValueError(
1465 ErrorContext::new(format!("{name} must be even, got {value}"))
1466 .with_location(ErrorLocation::new(file!(), line!())),
1467 ))
1468 }
1469 },
1470 "value must be even",
1471 );
1472
1473 assert!(validator.validate(&4, "number").is_ok());
1474 assert!(validator.validate(&3, "number").is_err());
1475 }
1476
1477 #[test]
1478 fn test_validator_builder() {
1479 let validator = ValidatorBuilder::new()
1480 .with_validator(RangeValidator::new().min(0.0))
1481 .with_validator(RangeValidator::new().max(1.0))
1482 .with_function(
1483 |value: &f64, name: &str| {
1484 if *value != 0.5 {
1485 Ok(())
1486 } else {
1487 Err(CoreError::ValueError(
1488 ErrorContext::new(format!("{name} cannot be 0.5"))
1489 .with_location(ErrorLocation::new(file!(), line!())),
1490 ))
1491 }
1492 },
1493 "value cannot be 0.5",
1494 )
1495 .build();
1496
1497 assert!(validator.validate(&0.3, "value").is_ok());
1498 assert!(validator.validate(&0.5, "value").is_err());
1499 assert!(validator.validate(&-0.1, "value").is_err());
1500 assert!(validator.validate(&1.1, "value").is_err());
1501 }
1502 }
1503}
1504
1505pub mod production;
1507
1508pub mod cross_platform;
1510
1511#[cfg(feature = "data_validation")]
1513pub mod data;