1use scirs2_core::ndarray::{ArrayView1, ArrayView2};
5use scirs2_core::random::SeedableRng;
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 types::Float,
10};
11
12#[derive(Debug, Clone)]
14pub struct ComplexityAnalysisResult {
15 pub effective_degrees_freedom: Float,
17 pub complexity_score: Float,
19 pub aic: Float,
21 pub bic: Float,
23 pub mdl: Float,
25 pub cv_complexity: Float,
27 pub interaction_complexity: Float,
29 pub n_effective_params: usize,
31}
32
33#[derive(Debug, Clone)]
35pub struct ComplexityConfig {
36 pub cv_folds: usize,
38 pub include_interactions: bool,
40 pub complexity_penalty: Float,
42 pub random_state: Option<u64>,
44}
45
46impl Default for ComplexityConfig {
47 fn default() -> Self {
48 Self {
49 cv_folds: 5,
50 include_interactions: true,
51 complexity_penalty: 1.0,
52 random_state: None,
53 }
54 }
55}
56
57pub fn analyze_model_complexity<F>(
100 predict_fn: &F,
101 X: &ArrayView2<Float>,
102 y: &ArrayView1<Float>,
103 n_params: usize,
104 config: &ComplexityConfig,
105) -> SklResult<ComplexityAnalysisResult>
106where
107 F: Fn(&ArrayView2<Float>) -> Vec<Float>,
108{
109 let (n_samples, n_features) = X.dim();
110
111 if n_samples != y.len() {
112 return Err(SklearsError::InvalidInput(
113 "X and y must have the same number of samples".to_string(),
114 ));
115 }
116
117 if n_samples == 0 || n_features == 0 {
118 return Err(SklearsError::InvalidInput(
119 "X and y must have non-zero samples and features".to_string(),
120 ));
121 }
122
123 let predictions = predict_fn(X);
125
126 let rss = compute_residual_sum_squares(y, &predictions);
128
129 let log_likelihood = compute_log_likelihood(y, &predictions, rss);
131
132 let aic = compute_aic(log_likelihood, n_params);
134 let bic = compute_bic(log_likelihood, n_params, n_samples);
135 let mdl = compute_mdl(log_likelihood, n_params, n_samples);
136
137 let effective_df = estimate_effective_degrees_freedom(predict_fn, X, y, config)?;
139
140 let cv_complexity = compute_cv_complexity(predict_fn, X, y, config)?;
142
143 let interaction_complexity = if config.include_interactions {
145 compute_interaction_complexity(predict_fn, X, y)?
146 } else {
147 0.0
148 };
149
150 let complexity_score = compute_overall_complexity_score(
152 effective_df,
153 cv_complexity,
154 interaction_complexity,
155 n_params,
156 n_features,
157 config.complexity_penalty,
158 );
159
160 let n_effective_params = effective_df.round() as usize;
161
162 Ok(ComplexityAnalysisResult {
163 effective_degrees_freedom: effective_df,
164 complexity_score,
165 aic,
166 bic,
167 mdl,
168 cv_complexity,
169 interaction_complexity,
170 n_effective_params,
171 })
172}
173
174fn compute_residual_sum_squares(y_true: &ArrayView1<Float>, y_pred: &[Float]) -> Float {
176 y_true
177 .iter()
178 .zip(y_pred.iter())
179 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
180 .sum()
181}
182
183fn compute_log_likelihood(y_true: &ArrayView1<Float>, y_pred: &[Float], rss: Float) -> Float {
185 let n = y_true.len() as Float;
186 let sigma_squared = rss / n;
187
188 if sigma_squared <= 0.0 {
189 return 0.0; }
191
192 -0.5 * n * (2.0 * std::f64::consts::PI * sigma_squared).ln() - 0.5 * rss / sigma_squared
193}
194
195fn compute_aic(log_likelihood: Float, n_params: usize) -> Float {
197 -2.0 * log_likelihood + 2.0 * n_params as Float
198}
199
200fn compute_bic(log_likelihood: Float, n_params: usize, n_samples: usize) -> Float {
202 -2.0 * log_likelihood + (n_samples as Float).ln() * n_params as Float
203}
204
205fn compute_mdl(log_likelihood: Float, n_params: usize, n_samples: usize) -> Float {
207 -log_likelihood + 0.5 * n_params as Float * (n_samples as Float).ln()
209}
210
211fn estimate_effective_degrees_freedom<F>(
213 predict_fn: &F,
214 X: &ArrayView2<Float>,
215 y: &ArrayView1<Float>,
216 config: &ComplexityConfig,
217) -> SklResult<Float>
218where
219 F: Fn(&ArrayView2<Float>) -> Vec<Float>,
220{
221 use scirs2_core::random::{seq::SliceRandom, SeedableRng};
222
223 let n_samples = X.nrows();
224 let mut rng = match config.random_state {
225 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
226 None => scirs2_core::random::rngs::StdRng::from_rng(&mut scirs2_core::random::thread_rng()),
227 };
228
229 let mut df_estimates = Vec::new();
230
231 for _ in 0..50 {
233 let mut indices: Vec<usize> = (0..n_samples).collect();
236 indices.shuffle(&mut rng);
237
238 let predictions = predict_fn(X);
240 let pred_variance = compute_prediction_variance(&predictions);
241 let noise_variance = estimate_noise_variance(y, &predictions);
242
243 let df_est = if noise_variance > 0.0 {
245 (pred_variance / noise_variance).min(n_samples as Float)
246 } else {
247 n_samples as Float
248 };
249
250 df_estimates.push(df_est);
251 }
252
253 df_estimates.sort_by(|a, b| a.partial_cmp(b).unwrap());
255 Ok(df_estimates[df_estimates.len() / 2])
256}
257
258fn compute_prediction_variance(predictions: &[Float]) -> Float {
260 let mean = predictions.iter().sum::<Float>() / predictions.len() as Float;
261 predictions
262 .iter()
263 .map(|&p| (p - mean).powi(2))
264 .sum::<Float>()
265 / predictions.len() as Float
266}
267
268fn estimate_noise_variance(y_true: &ArrayView1<Float>, y_pred: &[Float]) -> Float {
270 let residuals: Vec<Float> = y_true
271 .iter()
272 .zip(y_pred.iter())
273 .map(|(&true_val, &pred_val)| true_val - pred_val)
274 .collect();
275
276 let mean_residual = residuals.iter().sum::<Float>() / residuals.len() as Float;
277 residuals
278 .iter()
279 .map(|&r| (r - mean_residual).powi(2))
280 .sum::<Float>()
281 / residuals.len() as Float
282}
283
284fn compute_cv_complexity<F>(
286 predict_fn: &F,
287 X: &ArrayView2<Float>,
288 y: &ArrayView1<Float>,
289 config: &ComplexityConfig,
290) -> SklResult<Float>
291where
292 F: Fn(&ArrayView2<Float>) -> Vec<Float>,
293{
294 let n_samples = X.nrows();
295 let fold_size = n_samples / config.cv_folds;
296
297 if fold_size == 0 {
298 return Ok(1.0); }
300
301 let mut cv_scores = Vec::new();
302
303 for fold in 0..config.cv_folds {
305 let start_idx = fold * fold_size;
306 let end_idx = if fold == config.cv_folds - 1 {
307 n_samples
308 } else {
309 (fold + 1) * fold_size
310 };
311
312 let val_indices: Vec<usize> = (start_idx..end_idx).collect();
314
315 let predictions = predict_fn(X);
317 let val_predictions: Vec<Float> = val_indices.iter().map(|&idx| predictions[idx]).collect();
318
319 let val_y: Vec<Float> = val_indices.iter().map(|&idx| y[idx]).collect();
320
321 let score_variance = compute_score_variance(&val_y, &val_predictions);
323 cv_scores.push(score_variance);
324 }
325
326 Ok(cv_scores.iter().sum::<Float>() / cv_scores.len() as Float)
328}
329
330fn compute_score_variance(y_true: &[Float], y_pred: &[Float]) -> Float {
332 if y_true.is_empty() {
333 return 0.0;
334 }
335
336 let errors: Vec<Float> = y_true
337 .iter()
338 .zip(y_pred.iter())
339 .map(|(&true_val, &pred_val)| (true_val - pred_val).abs())
340 .collect();
341
342 let mean_error = errors.iter().sum::<Float>() / errors.len() as Float;
343 errors
344 .iter()
345 .map(|&e| (e - mean_error).powi(2))
346 .sum::<Float>()
347 / errors.len() as Float
348}
349
350fn compute_interaction_complexity<F>(
352 predict_fn: &F,
353 X: &ArrayView2<Float>,
354 y: &ArrayView1<Float>,
355) -> SklResult<Float>
356where
357 F: Fn(&ArrayView2<Float>) -> Vec<Float>,
358{
359 let n_features = X.ncols();
360
361 if n_features < 2 {
362 return Ok(0.0); }
364
365 let mut interaction_strength = 0.0;
367 let baseline_predictions = predict_fn(X);
368
369 let max_pairs = 10.min(n_features * (n_features - 1) / 2);
371 let mut pair_count = 0;
372
373 for i in 0..n_features {
374 for j in (i + 1)..n_features {
375 if pair_count >= max_pairs {
376 break;
377 }
378
379 let interaction_effect =
381 compute_pairwise_interaction(predict_fn, X, &baseline_predictions, i, j);
382
383 interaction_strength += interaction_effect.abs();
384 pair_count += 1;
385 }
386 }
387
388 Ok(interaction_strength / pair_count as Float)
389}
390
391fn compute_pairwise_interaction<F>(
393 predict_fn: &F,
394 X: &ArrayView2<Float>,
395 baseline_predictions: &[Float],
396 feature_i: usize,
397 feature_j: usize,
398) -> Float
399where
400 F: Fn(&ArrayView2<Float>) -> Vec<Float>,
401{
402 let n_samples = X.nrows();
403
404 let mut X_perturbed = X.to_owned();
406
407 let perturbation = 0.1;
409
410 for sample_idx in 0..n_samples {
411 X_perturbed[[sample_idx, feature_i]] += perturbation;
412 X_perturbed[[sample_idx, feature_j]] += perturbation;
413 }
414
415 let perturbed_predictions = predict_fn(&X_perturbed.view());
416
417 let interaction_effect: Float = perturbed_predictions
419 .iter()
420 .zip(baseline_predictions.iter())
421 .map(|(&perturbed, &baseline)| (perturbed - baseline).abs())
422 .sum::<Float>()
423 / n_samples as Float;
424
425 interaction_effect
426}
427
428fn compute_overall_complexity_score(
430 effective_df: Float,
431 cv_complexity: Float,
432 interaction_complexity: Float,
433 n_params: usize,
434 n_features: usize,
435 penalty: Float,
436) -> Float {
437 let df_component = effective_df / n_features as Float;
439 let cv_component = cv_complexity;
440 let interaction_component = interaction_complexity;
441 let param_component = n_params as Float / n_features as Float;
442
443 let complexity = 0.3 * df_component
445 + 0.3 * cv_component
446 + 0.2 * interaction_component
447 + 0.2 * param_component;
448
449 complexity * penalty
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use scirs2_core::ndarray::{array, ArrayView1, ArrayView2};
457
458 #[test]
459 #[allow(non_snake_case)]
460 fn test_complexity_analysis() {
461 let predict_fn = |x: &ArrayView2<Float>| -> Vec<Float> {
463 x.rows().into_iter().map(|row| row.iter().sum()).collect()
464 };
465
466 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
467 let y = array![3.0, 7.0, 11.0, 15.0]; let n_params = 3; let result = analyze_model_complexity(
471 &predict_fn,
472 &X.view(),
473 &y.view(),
474 n_params,
475 &ComplexityConfig::default(),
476 )
477 .unwrap();
478
479 assert!(result.complexity_score > 0.0);
480 assert!(result.effective_degrees_freedom > 0.0);
481 assert!(!result.aic.is_infinite());
482 assert!(!result.bic.is_infinite());
483 assert!(!result.mdl.is_infinite());
484 }
485
486 #[test]
487 fn test_information_criteria() {
488 let log_likelihood = -10.0;
489 let n_params = 3;
490 let n_samples = 100;
491
492 let aic = compute_aic(log_likelihood, n_params);
493 let bic = compute_bic(log_likelihood, n_params, n_samples);
494 let mdl = compute_mdl(log_likelihood, n_params, n_samples);
495
496 assert_eq!(aic, 26.0); assert!(bic > aic); assert!(mdl > 0.0);
499 }
500
501 #[test]
502 #[allow(non_snake_case)]
503 fn test_complexity_analysis_errors() {
504 let predict_fn = |x: &ArrayView2<Float>| -> Vec<Float> {
505 x.rows().into_iter().map(|row| row.iter().sum()).collect()
506 };
507
508 let X = array![[1.0, 2.0], [3.0, 4.0]];
510 let y = array![3.0]; let result = analyze_model_complexity(
513 &predict_fn,
514 &X.view(),
515 &y.view(),
516 2,
517 &ComplexityConfig::default(),
518 );
519 assert!(result.is_err());
520
521 let X_empty = array![[], []];
523 let y_empty = array![];
524 let result = analyze_model_complexity(
525 &predict_fn,
526 &X_empty.view(),
527 &y_empty.view(),
528 2,
529 &ComplexityConfig::default(),
530 );
531 assert!(result.is_err());
532 }
533}