sklears_dummy/validation/
validation_core.rs1use scirs2_core::ndarray::Array1;
2use scirs2_core::random::{thread_rng, RngCore, SeedableRng};
3use sklears_core::error::{Result, SklearsError};
4use sklears_core::types::Float;
5use std::cmp::Ordering;
6
7#[derive(Debug, Clone)]
9pub struct DummyValidationResult {
10 pub mean_score: Float,
12 pub std_score: Float,
14 pub fold_scores: Vec<Float>,
16 pub strategy: String,
18}
19
20impl DummyValidationResult {
21 pub fn new(
22 mean_score: Float,
23 std_score: Float,
24 fold_scores: Vec<Float>,
25 strategy: String,
26 ) -> Self {
27 Self {
28 mean_score,
29 std_score,
30 fold_scores,
31 strategy,
32 }
33 }
34
35 pub fn confidence_interval(&self, confidence_level: Float) -> (Float, Float) {
36 let n = self.fold_scores.len() as Float;
37 let sem = self.std_score / n.sqrt();
38
39 let t_value = match confidence_level {
41 0.90 => 1.645,
42 0.95 => 1.96,
43 0.99 => 2.576,
44 _ => 1.96, };
46
47 let margin = t_value * sem;
48 (self.mean_score - margin, self.mean_score + margin)
49 }
50
51 pub fn is_significantly_better_than(
52 &self,
53 other: &DummyValidationResult,
54 alpha: Float,
55 ) -> bool {
56 let pooled_std = ((self.std_score.powi(2) + other.std_score.powi(2)) / 2.0).sqrt();
58 let n1 = self.fold_scores.len() as Float;
59 let n2 = other.fold_scores.len() as Float;
60 let se_diff = pooled_std * ((1.0 / n1) + (1.0 / n2)).sqrt();
61
62 if se_diff == 0.0 {
63 return self.mean_score > other.mean_score;
64 }
65
66 let t_stat = (self.mean_score - other.mean_score) / se_diff;
67 let t_critical = match alpha {
68 0.01 => 2.576,
69 0.05 => 1.96,
70 0.10 => 1.645,
71 _ => 1.96,
72 };
73
74 t_stat > t_critical
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct ValidationConfig {
81 pub cv_folds: usize,
83 pub random_state: Option<u64>,
85 pub shuffle: bool,
87 pub stratify: bool,
89 pub scoring_metric: String,
91 pub bootstrap_samples: usize,
93 pub confidence_level: Float,
95}
96
97impl Default for ValidationConfig {
98 fn default() -> Self {
99 Self {
100 cv_folds: 5,
101 random_state: None,
102 shuffle: true,
103 stratify: false,
104 scoring_metric: "accuracy".to_string(),
105 bootstrap_samples: 1000,
106 confidence_level: 0.95,
107 }
108 }
109}
110
111impl ValidationConfig {
112 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn cv_folds(mut self, folds: usize) -> Self {
117 self.cv_folds = folds;
118 self
119 }
120
121 pub fn random_state(mut self, seed: u64) -> Self {
122 self.random_state = Some(seed);
123 self
124 }
125
126 pub fn shuffle(mut self, shuffle: bool) -> Self {
127 self.shuffle = shuffle;
128 self
129 }
130
131 pub fn stratify(mut self, stratify: bool) -> Self {
132 self.stratify = stratify;
133 self
134 }
135
136 pub fn scoring_metric(mut self, metric: String) -> Self {
137 self.scoring_metric = metric;
138 self
139 }
140
141 pub fn bootstrap_samples(mut self, samples: usize) -> Self {
142 self.bootstrap_samples = samples;
143 self
144 }
145
146 pub fn confidence_level(mut self, level: Float) -> Self {
147 self.confidence_level = level;
148 self
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct ComprehensiveValidationResult {
155 pub validation_result: DummyValidationResult,
157 pub fold_details: Vec<FoldResult>,
159 pub statistical_summary: StatisticalSummary,
161 pub config: ValidationConfig,
163}
164
165#[derive(Debug, Clone)]
167pub struct FoldResult {
168 pub fold_index: usize,
170 pub train_size: usize,
172 pub test_size: usize,
174 pub score: Float,
176 pub fit_time: Float,
178 pub predict_time: Float,
180}
181
182#[derive(Debug, Clone)]
184pub struct StatisticalSummary {
185 pub mean: Float,
187 pub std: Float,
189 pub min: Float,
191 pub max: Float,
193 pub median: Float,
195 pub q25: Float,
197 pub q75: Float,
199 pub skewness: Float,
201 pub kurtosis: Float,
203}
204
205impl StatisticalSummary {
206 pub fn from_scores(scores: &[Float]) -> Self {
207 if scores.is_empty() {
208 return Self::default();
209 }
210
211 let mut sorted_scores = scores.to_vec();
212 sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
213
214 let n = scores.len() as Float;
215 let mean = scores.iter().sum::<Float>() / n;
216 let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<Float>() / n;
217 let std = variance.sqrt();
218 let min = sorted_scores[0];
219 let max = sorted_scores[sorted_scores.len() - 1];
220
221 let median = if sorted_scores.len() % 2 == 0 {
222 let mid = sorted_scores.len() / 2;
223 (sorted_scores[mid - 1] + sorted_scores[mid]) / 2.0
224 } else {
225 sorted_scores[sorted_scores.len() / 2]
226 };
227
228 let q25_idx = (sorted_scores.len() as Float * 0.25) as usize;
229 let q75_idx = (sorted_scores.len() as Float * 0.75) as usize;
230 let q25 = sorted_scores[q25_idx.min(sorted_scores.len() - 1)];
231 let q75 = sorted_scores[q75_idx.min(sorted_scores.len() - 1)];
232
233 let m3 = scores.iter().map(|&x| (x - mean).powi(3)).sum::<Float>() / n;
235 let m4 = scores.iter().map(|&x| (x - mean).powi(4)).sum::<Float>() / n;
236 let skewness = if std > 0.0 { m3 / std.powi(3) } else { 0.0 };
237 let kurtosis = if std > 0.0 {
238 m4 / std.powi(4) - 3.0
239 } else {
240 0.0
241 };
242
243 Self {
244 mean,
245 std,
246 min,
247 max,
248 median,
249 q25,
250 q75,
251 skewness,
252 kurtosis,
253 }
254 }
255}
256
257impl Default for StatisticalSummary {
258 fn default() -> Self {
259 Self {
260 mean: 0.0,
261 std: 0.0,
262 min: 0.0,
263 max: 0.0,
264 median: 0.0,
265 q25: 0.0,
266 q75: 0.0,
267 skewness: 0.0,
268 kurtosis: 0.0,
269 }
270 }
271}
272
273#[derive(Debug, Clone)]
275pub enum ValidationError {
276 InsufficientData(String),
278 InvalidFolds(String),
280 StratificationError(String),
282 ScoringError(String),
284 ConfigurationError(String),
286}
287
288impl std::fmt::Display for ValidationError {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 match self {
291 ValidationError::InsufficientData(msg) => write!(f, "Insufficient data: {}", msg),
292 ValidationError::InvalidFolds(msg) => write!(f, "Invalid fold configuration: {}", msg),
293 ValidationError::StratificationError(msg) => write!(f, "Stratification error: {}", msg),
294 ValidationError::ScoringError(msg) => write!(f, "Scoring error: {}", msg),
295 ValidationError::ConfigurationError(msg) => write!(f, "Configuration error: {}", msg),
296 }
297 }
298}
299
300impl std::error::Error for ValidationError {}
301
302pub fn validate_cv_params(n_samples: usize, cv_folds: usize) -> Result<()> {
304 if cv_folds < 2 {
305 return Err(SklearsError::InvalidInput(
306 "Cross-validation folds must be at least 2".to_string(),
307 ));
308 }
309
310 if n_samples < cv_folds {
311 return Err(SklearsError::InvalidInput(
312 "Number of samples must be at least equal to cv folds".to_string(),
313 ));
314 }
315
316 Ok(())
317}
318
319pub fn is_classification_task(y: &Array1<Float>) -> bool {
321 if y.is_empty() {
322 return false;
323 }
324
325 if y.iter().any(|&val| val.is_nan() || val.is_infinite()) {
327 return false;
328 }
329
330 let all_integers = y.iter().all(|&val| val.fract() == 0.0);
332 if !all_integers {
333 return false;
334 }
335
336 let mut unique_values: Vec<Float> = y.iter().copied().collect();
338 unique_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
339 unique_values.dedup();
340
341 unique_values.len() > 1 && unique_values.len() < 50
344}
345
346pub fn create_rng(random_state: Option<u64>) -> Box<dyn RngCore> {
348 match random_state {
349 Some(seed) => Box::new(scirs2_core::random::rngs::StdRng::seed_from_u64(seed)),
350 None => Box::new(thread_rng()),
351 }
352}