1use crate::error::{StatsError, StatsResult};
4use crate::regression::utils::*;
5use crate::regression::RegressionResults;
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::Float;
8use scirs2_linalg::lstsq;
9use std::collections::HashSet;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum StepwiseDirection {
14 Forward,
16 Backward,
18 Both,
20}
21
22#[derive(Debug, Clone, Copy)]
24pub enum StepwiseCriterion {
25 AIC,
27 BIC,
29 AdjR2,
31 F,
33 T,
35}
36
37pub struct StepwiseResults<F>
39where
40 F: Float + std::fmt::Debug + std::fmt::Display + 'static,
41{
42 pub final_model: RegressionResults<F>,
44
45 pub selected_indices: Vec<usize>,
47
48 pub sequence: Vec<(usize, bool)>, pub criteria_values: Vec<F>,
53}
54
55impl<F> StepwiseResults<F>
56where
57 F: Float + std::fmt::Debug + std::fmt::Display + 'static,
58{
59 pub fn summary(&self) -> String {
61 let mut summary = String::new();
62
63 summary.push_str("=== Stepwise Regression Results ===\n\n");
64
65 summary.push_str("Selected variables: ");
67 for (i, &idx) in self.selected_indices.iter().enumerate() {
68 if i > 0 {
69 summary.push_str(", ");
70 }
71 summary.push_str(&format!("X{}", idx));
72 }
73 summary.push_str("\n\n");
74
75 summary.push_str("Sequence of variable entry/exit:\n");
77 for (i, &(idx, is_entry)) in self.sequence.iter().enumerate() {
78 summary.push_str(&format!(
79 "Step {}: {} X{} (criterion value: {})\n",
80 i + 1,
81 if is_entry { "Added" } else { "Removed" },
82 idx,
83 self.criteria_values[i]
84 ));
85 }
86 summary.push('\n');
87
88 summary.push_str("Final Model:\n");
90 summary.push_str(&self.final_model.summary());
91
92 summary
93 }
94}
95
96#[allow(clippy::too_many_arguments)]
154#[allow(dead_code)]
155pub fn stepwise_regression<F>(
156 x: &ArrayView2<F>,
157 y: &ArrayView1<F>,
158 direction: StepwiseDirection,
159 criterion: StepwiseCriterion,
160 p_enter: Option<F>,
161 p_remove: Option<F>,
162 max_steps: Option<usize>,
163 include_intercept: bool,
164) -> StatsResult<StepwiseResults<F>>
165where
166 F: Float
167 + std::iter::Sum<F>
168 + std::ops::Div<Output = F>
169 + std::fmt::Debug
170 + std::fmt::Display
171 + 'static
172 + scirs2_core::numeric::NumAssign
173 + scirs2_core::numeric::One
174 + scirs2_core::ndarray::ScalarOperand
175 + Send
176 + Sync,
177{
178 if x.nrows() != y.len() {
180 return Err(StatsError::DimensionMismatch(format!(
181 "Input x has {} rows but y has length {}",
182 x.nrows(),
183 y.len()
184 )));
185 }
186
187 let n = x.nrows();
188 let p = x.ncols();
189
190 if n < 3 {
192 return Err(StatsError::InvalidArgument(
193 "At least 3 observations required for stepwise regression".to_string(),
194 ));
195 }
196
197 let p_enter = p_enter.unwrap_or_else(|| F::from(0.05).unwrap());
199 let p_remove = p_remove.unwrap_or_else(|| F::from(0.1).unwrap());
200
201 let max_steps = max_steps.unwrap_or(p * 2);
203
204 let mut selected_indices = match direction {
206 StepwiseDirection::Forward => HashSet::new(),
207 StepwiseDirection::Backward | StepwiseDirection::Both => {
208 let mut indices = HashSet::new();
210 for i in 0..p {
211 indices.insert(i);
212 }
213 indices
214 }
215 };
216
217 let mut sequence = Vec::new();
219 let mut criteria_values = Vec::new();
220
221 let mut current_x = match direction {
223 StepwiseDirection::Forward => {
224 if include_intercept {
226 Array2::<F>::ones((n, 1))
227 } else {
228 Array2::<F>::zeros((n, 0))
229 }
230 }
231 StepwiseDirection::Backward | StepwiseDirection::Both => {
232 if include_intercept {
234 let mut x_full = Array2::<F>::zeros((n, p + 1));
235 x_full.slice_mut(s![.., 0]).fill(F::one());
236 for i in 0..p {
237 x_full.slice_mut(s![.., i + 1]).assign(&x.slice(s![.., i]));
238 }
239 x_full
240 } else {
241 x.to_owned()
242 }
243 }
244 };
245
246 let mut step = 0;
248 let mut criterion_improved = true;
249
250 while step < max_steps && criterion_improved {
251 criterion_improved = false;
252
253 if direction == StepwiseDirection::Forward || direction == StepwiseDirection::Both {
255 let mut best_var = None;
257 let mut best_criterion = F::infinity();
258
259 for i in 0..p {
260 if selected_indices.contains(&i) {
262 continue;
263 }
264
265 let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
267 let var_col = x.slice(s![.., i]).to_owned();
268 test_x
269 .push_column(var_col.view())
270 .expect("Failed to push column");
271
272 if let Ok(model) = linear_regression(&test_x.view(), y) {
274 let crit_value =
275 calculate_criterion(&model, n, model.coefficients.len(), criterion);
276
277 if is_criterion_better(crit_value, best_criterion, criterion) {
278 best_var = Some(i);
279 best_criterion = crit_value;
280 }
281 }
282 }
283
284 if let Some(var_idx) = best_var {
286 let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
287 let var_col = x.slice(s![.., var_idx]).to_owned();
288 test_x
289 .push_column(var_col.view())
290 .expect("Failed to push column");
291
292 if let Ok(model) = linear_regression(&test_x.view(), y) {
293 let var_pos = test_x.ncols() - 1;
294 let _t_value = model.t_values[var_pos];
295 let p_value = model.p_values[var_pos];
296
297 if p_value <= p_enter {
298 selected_indices.insert(var_idx);
299 current_x = test_x;
300 sequence.push((var_idx, true));
301 criteria_values.push(best_criterion);
302 criterion_improved = true;
303 }
304 }
305 }
306 }
307
308 if (direction == StepwiseDirection::Backward || direction == StepwiseDirection::Both)
310 && !criterion_improved
311 && !selected_indices.is_empty()
312 {
313 let mut worst_var = None;
315 let mut worst_criterion = F::infinity();
316
317 for &var_idx in &selected_indices {
318 let mut test_indices = selected_indices.clone();
320 test_indices.remove(&var_idx);
321
322 let test_x = create_model_matrix(x, &test_indices, include_intercept);
323
324 if let Ok(model) = linear_regression(&test_x.view(), y) {
326 let crit_value =
327 calculate_criterion(&model, n, model.coefficients.len(), criterion);
328
329 if is_criterion_better(crit_value, worst_criterion, criterion) {
330 worst_var = Some(var_idx);
331 worst_criterion = crit_value;
332 }
333 }
334 }
335
336 if let Some(var_idx) = worst_var {
338 let var_pos = find_var_position(¤t_x, x, var_idx, include_intercept);
339
340 if let Ok(model) = linear_regression(¤t_x.view(), y) {
341 let p_value = model.p_values[var_pos];
342
343 if p_value > p_remove {
344 selected_indices.remove(&var_idx);
345 current_x = create_model_matrix(x, &selected_indices, include_intercept);
346 sequence.push((var_idx, false));
347 criteria_values.push(worst_criterion);
348 criterion_improved = true;
349 }
350 }
351 }
352 }
353
354 step += 1;
355 }
356
357 let final_model = linear_regression(¤t_x.view(), y)?;
359
360 let selected_indices = selected_indices.into_iter().collect();
362
363 Ok(StepwiseResults {
364 final_model,
365 selected_indices,
366 sequence,
367 criteria_values,
368 })
369}
370
371#[allow(dead_code)]
373fn create_model_matrix<F>(
374 x: &ArrayView2<F>,
375 indices: &HashSet<usize>,
376 include_intercept: bool,
377) -> Array2<F>
378where
379 F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
380{
381 let n = x.nrows();
382 let p = indices.len();
383
384 let cols = if include_intercept { p + 1 } else { p };
385 let mut x_model = Array2::<F>::zeros((n, cols));
386
387 if include_intercept {
388 x_model.slice_mut(s![.., 0]).fill(F::one());
389 }
390
391 let offset = if include_intercept { 1 } else { 0 };
392
393 for (i, &idx) in indices.iter().enumerate() {
394 x_model
395 .slice_mut(s![.., i + offset])
396 .assign(&x.slice(s![.., idx]));
397 }
398
399 x_model
400}
401
402#[allow(dead_code)]
403fn find_var_position<F>(
404 current_x: &Array2<F>,
405 x: &ArrayView2<F>,
406 var_idx: usize,
407 include_intercept: bool,
408) -> usize
409where
410 F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
411{
412 let offset = if include_intercept { 1 } else { 0 };
413
414 for i in offset..current_x.ncols() {
415 let col = current_x.slice(s![.., i]);
416 let x_col = x.slice(s![.., var_idx]);
417
418 if col
419 .iter()
420 .zip(x_col.iter())
421 .all(|(&a, &b)| (a - b).abs() < F::epsilon())
422 {
423 return i;
424 }
425 }
426
427 current_x.ncols() - 1
429}
430
431#[allow(dead_code)]
432fn calculate_criterion<F>(
433 model: &RegressionResults<F>,
434 n: usize,
435 p: usize,
436 criterion: StepwiseCriterion,
437) -> F
438where
439 F: Float + 'static + std::iter::Sum<F> + std::fmt::Debug + std::fmt::Display,
440{
441 match criterion {
442 StepwiseCriterion::AIC => {
443 let rss: F = model
444 .residuals
445 .iter()
446 .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
447 .sum();
448 let n_f = F::from(n).unwrap();
449 let k_f = F::from(p).unwrap();
450 n_f * scirs2_core::numeric::Float::ln(rss / n_f) + F::from(2.0).unwrap() * k_f
451 }
452 StepwiseCriterion::BIC => {
453 let rss: F = model
454 .residuals
455 .iter()
456 .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
457 .sum();
458 let n_f = F::from(n).unwrap();
459 let k_f = F::from(p).unwrap();
460 n_f * scirs2_core::numeric::Float::ln(rss / n_f)
461 + k_f * scirs2_core::numeric::Float::ln(n_f)
462 }
463 StepwiseCriterion::AdjR2 => {
464 -model.adj_r_squared }
466 StepwiseCriterion::F => {
467 -model.f_statistic }
469 StepwiseCriterion::T => {
470 let min_t = model
472 .t_values
473 .iter()
474 .map(|&t| t.abs())
475 .fold(F::infinity(), |a, b| a.min(b));
476 -min_t }
478 }
479}
480
481#[allow(dead_code)]
482fn is_criterion_better<F>(_new_value: F, oldvalue: F, criterion: StepwiseCriterion) -> bool
483where
484 F: Float + std::fmt::Display,
485{
486 match criterion {
487 StepwiseCriterion::AIC | StepwiseCriterion::BIC => _new_value < oldvalue,
489
490 StepwiseCriterion::AdjR2 | StepwiseCriterion::F | StepwiseCriterion::T => {
492 _new_value < oldvalue
493 }
494 }
495}
496
497#[allow(dead_code)]
499fn linear_regression<F>(x: &ArrayView2<F>, y: &ArrayView1<F>) -> StatsResult<RegressionResults<F>>
500where
501 F: Float
502 + std::iter::Sum<F>
503 + std::ops::Div<Output = F>
504 + std::fmt::Debug
505 + std::fmt::Display
506 + 'static
507 + scirs2_core::numeric::NumAssign
508 + scirs2_core::numeric::One
509 + scirs2_core::ndarray::ScalarOperand
510 + Send
511 + Sync,
512{
513 let n = x.nrows();
514 let p = x.ncols();
515
516 if n <= p {
518 return Err(StatsError::InvalidArgument(format!(
519 "Number of observations ({}) must be greater than number of predictors ({})",
520 n, p
521 )));
522 }
523
524 let coefficients = match lstsq(x, y, None) {
526 Ok(result) => result.x,
527 Err(e) => {
528 return Err(StatsError::ComputationError(format!(
529 "Least squares computation failed: {:?}",
530 e
531 )))
532 }
533 };
534
535 let fitted_values = x.dot(&coefficients);
537 let residuals = y.to_owned() - &fitted_values;
538
539 let df_model = p - 1; let df_residuals = n - p;
542
543 let (_y_mean, ss_total, ss_residual, ss_explained) =
545 calculate_sum_of_squares(y, &residuals.view());
546
547 let r_squared = ss_explained / ss_total;
549 let adj_r_squared = F::one()
550 - (F::one() - r_squared) * F::from(n - 1).unwrap() / F::from(df_residuals).unwrap();
551
552 let mse = ss_residual / F::from(df_residuals).unwrap();
554 let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
555
556 let std_errors = match calculate_std_errors(x, &residuals.view(), df_residuals) {
558 Ok(se) => se,
559 Err(_) => Array1::<F>::zeros(p),
560 };
561
562 let t_values = calculate_t_values(&coefficients, &std_errors);
564
565 let p_values = t_values.mapv(|t| {
568 let t_abs = scirs2_core::numeric::Float::abs(t);
569 let df_f = F::from(df_residuals).unwrap();
570 F::from(2.0).unwrap()
571 * (F::one() - t_abs / scirs2_core::numeric::Float::sqrt(df_f + t_abs * t_abs))
572 });
573
574 let mut conf_intervals = Array2::<F>::zeros((p, 2));
576 for i in 0..p {
577 let margin = std_errors[i] * F::from(1.96).unwrap(); conf_intervals[[i, 0]] = coefficients[i] - margin;
579 conf_intervals[[i, 1]] = coefficients[i] + margin;
580 }
581
582 let f_statistic = if df_model > 0 && df_residuals > 0 {
584 (ss_explained / F::from(df_model).unwrap()) / (ss_residual / F::from(df_residuals).unwrap())
585 } else {
586 F::infinity()
587 };
588
589 let f_p_value = F::zero(); Ok(RegressionResults {
594 coefficients,
595 std_errors,
596 t_values,
597 p_values,
598 conf_intervals,
599 r_squared,
600 adj_r_squared,
601 f_statistic,
602 f_p_value,
603 residual_std_error,
604 df_residuals,
605 residuals,
606 fitted_values,
607 inlier_mask: vec![true; n], })
609}