Skip to main content

swarm_engine_core/validation/
validator.rs

1//! Validator - 学習結果の性能検証
2//!
3//! Learn Pipeline の外で使用。
4//! データ分割 → 検証 のプロセスを実行。
5
6use super::result::ValidationResult;
7use super::strategy::ValidationStrategy;
8
9/// Validator - 学習結果の性能検証
10///
11/// データを Train/Test に分割し、Test データで性能を検証する。
12/// Learn Pipeline の外で使用される。
13///
14/// # Type Parameter
15///
16/// * `T` - 検証対象のデータ型(Episode, ActionRecord 等)
17///
18/// # Example
19///
20/// ```ignore
21/// use swarm_engine_core::validation::{Validator, NoRegression};
22///
23/// let validator = Validator::new(0.8, Box::new(NoRegression::new()));
24/// let result = validator.validate(&episodes, |test| {
25///     // test データで成功率を計算
26///     compute_success_rate(test)
27/// });
28/// ```
29pub struct Validator<T> {
30    /// Train/Test 分割比率 (0.8 = 80% train, 20% test)
31    split_ratio: f64,
32    /// 検証戦略
33    strategy: Box<dyn ValidationStrategy>,
34    /// データ型マーカー
35    _marker: std::marker::PhantomData<T>,
36}
37
38impl<T> Validator<T> {
39    /// 新しい Validator を作成
40    ///
41    /// # Arguments
42    ///
43    /// * `split_ratio` - Train/Test 分割比率 (0.0-1.0)
44    /// * `strategy` - 検証戦略
45    pub fn new(split_ratio: f64, strategy: Box<dyn ValidationStrategy>) -> Self {
46        assert!(
47            (0.0..=1.0).contains(&split_ratio),
48            "split_ratio must be between 0.0 and 1.0"
49        );
50        Self {
51            split_ratio,
52            strategy,
53            _marker: std::marker::PhantomData,
54        }
55    }
56
57    /// 8:2 分割で作成
58    pub fn with_80_20_split(strategy: Box<dyn ValidationStrategy>) -> Self {
59        Self::new(0.8, strategy)
60    }
61
62    /// 7:3 分割で作成
63    pub fn with_70_30_split(strategy: Box<dyn ValidationStrategy>) -> Self {
64        Self::new(0.7, strategy)
65    }
66
67    /// 検証を実行
68    ///
69    /// # Arguments
70    ///
71    /// * `data` - 全データ
72    /// * `baseline_fn` - Train データからベースライン成績を計算する関数
73    /// * `evaluate_fn` - Test データから検証成績を計算する関数
74    ///
75    /// # Returns
76    ///
77    /// 検証結果
78    pub fn validate<F, G>(&self, data: &[T], baseline_fn: F, evaluate_fn: G) -> ValidationResult
79    where
80        F: FnOnce(&[T]) -> f64,
81        G: FnOnce(&[T]) -> f64,
82    {
83        let (train, test) = self.split(data);
84
85        let baseline = baseline_fn(train);
86        let current = evaluate_fn(test);
87
88        self.strategy.evaluate(baseline, current, test.len())
89    }
90
91    /// 検証を実行(ベースラインを外部から指定)
92    ///
93    /// Bootstrap で既に計算済みのベースラインを使用する場合。
94    pub fn validate_with_baseline<F>(
95        &self,
96        data: &[T],
97        baseline: f64,
98        evaluate_fn: F,
99    ) -> ValidationResult
100    where
101        F: FnOnce(&[T]) -> f64,
102    {
103        let (_, test) = self.split(data);
104        let current = evaluate_fn(test);
105        self.strategy.evaluate(baseline, current, test.len())
106    }
107
108    /// データを Train/Test に分割
109    fn split<'a>(&self, data: &'a [T]) -> (&'a [T], &'a [T]) {
110        let split_idx = (data.len() as f64 * self.split_ratio) as usize;
111        let split_idx = split_idx.min(data.len());
112        (&data[..split_idx], &data[split_idx..])
113    }
114
115    /// 戦略名を取得
116    pub fn strategy_name(&self) -> &str {
117        self.strategy.name()
118    }
119
120    /// 分割比率を取得
121    pub fn split_ratio(&self) -> f64 {
122        self.split_ratio
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::super::strategy::{Absolute, Improvement, NoRegression};
129    use super::*;
130
131    #[test]
132    fn test_validator_split() {
133        let data: Vec<i32> = (0..100).collect();
134        let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));
135
136        let (train, test) = validator.split(&data);
137        assert_eq!(train.len(), 80);
138        assert_eq!(test.len(), 20);
139    }
140
141    #[test]
142    fn test_validator_validate() {
143        // 100 samples: Train 80, Test 20
144        let mut data: Vec<f64> = Vec::with_capacity(100);
145        // Train: 80 samples, 64 success (80%)
146        for i in 0..80 {
147            data.push(if i < 64 { 1.0 } else { 0.0 });
148        }
149        // Test: 20 samples, 18 success (90%)
150        for i in 0..20 {
151            data.push(if i < 18 { 1.0 } else { 0.0 });
152        }
153
154        let validator = Validator::with_80_20_split(Box::new(NoRegression::new()));
155
156        let result = validator.validate(
157            &data,
158            |train| train.iter().sum::<f64>() / train.len() as f64,
159            |test| test.iter().sum::<f64>() / test.len() as f64,
160        );
161
162        assert!(result.passed);
163        assert!((result.baseline - 0.8).abs() < 0.01);
164        assert!((result.current - 0.9).abs() < 0.01);
165    }
166
167    #[test]
168    fn test_validator_with_baseline() {
169        let data: Vec<f64> = (0..100).map(|i| if i < 85 { 1.0 } else { 0.0 }).collect();
170
171        let validator = Validator::with_80_20_split(Box::new(Improvement::ten_percent()));
172
173        // 外部ベースライン 0.7 を使用
174        let result = validator.validate_with_baseline(&data, 0.7, |test| {
175            test.iter().sum::<f64>() / test.len() as f64
176        });
177
178        // Test データの 20% は 85/100 のうちの後半 20 件 = 15 success / 20 = 0.75
179        // 0.75 >= 0.7 * 1.1 = 0.77? → No, fail
180        assert!(!result.passed);
181    }
182
183    #[test]
184    fn test_validator_absolute_strategy() {
185        let data: Vec<f64> = (0..100).map(|i| if i < 90 { 1.0 } else { 0.0 }).collect();
186
187        let validator = Validator::with_80_20_split(Box::new(Absolute::eighty_percent()));
188
189        let result = validator.validate_with_baseline(&data, 0.5, |test| {
190            test.iter().sum::<f64>() / test.len() as f64
191        });
192
193        // Test: 後半 20 件 = 10 success / 20 = 0.5 < 0.8 → fail
194        assert!(!result.passed);
195    }
196
197    #[test]
198    fn test_validator_empty_data() {
199        let data: Vec<i32> = vec![];
200        let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));
201
202        let (train, test) = validator.split(&data);
203        assert!(train.is_empty());
204        assert!(test.is_empty());
205    }
206
207    #[test]
208    #[should_panic(expected = "split_ratio must be between")]
209    fn test_validator_invalid_ratio() {
210        let _: Validator<i32> = Validator::new(1.5, Box::new(NoRegression::new()));
211    }
212}