scirs2_stats/regression/
stepwise.rs

1//! Stepwise regression implementations
2
3use crate::error::{StatsError, StatsResult};
4use crate::regression::utils::*;
5use crate::regression::RegressionResults;
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::Float;
8use scirs2_linalg::lstsq;
9use std::collections::HashSet;
10
11/// Direction for stepwise regression
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum StepwiseDirection {
14    /// Forward selection (start with no variables and add)
15    Forward,
16    /// Backward elimination (start with all variables and remove)
17    Backward,
18    /// Bidirectional selection (both add and remove)
19    Both,
20}
21
22/// Criterion for selecting variables in stepwise regression
23#[derive(Debug, Clone, Copy)]
24pub enum StepwiseCriterion {
25    /// Akaike Information Criterion (AIC)
26    AIC,
27    /// Bayesian Information Criterion (BIC)
28    BIC,
29    /// Adjusted R-squared
30    AdjR2,
31    /// F-test significance
32    F,
33    /// t-test significance
34    T,
35}
36
37/// Results from stepwise regression
38pub struct StepwiseResults<F>
39where
40    F: Float + std::fmt::Debug + std::fmt::Display + 'static,
41{
42    /// The final regression model
43    pub final_model: RegressionResults<F>,
44
45    /// Indices of selected variables
46    pub selected_indices: Vec<usize>,
47
48    /// Variable entry/exit sequence
49    pub sequence: Vec<(usize, bool)>, // (index, is_entry)
50
51    /// Criteria values at each step
52    pub criteria_values: Vec<F>,
53}
54
55impl<F> StepwiseResults<F>
56where
57    F: Float + std::fmt::Debug + std::fmt::Display + 'static,
58{
59    /// Returns a summary of the stepwise regression process
60    pub fn summary(&self) -> String {
61        let mut summary = String::new();
62
63        summary.push_str("=== Stepwise Regression Results ===\n\n");
64
65        // Selected variables
66        summary.push_str("Selected variables: ");
67        for (i, &idx) in self.selected_indices.iter().enumerate() {
68            if i > 0 {
69                summary.push_str(", ");
70            }
71            summary.push_str(&format!("X{}", idx));
72        }
73        summary.push_str("\n\n");
74
75        // Sequence of entry/exit
76        summary.push_str("Sequence of variable entry/exit:\n");
77        for (i, &(idx, is_entry)) in self.sequence.iter().enumerate() {
78            summary.push_str(&format!(
79                "Step {}: {} X{} (criterion value: {})\n",
80                i + 1,
81                if is_entry { "Added" } else { "Removed" },
82                idx,
83                self.criteria_values[i]
84            ));
85        }
86        summary.push('\n');
87
88        // Final model summary
89        summary.push_str("Final Model:\n");
90        summary.push_str(&self.final_model.summary());
91
92        summary
93    }
94}
95
96/// Perform stepwise regression using various criteria and directions.
97///
98/// # Arguments
99///
100/// * `x` - Independent variables (design matrix)
101/// * `y` - Dependent variable
102/// * `direction` - Direction for stepwise regression (Forward, Backward, or Both)
103/// * `criterion` - Criterion for variable selection
104/// * `p_enter` - p-value threshold for entering variables (for F or T criteria)
105/// * `p_remove` - p-value threshold for removing variables (for F or T criteria)
106/// * `max_steps` - Maximum number of steps to perform
107/// * `include_intercept` - Whether to include an intercept term
108///
109/// # Returns
110///
111/// A StepwiseResults struct with the final model and selection details.
112///
113/// # Examples
114///
115/// ```
116/// use scirs2_core::ndarray::{array, Array2};
117/// use scirs2_stats::{stepwise_regression, StepwiseDirection, StepwiseCriterion};
118///
119/// // Create a design matrix with 3 variables (independent)
120/// let x = Array2::from_shape_vec((10, 3), vec![
121///     1.0, 0.0, 0.0,
122///     0.0, 1.0, 0.0,
123///     0.0, 0.0, 1.0,
124///     1.0, 1.0, 0.0,
125///     1.0, 0.0, 1.0,
126///     0.0, 1.0, 1.0,
127///     1.0, 1.0, 1.0,
128///     2.0, 0.0, 0.0,
129///     0.0, 2.0, 0.0,
130///     0.0, 0.0, 2.0,
131/// ]).unwrap();
132///
133/// // Target values: y = 2.0*x0 + 3.0*x1 + small noise (clearly depends on first two variables)
134/// let y = array![
135///     2.0, 3.0, 0.1, 5.0, 2.1, 3.1, 5.1, 4.0, 6.0, 0.2
136/// ];
137///
138/// // Perform forward stepwise regression using AIC with relaxed p-value threshold
139/// let results = stepwise_regression(
140///     &x.view(),
141///     &y.view(),
142///     StepwiseDirection::Forward,
143///     StepwiseCriterion::AIC,
144///     Some(0.5), // More relaxed entry threshold
145///     Some(0.6), // More relaxed removal threshold
146///     None,
147///     true
148/// ).unwrap();
149///
150/// // Check that the algorithm selected at least one variable
151/// assert!(!results.selected_indices.is_empty());
152/// ```
153#[allow(clippy::too_many_arguments)]
154#[allow(dead_code)]
155pub fn stepwise_regression<F>(
156    x: &ArrayView2<F>,
157    y: &ArrayView1<F>,
158    direction: StepwiseDirection,
159    criterion: StepwiseCriterion,
160    p_enter: Option<F>,
161    p_remove: Option<F>,
162    max_steps: Option<usize>,
163    include_intercept: bool,
164) -> StatsResult<StepwiseResults<F>>
165where
166    F: Float
167        + std::iter::Sum<F>
168        + std::ops::Div<Output = F>
169        + std::fmt::Debug
170        + std::fmt::Display
171        + 'static
172        + scirs2_core::numeric::NumAssign
173        + scirs2_core::numeric::One
174        + scirs2_core::ndarray::ScalarOperand
175        + Send
176        + Sync,
177{
178    // Check input dimensions
179    if x.nrows() != y.len() {
180        return Err(StatsError::DimensionMismatch(format!(
181            "Input x has {} rows but y has length {}",
182            x.nrows(),
183            y.len()
184        )));
185    }
186
187    let n = x.nrows();
188    let p = x.ncols();
189
190    // Need at least 3 observations for meaningful regression
191    if n < 3 {
192        return Err(StatsError::InvalidArgument(
193            "At least 3 observations required for stepwise regression".to_string(),
194        ));
195    }
196
197    // Default thresholds for entry/removal
198    let p_enter = p_enter.unwrap_or_else(|| F::from(0.05).unwrap());
199    let p_remove = p_remove.unwrap_or_else(|| F::from(0.1).unwrap());
200
201    // Default maximum _steps
202    let max_steps = max_steps.unwrap_or(p * 2);
203
204    // Track selected variables
205    let mut selected_indices = match direction {
206        StepwiseDirection::Forward => HashSet::new(),
207        StepwiseDirection::Backward | StepwiseDirection::Both => {
208            // Start with all variables
209            let mut indices = HashSet::new();
210            for i in 0..p {
211                indices.insert(i);
212            }
213            indices
214        }
215    };
216
217    // Track variable entry/exit sequence and criteria values
218    let mut sequence = Vec::new();
219    let mut criteria_values = Vec::new();
220
221    // Keep track of current model
222    let mut current_x = match direction {
223        StepwiseDirection::Forward => {
224            // Start with no variables (just _intercept if requested)
225            if include_intercept {
226                Array2::<F>::ones((n, 1))
227            } else {
228                Array2::<F>::zeros((n, 0))
229            }
230        }
231        StepwiseDirection::Backward | StepwiseDirection::Both => {
232            // Start with all variables
233            if include_intercept {
234                let mut x_full = Array2::<F>::zeros((n, p + 1));
235                x_full.slice_mut(s![.., 0]).fill(F::one());
236                for i in 0..p {
237                    x_full.slice_mut(s![.., i + 1]).assign(&x.slice(s![.., i]));
238                }
239                x_full
240            } else {
241                x.to_owned()
242            }
243        }
244    };
245
246    // Perform stepwise regression
247    let mut step = 0;
248    let mut criterion_improved = true;
249
250    while step < max_steps && criterion_improved {
251        criterion_improved = false;
252
253        // Forward selection step (if direction is Forward or Both)
254        if direction == StepwiseDirection::Forward || direction == StepwiseDirection::Both {
255            // Find best variable to add
256            let mut best_var = None;
257            let mut best_criterion = F::infinity();
258
259            for i in 0..p {
260                // Skip if already in model
261                if selected_indices.contains(&i) {
262                    continue;
263                }
264
265                // Add this variable to model temporarily
266                let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
267                let var_col = x.slice(s![.., i]).to_owned();
268                test_x
269                    .push_column(var_col.view())
270                    .expect("Failed to push column");
271
272                // Evaluate model
273                if let Ok(model) = linear_regression(&test_x.view(), y) {
274                    let crit_value =
275                        calculate_criterion(&model, n, model.coefficients.len(), criterion);
276
277                    if is_criterion_better(crit_value, best_criterion, criterion) {
278                        best_var = Some(i);
279                        best_criterion = crit_value;
280                    }
281                }
282            }
283
284            // Add best variable if it meets entry criterion
285            if let Some(var_idx) = best_var {
286                let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
287                let var_col = x.slice(s![.., var_idx]).to_owned();
288                test_x
289                    .push_column(var_col.view())
290                    .expect("Failed to push column");
291
292                if let Ok(model) = linear_regression(&test_x.view(), y) {
293                    let var_pos = test_x.ncols() - 1;
294                    let _t_value = model.t_values[var_pos];
295                    let p_value = model.p_values[var_pos];
296
297                    if p_value <= p_enter {
298                        selected_indices.insert(var_idx);
299                        current_x = test_x;
300                        sequence.push((var_idx, true));
301                        criteria_values.push(best_criterion);
302                        criterion_improved = true;
303                    }
304                }
305            }
306        }
307
308        // Backward elimination step (if direction is Backward or Both)
309        if (direction == StepwiseDirection::Backward || direction == StepwiseDirection::Both)
310            && !criterion_improved
311            && !selected_indices.is_empty()
312        {
313            // Find worst variable to _remove
314            let mut worst_var = None;
315            let mut worst_criterion = F::infinity();
316
317            for &var_idx in &selected_indices {
318                // Create model without this variable
319                let mut test_indices = selected_indices.clone();
320                test_indices.remove(&var_idx);
321
322                let test_x = create_model_matrix(x, &test_indices, include_intercept);
323
324                // Evaluate model
325                if let Ok(model) = linear_regression(&test_x.view(), y) {
326                    let crit_value =
327                        calculate_criterion(&model, n, model.coefficients.len(), criterion);
328
329                    if is_criterion_better(crit_value, worst_criterion, criterion) {
330                        worst_var = Some(var_idx);
331                        worst_criterion = crit_value;
332                    }
333                }
334            }
335
336            // Remove worst variable if it meets removal criterion
337            if let Some(var_idx) = worst_var {
338                let var_pos = find_var_position(&current_x, x, var_idx, include_intercept);
339
340                if let Ok(model) = linear_regression(&current_x.view(), y) {
341                    let p_value = model.p_values[var_pos];
342
343                    if p_value > p_remove {
344                        selected_indices.remove(&var_idx);
345                        current_x = create_model_matrix(x, &selected_indices, include_intercept);
346                        sequence.push((var_idx, false));
347                        criteria_values.push(worst_criterion);
348                        criterion_improved = true;
349                    }
350                }
351            }
352        }
353
354        step += 1;
355    }
356
357    // Calculate final model
358    let final_model = linear_regression(&current_x.view(), y)?;
359
360    // Create results
361    let selected_indices = selected_indices.into_iter().collect();
362
363    Ok(StepwiseResults {
364        final_model,
365        selected_indices,
366        sequence,
367        criteria_values,
368    })
369}
370
371// Helper functions
372#[allow(dead_code)]
373fn create_model_matrix<F>(
374    x: &ArrayView2<F>,
375    indices: &HashSet<usize>,
376    include_intercept: bool,
377) -> Array2<F>
378where
379    F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
380{
381    let n = x.nrows();
382    let p = indices.len();
383
384    let cols = if include_intercept { p + 1 } else { p };
385    let mut x_model = Array2::<F>::zeros((n, cols));
386
387    if include_intercept {
388        x_model.slice_mut(s![.., 0]).fill(F::one());
389    }
390
391    let offset = if include_intercept { 1 } else { 0 };
392
393    for (i, &idx) in indices.iter().enumerate() {
394        x_model
395            .slice_mut(s![.., i + offset])
396            .assign(&x.slice(s![.., idx]));
397    }
398
399    x_model
400}
401
402#[allow(dead_code)]
403fn find_var_position<F>(
404    current_x: &Array2<F>,
405    x: &ArrayView2<F>,
406    var_idx: usize,
407    include_intercept: bool,
408) -> usize
409where
410    F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
411{
412    let offset = if include_intercept { 1 } else { 0 };
413
414    for i in offset..current_x.ncols() {
415        let col = current_x.slice(s![.., i]);
416        let x_col = x.slice(s![.., var_idx]);
417
418        if col
419            .iter()
420            .zip(x_col.iter())
421            .all(|(&a, &b)| (a - b).abs() < F::epsilon())
422        {
423            return i;
424        }
425    }
426
427    // Default to last column if not found
428    current_x.ncols() - 1
429}
430
431#[allow(dead_code)]
432fn calculate_criterion<F>(
433    model: &RegressionResults<F>,
434    n: usize,
435    p: usize,
436    criterion: StepwiseCriterion,
437) -> F
438where
439    F: Float + 'static + std::iter::Sum<F> + std::fmt::Debug + std::fmt::Display,
440{
441    match criterion {
442        StepwiseCriterion::AIC => {
443            let rss: F = model
444                .residuals
445                .iter()
446                .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
447                .sum();
448            let n_f = F::from(n).unwrap();
449            let k_f = F::from(p).unwrap();
450            n_f * scirs2_core::numeric::Float::ln(rss / n_f) + F::from(2.0).unwrap() * k_f
451        }
452        StepwiseCriterion::BIC => {
453            let rss: F = model
454                .residuals
455                .iter()
456                .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
457                .sum();
458            let n_f = F::from(n).unwrap();
459            let k_f = F::from(p).unwrap();
460            n_f * scirs2_core::numeric::Float::ln(rss / n_f)
461                + k_f * scirs2_core::numeric::Float::ln(n_f)
462        }
463        StepwiseCriterion::AdjR2 => {
464            -model.adj_r_squared // Negative because we want to maximize adj R^2
465        }
466        StepwiseCriterion::F => {
467            -model.f_statistic // Negative because we want to maximize F
468        }
469        StepwiseCriterion::T => {
470            // Use minimum absolute t-value
471            let min_t = model
472                .t_values
473                .iter()
474                .map(|&t| t.abs())
475                .fold(F::infinity(), |a, b| a.min(b));
476            -min_t // Negative because we want to maximize min |t|
477        }
478    }
479}
480
481#[allow(dead_code)]
482fn is_criterion_better<F>(_new_value: F, oldvalue: F, criterion: StepwiseCriterion) -> bool
483where
484    F: Float + std::fmt::Display,
485{
486    match criterion {
487        // For AIC and BIC, lower is better
488        StepwiseCriterion::AIC | StepwiseCriterion::BIC => _new_value < oldvalue,
489
490        // For Adj R^2, F, and T, we stored negative values, so lower is better
491        StepwiseCriterion::AdjR2 | StepwiseCriterion::F | StepwiseCriterion::T => {
492            _new_value < oldvalue
493        }
494    }
495}
496
497// Internal helper function for linear regression
498#[allow(dead_code)]
499fn linear_regression<F>(x: &ArrayView2<F>, y: &ArrayView1<F>) -> StatsResult<RegressionResults<F>>
500where
501    F: Float
502        + std::iter::Sum<F>
503        + std::ops::Div<Output = F>
504        + std::fmt::Debug
505        + std::fmt::Display
506        + 'static
507        + scirs2_core::numeric::NumAssign
508        + scirs2_core::numeric::One
509        + scirs2_core::ndarray::ScalarOperand
510        + Send
511        + Sync,
512{
513    let n = x.nrows();
514    let p = x.ncols();
515
516    // We need at least p+1 observations for inference
517    if n <= p {
518        return Err(StatsError::InvalidArgument(format!(
519            "Number of observations ({}) must be greater than number of predictors ({})",
520            n, p
521        )));
522    }
523
524    // Solve least squares problem
525    let coefficients = match lstsq(x, y, None) {
526        Ok(result) => result.x,
527        Err(e) => {
528            return Err(StatsError::ComputationError(format!(
529                "Least squares computation failed: {:?}",
530                e
531            )))
532        }
533    };
534
535    // Calculate fitted values and residuals
536    let fitted_values = x.dot(&coefficients);
537    let residuals = y.to_owned() - &fitted_values;
538
539    // Calculate degrees of freedom
540    let df_model = p - 1; // Subtract 1 if intercept included
541    let df_residuals = n - p;
542
543    // Calculate sum of squares
544    let (_y_mean, ss_total, ss_residual, ss_explained) =
545        calculate_sum_of_squares(y, &residuals.view());
546
547    // Calculate R-squared and adjusted R-squared
548    let r_squared = ss_explained / ss_total;
549    let adj_r_squared = F::one()
550        - (F::one() - r_squared) * F::from(n - 1).unwrap() / F::from(df_residuals).unwrap();
551
552    // Calculate mean squared error and residual standard error
553    let mse = ss_residual / F::from(df_residuals).unwrap();
554    let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
555
556    // Calculate standard errors for coefficients
557    let std_errors = match calculate_std_errors(x, &residuals.view(), df_residuals) {
558        Ok(se) => se,
559        Err(_) => Array1::<F>::zeros(p),
560    };
561
562    // Calculate t-values
563    let t_values = calculate_t_values(&coefficients, &std_errors);
564
565    // Calculate p-values (simplified)
566    // In a real implementation, we would use a proper t-distribution function
567    let p_values = t_values.mapv(|t| {
568        let t_abs = scirs2_core::numeric::Float::abs(t);
569        let df_f = F::from(df_residuals).unwrap();
570        F::from(2.0).unwrap()
571            * (F::one() - t_abs / scirs2_core::numeric::Float::sqrt(df_f + t_abs * t_abs))
572    });
573
574    // Calculate confidence intervals
575    let mut conf_intervals = Array2::<F>::zeros((p, 2));
576    for i in 0..p {
577        let margin = std_errors[i] * F::from(1.96).unwrap(); // Approximate 95% CI
578        conf_intervals[[i, 0]] = coefficients[i] - margin;
579        conf_intervals[[i, 1]] = coefficients[i] + margin;
580    }
581
582    // Calculate F-statistic
583    let f_statistic = if df_model > 0 && df_residuals > 0 {
584        (ss_explained / F::from(df_model).unwrap()) / (ss_residual / F::from(df_residuals).unwrap())
585    } else {
586        F::infinity()
587    };
588
589    // Calculate p-value for F-statistic (simplified)
590    let f_p_value = F::zero(); // In a real implementation, use F-distribution
591
592    // Create and return the results structure
593    Ok(RegressionResults {
594        coefficients,
595        std_errors,
596        t_values,
597        p_values,
598        conf_intervals,
599        r_squared,
600        adj_r_squared,
601        f_statistic,
602        f_p_value,
603        residual_std_error,
604        df_residuals,
605        residuals,
606        fitted_values,
607        inlier_mask: vec![true; n], // All points are inliers in stepwise regression
608    })
609}