elastic_net

Function elastic_net 

Source
pub fn elastic_net<F>(
    x: &ArrayView2<'_, F>,
    y: &ArrayView1<'_, F>,
    alpha: Option<F>,
    l1_ratio: Option<F>,
    fit_intercept: Option<bool>,
    normalize: Option<bool>,
    tol: Option<F>,
    max_iter: Option<usize>,
    conf_level: Option<F>,
) -> StatsResult<RegressionResults<F>>
where F: Float + Sum<F> + Div<Output = F> + Debug + Display + 'static + NumAssign + One + ScalarOperand + Send + Sync,
Expand description

Perform elastic net regression (L1 + L2 regularization).

Elastic net combines L1 and L2 penalties, offering a compromise between lasso and ridge regression.

§Arguments

  • x - Independent variables (design matrix)
  • y - Dependent variable
  • alpha - Total regularization strength (default: 1.0)
  • l1_ratio - Ratio of L1 penalty (default: 0.5, 0 = ridge, 1 = lasso)
  • fit_intercept - Whether to fit an intercept term (default: true)
  • normalize - Whether to normalize the data before fitting (default: false)
  • tol - Convergence tolerance (default: 1e-4)
  • max_iter - Maximum number of iterations (default: 1000)
  • conf_level - Confidence level for confidence intervals (default: 0.95)

§Returns

A RegressionResults struct with the regression results.

§Examples

use ndarray::{array, Array2};
use scirs2_stats::elastic_net;

// Create a design matrix with 5 variables
let x = Array2::from_shape_vec((10, 5), vec![
    1.0, 2.0, 0.1, 0.2, 0.3,
    2.0, 3.0, 0.2, 0.3, 0.4,
    3.0, 4.0, 0.3, 0.4, 0.5,
    4.0, 5.0, 0.4, 0.5, 0.6,
    5.0, 6.0, 0.5, 0.6, 0.7,
    6.0, 7.0, 0.6, 0.7, 0.8,
    7.0, 8.0, 0.7, 0.8, 0.9,
    8.0, 9.0, 0.8, 0.9, 1.0,
    9.0, 10.0, 0.9, 1.0, 1.1,
    10.0, 11.0, 1.0, 1.1, 1.2,
]).unwrap();

// Target values
let y = array![5.0, 8.0, 11.0, 14.0, 17.0, 20.0, 23.0, 26.0, 29.0, 32.0];

// Perform elastic net regression with alpha=0.1 and l1_ratio=0.5
let result = elastic_net(&x.view(), &y.view(), Some(0.1), Some(0.5), None, None, None, None, None).unwrap();

// Check that we got coefficients
assert!(result.coefficients.len() > 0);