1use crate::error::{StatsError, StatsResult};
4use crate::regression::utils::*;
5use crate::regression::RegressionResults;
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::Float;
8use scirs2_linalg::lstsq;
9use std::collections::HashSet;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum StepwiseDirection {
14 Forward,
16 Backward,
18 Both,
20}
21
22#[derive(Debug, Clone, Copy)]
24pub enum StepwiseCriterion {
25 AIC,
27 BIC,
29 AdjR2,
31 F,
33 T,
35}
36
37pub struct StepwiseResults<F>
39where
40 F: Float + std::fmt::Debug + std::fmt::Display + 'static,
41{
42 pub final_model: RegressionResults<F>,
44
45 pub selected_indices: Vec<usize>,
47
48 pub sequence: Vec<(usize, bool)>, pub criteria_values: Vec<F>,
53}
54
55impl<F> StepwiseResults<F>
56where
57 F: Float + std::fmt::Debug + std::fmt::Display + 'static,
58{
59 pub fn summary(&self) -> String {
61 let mut summary = String::new();
62
63 summary.push_str("=== Stepwise Regression Results ===\n\n");
64
65 summary.push_str("Selected variables: ");
67 for (i, &idx) in self.selected_indices.iter().enumerate() {
68 if i > 0 {
69 summary.push_str(", ");
70 }
71 summary.push_str(&format!("X{}", idx));
72 }
73 summary.push_str("\n\n");
74
75 summary.push_str("Sequence of variable entry/exit:\n");
77 for (i, &(idx, is_entry)) in self.sequence.iter().enumerate() {
78 summary.push_str(&format!(
79 "Step {}: {} X{} (criterion value: {})\n",
80 i + 1,
81 if is_entry { "Added" } else { "Removed" },
82 idx,
83 self.criteria_values[i]
84 ));
85 }
86 summary.push('\n');
87
88 summary.push_str("Final Model:\n");
90 summary.push_str(&self.final_model.summary());
91
92 summary
93 }
94}
95
96#[allow(clippy::too_many_arguments)]
154#[allow(dead_code)]
155pub fn stepwise_regression<F>(
156 x: &ArrayView2<F>,
157 y: &ArrayView1<F>,
158 direction: StepwiseDirection,
159 criterion: StepwiseCriterion,
160 p_enter: Option<F>,
161 p_remove: Option<F>,
162 max_steps: Option<usize>,
163 include_intercept: bool,
164) -> StatsResult<StepwiseResults<F>>
165where
166 F: Float
167 + std::iter::Sum<F>
168 + std::ops::Div<Output = F>
169 + std::fmt::Debug
170 + std::fmt::Display
171 + 'static
172 + scirs2_core::numeric::NumAssign
173 + scirs2_core::numeric::One
174 + scirs2_core::ndarray::ScalarOperand
175 + Send
176 + Sync,
177{
178 if x.nrows() != y.len() {
180 return Err(StatsError::DimensionMismatch(format!(
181 "Input x has {} rows but y has length {}",
182 x.nrows(),
183 y.len()
184 )));
185 }
186
187 let n = x.nrows();
188 let p = x.ncols();
189
190 if n < 3 {
192 return Err(StatsError::InvalidArgument(
193 "At least 3 observations required for stepwise regression".to_string(),
194 ));
195 }
196
197 let p_enter =
199 p_enter.unwrap_or_else(|| F::from(0.05).expect("Failed to convert constant to float"));
200 let p_remove =
201 p_remove.unwrap_or_else(|| F::from(0.1).expect("Failed to convert constant to float"));
202
203 let max_steps = max_steps.unwrap_or(p * 2);
205
206 let mut selected_indices = match direction {
208 StepwiseDirection::Forward => HashSet::new(),
209 StepwiseDirection::Backward | StepwiseDirection::Both => {
210 let mut indices = HashSet::new();
212 for i in 0..p {
213 indices.insert(i);
214 }
215 indices
216 }
217 };
218
219 let mut sequence = Vec::new();
221 let mut criteria_values = Vec::new();
222
223 let mut current_x = match direction {
225 StepwiseDirection::Forward => {
226 if include_intercept {
228 Array2::<F>::ones((n, 1))
229 } else {
230 Array2::<F>::zeros((n, 0))
231 }
232 }
233 StepwiseDirection::Backward | StepwiseDirection::Both => {
234 if include_intercept {
236 let mut x_full = Array2::<F>::zeros((n, p + 1));
237 x_full.slice_mut(s![.., 0]).fill(F::one());
238 for i in 0..p {
239 x_full.slice_mut(s![.., i + 1]).assign(&x.slice(s![.., i]));
240 }
241 x_full
242 } else {
243 x.to_owned()
244 }
245 }
246 };
247
248 let mut step = 0;
250 let mut criterion_improved = true;
251
252 while step < max_steps && criterion_improved {
253 criterion_improved = false;
254
255 if direction == StepwiseDirection::Forward || direction == StepwiseDirection::Both {
257 let mut best_var = None;
259 let mut best_criterion = F::infinity();
260
261 for i in 0..p {
262 if selected_indices.contains(&i) {
264 continue;
265 }
266
267 let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
269 let var_col = x.slice(s![.., i]).to_owned();
270 test_x
271 .push_column(var_col.view())
272 .expect("Failed to push column");
273
274 if let Ok(model) = linear_regression(&test_x.view(), y) {
276 let crit_value =
277 calculate_criterion(&model, n, model.coefficients.len(), criterion);
278
279 if is_criterion_better(crit_value, best_criterion, criterion) {
280 best_var = Some(i);
281 best_criterion = crit_value;
282 }
283 }
284 }
285
286 if let Some(var_idx) = best_var {
288 let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
289 let var_col = x.slice(s![.., var_idx]).to_owned();
290 test_x
291 .push_column(var_col.view())
292 .expect("Failed to push column");
293
294 if let Ok(model) = linear_regression(&test_x.view(), y) {
295 let var_pos = test_x.ncols() - 1;
296 let _t_value = model.t_values[var_pos];
297 let p_value = model.p_values[var_pos];
298
299 if p_value <= p_enter {
300 selected_indices.insert(var_idx);
301 current_x = test_x;
302 sequence.push((var_idx, true));
303 criteria_values.push(best_criterion);
304 criterion_improved = true;
305 }
306 }
307 }
308 }
309
310 if (direction == StepwiseDirection::Backward || direction == StepwiseDirection::Both)
312 && !criterion_improved
313 && !selected_indices.is_empty()
314 {
315 let mut worst_var = None;
317 let mut worst_criterion = F::infinity();
318
319 for &var_idx in &selected_indices {
320 let mut test_indices = selected_indices.clone();
322 test_indices.remove(&var_idx);
323
324 let test_x = create_model_matrix(x, &test_indices, include_intercept);
325
326 if let Ok(model) = linear_regression(&test_x.view(), y) {
328 let crit_value =
329 calculate_criterion(&model, n, model.coefficients.len(), criterion);
330
331 if is_criterion_better(crit_value, worst_criterion, criterion) {
332 worst_var = Some(var_idx);
333 worst_criterion = crit_value;
334 }
335 }
336 }
337
338 if let Some(var_idx) = worst_var {
340 let var_pos = find_var_position(¤t_x, x, var_idx, include_intercept);
341
342 if let Ok(model) = linear_regression(¤t_x.view(), y) {
343 let p_value = model.p_values[var_pos];
344
345 if p_value > p_remove {
346 selected_indices.remove(&var_idx);
347 current_x = create_model_matrix(x, &selected_indices, include_intercept);
348 sequence.push((var_idx, false));
349 criteria_values.push(worst_criterion);
350 criterion_improved = true;
351 }
352 }
353 }
354 }
355
356 step += 1;
357 }
358
359 let final_model = linear_regression(¤t_x.view(), y)?;
361
362 let selected_indices = selected_indices.into_iter().collect();
364
365 Ok(StepwiseResults {
366 final_model,
367 selected_indices,
368 sequence,
369 criteria_values,
370 })
371}
372
373#[allow(dead_code)]
375fn create_model_matrix<F>(
376 x: &ArrayView2<F>,
377 indices: &HashSet<usize>,
378 include_intercept: bool,
379) -> Array2<F>
380where
381 F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
382{
383 let n = x.nrows();
384 let p = indices.len();
385
386 let cols = if include_intercept { p + 1 } else { p };
387 let mut x_model = Array2::<F>::zeros((n, cols));
388
389 if include_intercept {
390 x_model.slice_mut(s![.., 0]).fill(F::one());
391 }
392
393 let offset = if include_intercept { 1 } else { 0 };
394
395 for (i, &idx) in indices.iter().enumerate() {
396 x_model
397 .slice_mut(s![.., i + offset])
398 .assign(&x.slice(s![.., idx]));
399 }
400
401 x_model
402}
403
404#[allow(dead_code)]
405fn find_var_position<F>(
406 current_x: &Array2<F>,
407 x: &ArrayView2<F>,
408 var_idx: usize,
409 include_intercept: bool,
410) -> usize
411where
412 F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
413{
414 let offset = if include_intercept { 1 } else { 0 };
415
416 for i in offset..current_x.ncols() {
417 let col = current_x.slice(s![.., i]);
418 let x_col = x.slice(s![.., var_idx]);
419
420 if col
421 .iter()
422 .zip(x_col.iter())
423 .all(|(&a, &b)| (a - b).abs() < F::epsilon())
424 {
425 return i;
426 }
427 }
428
429 current_x.ncols() - 1
431}
432
433#[allow(dead_code)]
434fn calculate_criterion<F>(
435 model: &RegressionResults<F>,
436 n: usize,
437 p: usize,
438 criterion: StepwiseCriterion,
439) -> F
440where
441 F: Float + 'static + std::iter::Sum<F> + std::fmt::Debug + std::fmt::Display,
442{
443 match criterion {
444 StepwiseCriterion::AIC => {
445 let rss: F = model
446 .residuals
447 .iter()
448 .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
449 .sum();
450 let n_f = F::from(n).expect("Failed to convert to float");
451 let k_f = F::from(p).expect("Failed to convert to float");
452 n_f * scirs2_core::numeric::Float::ln(rss / n_f)
453 + F::from(2.0).expect("Failed to convert constant to float") * k_f
454 }
455 StepwiseCriterion::BIC => {
456 let rss: F = model
457 .residuals
458 .iter()
459 .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
460 .sum();
461 let n_f = F::from(n).expect("Failed to convert to float");
462 let k_f = F::from(p).expect("Failed to convert to float");
463 n_f * scirs2_core::numeric::Float::ln(rss / n_f)
464 + k_f * scirs2_core::numeric::Float::ln(n_f)
465 }
466 StepwiseCriterion::AdjR2 => {
467 -model.adj_r_squared }
469 StepwiseCriterion::F => {
470 -model.f_statistic }
472 StepwiseCriterion::T => {
473 let min_t = model
475 .t_values
476 .iter()
477 .map(|&t| t.abs())
478 .fold(F::infinity(), |a, b| a.min(b));
479 -min_t }
481 }
482}
483
484#[allow(dead_code)]
485fn is_criterion_better<F>(_new_value: F, oldvalue: F, criterion: StepwiseCriterion) -> bool
486where
487 F: Float + std::fmt::Display,
488{
489 match criterion {
490 StepwiseCriterion::AIC | StepwiseCriterion::BIC => _new_value < oldvalue,
492
493 StepwiseCriterion::AdjR2 | StepwiseCriterion::F | StepwiseCriterion::T => {
495 _new_value < oldvalue
496 }
497 }
498}
499
500#[allow(dead_code)]
502fn linear_regression<F>(x: &ArrayView2<F>, y: &ArrayView1<F>) -> StatsResult<RegressionResults<F>>
503where
504 F: Float
505 + std::iter::Sum<F>
506 + std::ops::Div<Output = F>
507 + std::fmt::Debug
508 + std::fmt::Display
509 + 'static
510 + scirs2_core::numeric::NumAssign
511 + scirs2_core::numeric::One
512 + scirs2_core::ndarray::ScalarOperand
513 + Send
514 + Sync,
515{
516 let n = x.nrows();
517 let p = x.ncols();
518
519 if n <= p {
521 return Err(StatsError::InvalidArgument(format!(
522 "Number of observations ({}) must be greater than number of predictors ({})",
523 n, p
524 )));
525 }
526
527 let coefficients = match lstsq(x, y, None) {
529 Ok(result) => result.x,
530 Err(e) => {
531 return Err(StatsError::ComputationError(format!(
532 "Least squares computation failed: {:?}",
533 e
534 )))
535 }
536 };
537
538 let fitted_values = x.dot(&coefficients);
540 let residuals = y.to_owned() - &fitted_values;
541
542 let df_model = p - 1; let df_residuals = n - p;
545
546 let (_y_mean, ss_total, ss_residual, ss_explained) =
548 calculate_sum_of_squares(y, &residuals.view());
549
550 let r_squared = ss_explained / ss_total;
552 let adj_r_squared = F::one()
553 - (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
554 / F::from(df_residuals).expect("Failed to convert to float");
555
556 let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
558 let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
559
560 let std_errors = match calculate_std_errors(x, &residuals.view(), df_residuals) {
562 Ok(se) => se,
563 Err(_) => Array1::<F>::zeros(p),
564 };
565
566 let t_values = calculate_t_values(&coefficients, &std_errors);
568
569 let p_values = t_values.mapv(|t| {
572 let t_abs = scirs2_core::numeric::Float::abs(t);
573 let df_f = F::from(df_residuals).expect("Failed to convert to float");
574 F::from(2.0).expect("Failed to convert constant to float")
575 * (F::one() - t_abs / scirs2_core::numeric::Float::sqrt(df_f + t_abs * t_abs))
576 });
577
578 let mut conf_intervals = Array2::<F>::zeros((p, 2));
580 for i in 0..p {
581 let margin = std_errors[i] * F::from(1.96).expect("Failed to convert constant to float"); conf_intervals[[i, 0]] = coefficients[i] - margin;
583 conf_intervals[[i, 1]] = coefficients[i] + margin;
584 }
585
586 let f_statistic = if df_model > 0 && df_residuals > 0 {
588 (ss_explained / F::from(df_model).expect("Failed to convert to float"))
589 / (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
590 } else {
591 F::infinity()
592 };
593
594 let f_p_value = F::zero(); Ok(RegressionResults {
599 coefficients,
600 std_errors,
601 t_values,
602 p_values,
603 conf_intervals,
604 r_squared,
605 adj_r_squared,
606 f_statistic,
607 f_p_value,
608 residual_std_error,
609 df_residuals,
610 residuals,
611 fitted_values,
612 inlier_mask: vec![true; n], })
614}