Skip to main content

scirs2_stats/regression/
stepwise.rs

1//! Stepwise regression implementations
2
3use crate::error::{StatsError, StatsResult};
4use crate::regression::utils::*;
5use crate::regression::RegressionResults;
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::Float;
8use scirs2_linalg::lstsq;
9use std::collections::HashSet;
10
11/// Direction for stepwise regression
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum StepwiseDirection {
14    /// Forward selection (start with no variables and add)
15    Forward,
16    /// Backward elimination (start with all variables and remove)
17    Backward,
18    /// Bidirectional selection (both add and remove)
19    Both,
20}
21
22/// Criterion for selecting variables in stepwise regression
23#[derive(Debug, Clone, Copy)]
24pub enum StepwiseCriterion {
25    /// Akaike Information Criterion (AIC)
26    AIC,
27    /// Bayesian Information Criterion (BIC)
28    BIC,
29    /// Adjusted R-squared
30    AdjR2,
31    /// F-test significance
32    F,
33    /// t-test significance
34    T,
35}
36
37/// Results from stepwise regression
38pub struct StepwiseResults<F>
39where
40    F: Float + std::fmt::Debug + std::fmt::Display + 'static,
41{
42    /// The final regression model
43    pub final_model: RegressionResults<F>,
44
45    /// Indices of selected variables
46    pub selected_indices: Vec<usize>,
47
48    /// Variable entry/exit sequence
49    pub sequence: Vec<(usize, bool)>, // (index, is_entry)
50
51    /// Criteria values at each step
52    pub criteria_values: Vec<F>,
53}
54
55impl<F> StepwiseResults<F>
56where
57    F: Float + std::fmt::Debug + std::fmt::Display + 'static,
58{
59    /// Returns a summary of the stepwise regression process
60    pub fn summary(&self) -> String {
61        let mut summary = String::new();
62
63        summary.push_str("=== Stepwise Regression Results ===\n\n");
64
65        // Selected variables
66        summary.push_str("Selected variables: ");
67        for (i, &idx) in self.selected_indices.iter().enumerate() {
68            if i > 0 {
69                summary.push_str(", ");
70            }
71            summary.push_str(&format!("X{}", idx));
72        }
73        summary.push_str("\n\n");
74
75        // Sequence of entry/exit
76        summary.push_str("Sequence of variable entry/exit:\n");
77        for (i, &(idx, is_entry)) in self.sequence.iter().enumerate() {
78            summary.push_str(&format!(
79                "Step {}: {} X{} (criterion value: {})\n",
80                i + 1,
81                if is_entry { "Added" } else { "Removed" },
82                idx,
83                self.criteria_values[i]
84            ));
85        }
86        summary.push('\n');
87
88        // Final model summary
89        summary.push_str("Final Model:\n");
90        summary.push_str(&self.final_model.summary());
91
92        summary
93    }
94}
95
96/// Perform stepwise regression using various criteria and directions.
97///
98/// # Arguments
99///
100/// * `x` - Independent variables (design matrix)
101/// * `y` - Dependent variable
102/// * `direction` - Direction for stepwise regression (Forward, Backward, or Both)
103/// * `criterion` - Criterion for variable selection
104/// * `p_enter` - p-value threshold for entering variables (for F or T criteria)
105/// * `p_remove` - p-value threshold for removing variables (for F or T criteria)
106/// * `max_steps` - Maximum number of steps to perform
107/// * `include_intercept` - Whether to include an intercept term
108///
109/// # Returns
110///
111/// A StepwiseResults struct with the final model and selection details.
112///
113/// # Examples
114///
115/// ```
116/// use scirs2_core::ndarray::{array, Array2};
117/// use scirs2_stats::{stepwise_regression, StepwiseDirection, StepwiseCriterion};
118///
119/// // Create a design matrix with 3 variables (independent)
120/// let x = Array2::from_shape_vec((10, 3), vec![
121///     1.0, 0.0, 0.0,
122///     0.0, 1.0, 0.0,
123///     0.0, 0.0, 1.0,
124///     1.0, 1.0, 0.0,
125///     1.0, 0.0, 1.0,
126///     0.0, 1.0, 1.0,
127///     1.0, 1.0, 1.0,
128///     2.0, 0.0, 0.0,
129///     0.0, 2.0, 0.0,
130///     0.0, 0.0, 2.0,
131/// ]).expect("Operation failed");
132///
133/// // Target values: y = 2.0*x0 + 3.0*x1 + small noise (clearly depends on first two variables)
134/// let y = array![
135///     2.0, 3.0, 0.1, 5.0, 2.1, 3.1, 5.1, 4.0, 6.0, 0.2
136/// ];
137///
138/// // Perform forward stepwise regression using AIC with relaxed p-value threshold
139/// let results = stepwise_regression(
140///     &x.view(),
141///     &y.view(),
142///     StepwiseDirection::Forward,
143///     StepwiseCriterion::AIC,
144///     Some(0.5), // More relaxed entry threshold
145///     Some(0.6), // More relaxed removal threshold
146///     None,
147///     true
148/// ).expect("Operation failed");
149///
150/// // Check that the algorithm selected at least one variable
151/// assert!(!results.selected_indices.is_empty());
152/// ```
153#[allow(clippy::too_many_arguments)]
154#[allow(dead_code)]
155pub fn stepwise_regression<F>(
156    x: &ArrayView2<F>,
157    y: &ArrayView1<F>,
158    direction: StepwiseDirection,
159    criterion: StepwiseCriterion,
160    p_enter: Option<F>,
161    p_remove: Option<F>,
162    max_steps: Option<usize>,
163    include_intercept: bool,
164) -> StatsResult<StepwiseResults<F>>
165where
166    F: Float
167        + std::iter::Sum<F>
168        + std::ops::Div<Output = F>
169        + std::fmt::Debug
170        + std::fmt::Display
171        + 'static
172        + scirs2_core::numeric::NumAssign
173        + scirs2_core::numeric::One
174        + scirs2_core::ndarray::ScalarOperand
175        + Send
176        + Sync,
177{
178    // Check input dimensions
179    if x.nrows() != y.len() {
180        return Err(StatsError::DimensionMismatch(format!(
181            "Input x has {} rows but y has length {}",
182            x.nrows(),
183            y.len()
184        )));
185    }
186
187    let n = x.nrows();
188    let p = x.ncols();
189
190    // Need at least 3 observations for meaningful regression
191    if n < 3 {
192        return Err(StatsError::InvalidArgument(
193            "At least 3 observations required for stepwise regression".to_string(),
194        ));
195    }
196
197    // Default thresholds for entry/removal
198    let p_enter =
199        p_enter.unwrap_or_else(|| F::from(0.05).expect("Failed to convert constant to float"));
200    let p_remove =
201        p_remove.unwrap_or_else(|| F::from(0.1).expect("Failed to convert constant to float"));
202
203    // Default maximum _steps
204    let max_steps = max_steps.unwrap_or(p * 2);
205
206    // Track selected variables
207    let mut selected_indices = match direction {
208        StepwiseDirection::Forward => HashSet::new(),
209        StepwiseDirection::Backward | StepwiseDirection::Both => {
210            // Start with all variables
211            let mut indices = HashSet::new();
212            for i in 0..p {
213                indices.insert(i);
214            }
215            indices
216        }
217    };
218
219    // Track variable entry/exit sequence and criteria values
220    let mut sequence = Vec::new();
221    let mut criteria_values = Vec::new();
222
223    // Keep track of current model
224    let mut current_x = match direction {
225        StepwiseDirection::Forward => {
226            // Start with no variables (just _intercept if requested)
227            if include_intercept {
228                Array2::<F>::ones((n, 1))
229            } else {
230                Array2::<F>::zeros((n, 0))
231            }
232        }
233        StepwiseDirection::Backward | StepwiseDirection::Both => {
234            // Start with all variables
235            if include_intercept {
236                let mut x_full = Array2::<F>::zeros((n, p + 1));
237                x_full.slice_mut(s![.., 0]).fill(F::one());
238                for i in 0..p {
239                    x_full.slice_mut(s![.., i + 1]).assign(&x.slice(s![.., i]));
240                }
241                x_full
242            } else {
243                x.to_owned()
244            }
245        }
246    };
247
248    // Perform stepwise regression
249    let mut step = 0;
250    let mut criterion_improved = true;
251
252    while step < max_steps && criterion_improved {
253        criterion_improved = false;
254
255        // Forward selection step (if direction is Forward or Both)
256        if direction == StepwiseDirection::Forward || direction == StepwiseDirection::Both {
257            // Find best variable to add
258            let mut best_var = None;
259            let mut best_criterion = F::infinity();
260
261            for i in 0..p {
262                // Skip if already in model
263                if selected_indices.contains(&i) {
264                    continue;
265                }
266
267                // Add this variable to model temporarily
268                let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
269                let var_col = x.slice(s![.., i]).to_owned();
270                test_x
271                    .push_column(var_col.view())
272                    .expect("Failed to push column");
273
274                // Evaluate model
275                if let Ok(model) = linear_regression(&test_x.view(), y) {
276                    let crit_value =
277                        calculate_criterion(&model, n, model.coefficients.len(), criterion);
278
279                    if is_criterion_better(crit_value, best_criterion, criterion) {
280                        best_var = Some(i);
281                        best_criterion = crit_value;
282                    }
283                }
284            }
285
286            // Add best variable if it meets entry criterion
287            if let Some(var_idx) = best_var {
288                let mut test_x = create_model_matrix(x, &selected_indices, include_intercept);
289                let var_col = x.slice(s![.., var_idx]).to_owned();
290                test_x
291                    .push_column(var_col.view())
292                    .expect("Failed to push column");
293
294                if let Ok(model) = linear_regression(&test_x.view(), y) {
295                    let var_pos = test_x.ncols() - 1;
296                    let _t_value = model.t_values[var_pos];
297                    let p_value = model.p_values[var_pos];
298
299                    if p_value <= p_enter {
300                        selected_indices.insert(var_idx);
301                        current_x = test_x;
302                        sequence.push((var_idx, true));
303                        criteria_values.push(best_criterion);
304                        criterion_improved = true;
305                    }
306                }
307            }
308        }
309
310        // Backward elimination step (if direction is Backward or Both)
311        if (direction == StepwiseDirection::Backward || direction == StepwiseDirection::Both)
312            && !criterion_improved
313            && !selected_indices.is_empty()
314        {
315            // Find worst variable to _remove
316            let mut worst_var = None;
317            let mut worst_criterion = F::infinity();
318
319            for &var_idx in &selected_indices {
320                // Create model without this variable
321                let mut test_indices = selected_indices.clone();
322                test_indices.remove(&var_idx);
323
324                let test_x = create_model_matrix(x, &test_indices, include_intercept);
325
326                // Evaluate model
327                if let Ok(model) = linear_regression(&test_x.view(), y) {
328                    let crit_value =
329                        calculate_criterion(&model, n, model.coefficients.len(), criterion);
330
331                    if is_criterion_better(crit_value, worst_criterion, criterion) {
332                        worst_var = Some(var_idx);
333                        worst_criterion = crit_value;
334                    }
335                }
336            }
337
338            // Remove worst variable if it meets removal criterion
339            if let Some(var_idx) = worst_var {
340                let var_pos = find_var_position(&current_x, x, var_idx, include_intercept);
341
342                if let Ok(model) = linear_regression(&current_x.view(), y) {
343                    let p_value = model.p_values[var_pos];
344
345                    if p_value > p_remove {
346                        selected_indices.remove(&var_idx);
347                        current_x = create_model_matrix(x, &selected_indices, include_intercept);
348                        sequence.push((var_idx, false));
349                        criteria_values.push(worst_criterion);
350                        criterion_improved = true;
351                    }
352                }
353            }
354        }
355
356        step += 1;
357    }
358
359    // Calculate final model
360    let final_model = linear_regression(&current_x.view(), y)?;
361
362    // Create results
363    let selected_indices = selected_indices.into_iter().collect();
364
365    Ok(StepwiseResults {
366        final_model,
367        selected_indices,
368        sequence,
369        criteria_values,
370    })
371}
372
373// Helper functions
374#[allow(dead_code)]
375fn create_model_matrix<F>(
376    x: &ArrayView2<F>,
377    indices: &HashSet<usize>,
378    include_intercept: bool,
379) -> Array2<F>
380where
381    F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
382{
383    let n = x.nrows();
384    let p = indices.len();
385
386    let cols = if include_intercept { p + 1 } else { p };
387    let mut x_model = Array2::<F>::zeros((n, cols));
388
389    if include_intercept {
390        x_model.slice_mut(s![.., 0]).fill(F::one());
391    }
392
393    let offset = if include_intercept { 1 } else { 0 };
394
395    for (i, &idx) in indices.iter().enumerate() {
396        x_model
397            .slice_mut(s![.., i + offset])
398            .assign(&x.slice(s![.., idx]));
399    }
400
401    x_model
402}
403
404#[allow(dead_code)]
405fn find_var_position<F>(
406    current_x: &Array2<F>,
407    x: &ArrayView2<F>,
408    var_idx: usize,
409    include_intercept: bool,
410) -> usize
411where
412    F: Float + 'static + std::iter::Sum<F> + std::fmt::Display,
413{
414    let offset = if include_intercept { 1 } else { 0 };
415
416    for i in offset..current_x.ncols() {
417        let col = current_x.slice(s![.., i]);
418        let x_col = x.slice(s![.., var_idx]);
419
420        if col
421            .iter()
422            .zip(x_col.iter())
423            .all(|(&a, &b)| (a - b).abs() < F::epsilon())
424        {
425            return i;
426        }
427    }
428
429    // Default to last column if not found
430    current_x.ncols() - 1
431}
432
433#[allow(dead_code)]
434fn calculate_criterion<F>(
435    model: &RegressionResults<F>,
436    n: usize,
437    p: usize,
438    criterion: StepwiseCriterion,
439) -> F
440where
441    F: Float + 'static + std::iter::Sum<F> + std::fmt::Debug + std::fmt::Display,
442{
443    match criterion {
444        StepwiseCriterion::AIC => {
445            let rss: F = model
446                .residuals
447                .iter()
448                .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
449                .sum();
450            let n_f = F::from(n).expect("Failed to convert to float");
451            let k_f = F::from(p).expect("Failed to convert to float");
452            n_f * scirs2_core::numeric::Float::ln(rss / n_f)
453                + F::from(2.0).expect("Failed to convert constant to float") * k_f
454        }
455        StepwiseCriterion::BIC => {
456            let rss: F = model
457                .residuals
458                .iter()
459                .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
460                .sum();
461            let n_f = F::from(n).expect("Failed to convert to float");
462            let k_f = F::from(p).expect("Failed to convert to float");
463            n_f * scirs2_core::numeric::Float::ln(rss / n_f)
464                + k_f * scirs2_core::numeric::Float::ln(n_f)
465        }
466        StepwiseCriterion::AdjR2 => {
467            -model.adj_r_squared // Negative because we want to maximize adj R^2
468        }
469        StepwiseCriterion::F => {
470            -model.f_statistic // Negative because we want to maximize F
471        }
472        StepwiseCriterion::T => {
473            // Use minimum absolute t-value
474            let min_t = model
475                .t_values
476                .iter()
477                .map(|&t| t.abs())
478                .fold(F::infinity(), |a, b| a.min(b));
479            -min_t // Negative because we want to maximize min |t|
480        }
481    }
482}
483
484#[allow(dead_code)]
485fn is_criterion_better<F>(_new_value: F, oldvalue: F, criterion: StepwiseCriterion) -> bool
486where
487    F: Float + std::fmt::Display,
488{
489    match criterion {
490        // For AIC and BIC, lower is better
491        StepwiseCriterion::AIC | StepwiseCriterion::BIC => _new_value < oldvalue,
492
493        // For Adj R^2, F, and T, we stored negative values, so lower is better
494        StepwiseCriterion::AdjR2 | StepwiseCriterion::F | StepwiseCriterion::T => {
495            _new_value < oldvalue
496        }
497    }
498}
499
500// Internal helper function for linear regression
501#[allow(dead_code)]
502fn linear_regression<F>(x: &ArrayView2<F>, y: &ArrayView1<F>) -> StatsResult<RegressionResults<F>>
503where
504    F: Float
505        + std::iter::Sum<F>
506        + std::ops::Div<Output = F>
507        + std::fmt::Debug
508        + std::fmt::Display
509        + 'static
510        + scirs2_core::numeric::NumAssign
511        + scirs2_core::numeric::One
512        + scirs2_core::ndarray::ScalarOperand
513        + Send
514        + Sync,
515{
516    let n = x.nrows();
517    let p = x.ncols();
518
519    // We need at least p+1 observations for inference
520    if n <= p {
521        return Err(StatsError::InvalidArgument(format!(
522            "Number of observations ({}) must be greater than number of predictors ({})",
523            n, p
524        )));
525    }
526
527    // Solve least squares problem
528    let coefficients = match lstsq(x, y, None) {
529        Ok(result) => result.x,
530        Err(e) => {
531            return Err(StatsError::ComputationError(format!(
532                "Least squares computation failed: {:?}",
533                e
534            )))
535        }
536    };
537
538    // Calculate fitted values and residuals
539    let fitted_values = x.dot(&coefficients);
540    let residuals = y.to_owned() - &fitted_values;
541
542    // Calculate degrees of freedom
543    let df_model = p - 1; // Subtract 1 if intercept included
544    let df_residuals = n - p;
545
546    // Calculate sum of squares
547    let (_y_mean, ss_total, ss_residual, ss_explained) =
548        calculate_sum_of_squares(y, &residuals.view());
549
550    // Calculate R-squared and adjusted R-squared
551    let r_squared = ss_explained / ss_total;
552    let adj_r_squared = F::one()
553        - (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
554            / F::from(df_residuals).expect("Failed to convert to float");
555
556    // Calculate mean squared error and residual standard error
557    let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
558    let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
559
560    // Calculate standard errors for coefficients
561    let std_errors = match calculate_std_errors(x, &residuals.view(), df_residuals) {
562        Ok(se) => se,
563        Err(_) => Array1::<F>::zeros(p),
564    };
565
566    // Calculate t-values
567    let t_values = calculate_t_values(&coefficients, &std_errors);
568
569    // Calculate p-values (simplified)
570    // In a real implementation, we would use a proper t-distribution function
571    let p_values = t_values.mapv(|t| {
572        let t_abs = scirs2_core::numeric::Float::abs(t);
573        let df_f = F::from(df_residuals).expect("Failed to convert to float");
574        F::from(2.0).expect("Failed to convert constant to float")
575            * (F::one() - t_abs / scirs2_core::numeric::Float::sqrt(df_f + t_abs * t_abs))
576    });
577
578    // Calculate confidence intervals
579    let mut conf_intervals = Array2::<F>::zeros((p, 2));
580    for i in 0..p {
581        let margin = std_errors[i] * F::from(1.96).expect("Failed to convert constant to float"); // Approximate 95% CI
582        conf_intervals[[i, 0]] = coefficients[i] - margin;
583        conf_intervals[[i, 1]] = coefficients[i] + margin;
584    }
585
586    // Calculate F-statistic
587    let f_statistic = if df_model > 0 && df_residuals > 0 {
588        (ss_explained / F::from(df_model).expect("Failed to convert to float"))
589            / (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
590    } else {
591        F::infinity()
592    };
593
594    // Calculate p-value for F-statistic (simplified)
595    let f_p_value = F::zero(); // In a real implementation, use F-distribution
596
597    // Create and return the results structure
598    Ok(RegressionResults {
599        coefficients,
600        std_errors,
601        t_values,
602        p_values,
603        conf_intervals,
604        r_squared,
605        adj_r_squared,
606        f_statistic,
607        f_p_value,
608        residual_std_error,
609        df_residuals,
610        residuals,
611        fitted_values,
612        inlier_mask: vec![true; n], // All points are inliers in stepwise regression
613    })
614}