swarm_engine_core/validation/
validator.rs1use super::result::ValidationResult;
7use super::strategy::ValidationStrategy;
8
9pub struct Validator<T> {
30 split_ratio: f64,
32 strategy: Box<dyn ValidationStrategy>,
34 _marker: std::marker::PhantomData<T>,
36}
37
38impl<T> Validator<T> {
39 pub fn new(split_ratio: f64, strategy: Box<dyn ValidationStrategy>) -> Self {
46 assert!(
47 (0.0..=1.0).contains(&split_ratio),
48 "split_ratio must be between 0.0 and 1.0"
49 );
50 Self {
51 split_ratio,
52 strategy,
53 _marker: std::marker::PhantomData,
54 }
55 }
56
57 pub fn with_80_20_split(strategy: Box<dyn ValidationStrategy>) -> Self {
59 Self::new(0.8, strategy)
60 }
61
62 pub fn with_70_30_split(strategy: Box<dyn ValidationStrategy>) -> Self {
64 Self::new(0.7, strategy)
65 }
66
67 pub fn validate<F, G>(&self, data: &[T], baseline_fn: F, evaluate_fn: G) -> ValidationResult
79 where
80 F: FnOnce(&[T]) -> f64,
81 G: FnOnce(&[T]) -> f64,
82 {
83 let (train, test) = self.split(data);
84
85 let baseline = baseline_fn(train);
86 let current = evaluate_fn(test);
87
88 self.strategy.evaluate(baseline, current, test.len())
89 }
90
91 pub fn validate_with_baseline<F>(
95 &self,
96 data: &[T],
97 baseline: f64,
98 evaluate_fn: F,
99 ) -> ValidationResult
100 where
101 F: FnOnce(&[T]) -> f64,
102 {
103 let (_, test) = self.split(data);
104 let current = evaluate_fn(test);
105 self.strategy.evaluate(baseline, current, test.len())
106 }
107
108 fn split<'a>(&self, data: &'a [T]) -> (&'a [T], &'a [T]) {
110 let split_idx = (data.len() as f64 * self.split_ratio) as usize;
111 let split_idx = split_idx.min(data.len());
112 (&data[..split_idx], &data[split_idx..])
113 }
114
115 pub fn strategy_name(&self) -> &str {
117 self.strategy.name()
118 }
119
120 pub fn split_ratio(&self) -> f64 {
122 self.split_ratio
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::super::strategy::{Absolute, Improvement, NoRegression};
129 use super::*;
130
131 #[test]
132 fn test_validator_split() {
133 let data: Vec<i32> = (0..100).collect();
134 let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));
135
136 let (train, test) = validator.split(&data);
137 assert_eq!(train.len(), 80);
138 assert_eq!(test.len(), 20);
139 }
140
141 #[test]
142 fn test_validator_validate() {
143 let mut data: Vec<f64> = Vec::with_capacity(100);
145 for i in 0..80 {
147 data.push(if i < 64 { 1.0 } else { 0.0 });
148 }
149 for i in 0..20 {
151 data.push(if i < 18 { 1.0 } else { 0.0 });
152 }
153
154 let validator = Validator::with_80_20_split(Box::new(NoRegression::new()));
155
156 let result = validator.validate(
157 &data,
158 |train| train.iter().sum::<f64>() / train.len() as f64,
159 |test| test.iter().sum::<f64>() / test.len() as f64,
160 );
161
162 assert!(result.passed);
163 assert!((result.baseline - 0.8).abs() < 0.01);
164 assert!((result.current - 0.9).abs() < 0.01);
165 }
166
167 #[test]
168 fn test_validator_with_baseline() {
169 let data: Vec<f64> = (0..100).map(|i| if i < 85 { 1.0 } else { 0.0 }).collect();
170
171 let validator = Validator::with_80_20_split(Box::new(Improvement::ten_percent()));
172
173 let result = validator.validate_with_baseline(&data, 0.7, |test| {
175 test.iter().sum::<f64>() / test.len() as f64
176 });
177
178 assert!(!result.passed);
181 }
182
183 #[test]
184 fn test_validator_absolute_strategy() {
185 let data: Vec<f64> = (0..100).map(|i| if i < 90 { 1.0 } else { 0.0 }).collect();
186
187 let validator = Validator::with_80_20_split(Box::new(Absolute::eighty_percent()));
188
189 let result = validator.validate_with_baseline(&data, 0.5, |test| {
190 test.iter().sum::<f64>() / test.len() as f64
191 });
192
193 assert!(!result.passed);
195 }
196
197 #[test]
198 fn test_validator_empty_data() {
199 let data: Vec<i32> = vec![];
200 let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));
201
202 let (train, test) = validator.split(&data);
203 assert!(train.is_empty());
204 assert!(test.is_empty());
205 }
206
207 #[test]
208 #[should_panic(expected = "split_ratio must be between")]
209 fn test_validator_invalid_ratio() {
210 let _: Validator<i32> = Validator::new(1.5, Box::new(NoRegression::new()));
211 }
212}