1use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_linalg::compat::ArrayLinalgExt;
7use sklears_core::{
9 error::{validate, Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
16pub struct LarsConfig {
17 pub fit_intercept: bool,
19 pub normalize: bool,
21 pub n_nonzero_coefs: Option<usize>,
23 pub eps: Float,
25}
26
27impl Default for LarsConfig {
28 fn default() -> Self {
29 Self {
30 fit_intercept: true,
31 normalize: true,
32 n_nonzero_coefs: None,
33 eps: Float::EPSILON.sqrt(),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct Lars<State = Untrained> {
41 config: LarsConfig,
42 state: PhantomData<State>,
43 coef_: Option<Array1<Float>>,
45 intercept_: Option<Float>,
46 n_features_: Option<usize>,
47 active_: Option<Vec<usize>>,
48 alphas_: Option<Array1<Float>>,
49 n_iter_: Option<usize>,
50}
51
52impl Lars<Untrained> {
53 pub fn new() -> Self {
55 Self {
56 config: LarsConfig::default(),
57 state: PhantomData,
58 coef_: None,
59 intercept_: None,
60 n_features_: None,
61 active_: None,
62 alphas_: None,
63 n_iter_: None,
64 }
65 }
66
67 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
69 self.config.fit_intercept = fit_intercept;
70 self
71 }
72
73 pub fn normalize(mut self, normalize: bool) -> Self {
75 self.config.normalize = normalize;
76 self
77 }
78
79 pub fn n_nonzero_coefs(mut self, n_nonzero_coefs: usize) -> Self {
81 self.config.n_nonzero_coefs = Some(n_nonzero_coefs);
82 self
83 }
84
85 pub fn eps(mut self, eps: Float) -> Self {
87 self.config.eps = eps;
88 self
89 }
90}
91
92impl Default for Lars<Untrained> {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl Estimator for Lars<Untrained> {
99 type Config = LarsConfig;
100 type Error = SklearsError;
101 type Float = Float;
102
103 fn config(&self) -> &Self::Config {
104 &self.config
105 }
106}
107
108impl Fit<Array2<Float>, Array1<Float>> for Lars<Untrained> {
109 type Fitted = Lars<Trained>;
110
111 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
112 validate::check_consistent_length(x, y)?;
114
115 let n_samples = x.nrows();
116 let n_features = x.ncols();
117
118 let x_mean = x.mean_axis(Axis(0)).ok_or_else(|| {
120 SklearsError::NumericalError(
121 "mean computation should succeed for non-empty array".into(),
122 )
123 })?;
124 let mut x_centered = x - &x_mean;
125
126 let y_mean = if self.config.fit_intercept {
127 y.mean().unwrap_or(0.0)
128 } else {
129 0.0
130 };
131 let y_centered = y - y_mean;
132
133 let x_scale = if self.config.normalize {
135 let mut scale = Array1::zeros(n_features);
136 for j in 0..n_features {
137 let col = x_centered.column(j);
138 scale[j] = col.dot(&col).sqrt();
139 if scale[j] > self.config.eps {
140 x_centered.column_mut(j).mapv_inplace(|x| x / scale[j]);
141 } else {
142 scale[j] = 1.0;
143 }
144 }
145 scale
146 } else {
147 Array1::ones(n_features)
148 };
149
150 let mut coef = Array1::zeros(n_features);
152 let mut active: Vec<usize> = Vec::new();
153 let mut alphas = Vec::new();
154
155 let max_features = self
157 .config
158 .n_nonzero_coefs
159 .unwrap_or(n_features)
160 .min(n_features);
161
162 let mut residual = y_centered.clone();
164 let mut correlations = x_centered.t().dot(&residual);
165
166 let mut n_iter = 0;
167
168 while active.len() < max_features {
169 let mut max_corr = 0.0;
171 let mut best_idx = 0;
172
173 for j in 0..n_features {
174 if !active.contains(&j) {
175 let corr = correlations[j].abs();
176 if corr > max_corr {
177 max_corr = corr;
178 best_idx = j;
179 }
180 }
181 }
182
183 if max_corr < self.config.eps {
185 break;
186 }
187
188 active.push(best_idx);
190 alphas.push(max_corr);
191
192 let n_active = active.len();
194 let mut x_active = Array2::zeros((n_samples, n_active));
195 for (i, &j) in active.iter().enumerate() {
196 x_active.column_mut(i).assign(&x_centered.column(j));
197 }
198
199 let gram = x_active.t().dot(&x_active);
201
202 let ones = Array1::ones(n_active);
204
205 let mut gram_reg = gram.clone();
207 for i in 0..n_active {
208 gram_reg[[i, i]] += 1e-10;
209 }
210
211 let gram_inv_ones = &gram_reg
212 .solve(&ones)
213 .map_err(|e| SklearsError::NumericalError(format!("Failed to solve: {}", e)))?;
214
215 let normalization = 1.0 / ones.dot(gram_inv_ones).sqrt();
216 let direction = gram_inv_ones * normalization;
217
218 let equiangular = x_active.dot(&direction);
220
221 let mut gamma = max_corr;
223
224 for j in 0..n_features {
226 if !active.contains(&j) {
227 let a_j = x_centered.column(j).dot(&equiangular);
228 let c_j = correlations[j];
229
230 let gamma_plus = (max_corr - c_j) / (normalization - a_j + self.config.eps);
232 let gamma_minus = (max_corr + c_j) / (normalization + a_j + self.config.eps);
233
234 if gamma_plus > 0.0 && gamma_plus < gamma {
235 gamma = gamma_plus;
236 }
237 if gamma_minus > 0.0 && gamma_minus < gamma {
238 gamma = gamma_minus;
239 }
240 }
241 }
242
243 for (i, &j) in active.iter().enumerate() {
245 coef[j] += gamma * direction[i];
246 }
247
248 residual = residual - gamma * equiangular;
250 correlations = x_centered.t().dot(&residual);
251
252 n_iter += 1;
253 }
254
255 if self.config.normalize {
257 for j in 0..n_features {
258 if x_scale[j] > 0.0 {
259 coef[j] /= x_scale[j];
260 }
261 }
262 }
263
264 let intercept = if self.config.fit_intercept {
266 Some(y_mean - x_mean.dot(&coef))
267 } else {
268 None
269 };
270
271 Ok(Lars {
272 config: self.config,
273 state: PhantomData,
274 coef_: Some(coef),
275 intercept_: intercept,
276 n_features_: Some(n_features),
277 active_: Some(active),
278 alphas_: Some(Array1::from(alphas)),
279 n_iter_: Some(n_iter),
280 })
281 }
282}
283
284impl Lars<Trained> {
285 pub fn coef(&self) -> &Array1<Float> {
287 self.coef_.as_ref().expect("Model is trained")
288 }
289
290 pub fn intercept(&self) -> Option<Float> {
292 self.intercept_
293 }
294
295 pub fn active(&self) -> &[usize] {
297 self.active_.as_ref().expect("Model is trained")
298 }
299
300 pub fn alphas(&self) -> &Array1<Float> {
302 self.alphas_.as_ref().expect("Model is trained")
303 }
304
305 pub fn n_iter(&self) -> usize {
307 self.n_iter_.expect("Model is trained")
308 }
309}
310
311impl Predict<Array2<Float>, Array1<Float>> for Lars<Trained> {
312 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
313 let n_features = self.n_features_.expect("Model is trained");
314 validate::check_n_features(x, n_features)?;
315
316 let coef = self.coef_.as_ref().expect("Model is trained");
317 let mut predictions = x.dot(coef);
318
319 if let Some(intercept) = self.intercept_ {
320 predictions += intercept;
321 }
322
323 Ok(predictions)
324 }
325}
326
327impl Score<Array2<Float>, Array1<Float>> for Lars<Trained> {
328 type Float = Float;
329
330 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
331 let predictions = self.predict(x)?;
332
333 let ss_res = (&predictions - y).mapv(|x| x * x).sum();
335 let y_mean = y.mean().unwrap_or(0.0);
336 let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
337
338 if ss_tot == 0.0 {
339 return Ok(1.0);
340 }
341
342 Ok(1.0 - (ss_res / ss_tot))
343 }
344}
345
346#[allow(non_snake_case)]
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use approx::assert_abs_diff_eq;
351 use scirs2_core::ndarray::array;
352
353 #[test]
354 fn test_lars_simple() {
355 let x = array![[1.0, 2.0], [2.0, 4.0], [3.0, 6.0], [4.0, 8.0]];
357 let y = array![3.0, 6.0, 9.0, 12.0]; let model = Lars::new()
360 .fit_intercept(false)
361 .normalize(false)
362 .fit(&x, &y)
363 .expect("operation should succeed");
364
365 let coef = model.coef();
367 assert!(coef[0].abs() > 0.0 || coef[1].abs() > 0.0);
368
369 let predictions = model.predict(&x).expect("prediction should succeed");
371 for i in 0..4 {
372 assert_abs_diff_eq!(predictions[i], y[i], epsilon = 1e-5);
373 }
374 }
375
376 #[test]
377 fn test_lars_orthogonal_features() {
378 let x = array![
380 [1.0, 0.0],
381 [2.0, 0.0],
382 [0.0, 1.0],
383 [0.0, 2.0],
384 [3.0, 0.0],
385 [0.0, 3.0],
386 ];
387 let y = array![2.0, 4.0, 3.0, 6.0, 6.0, 9.0]; let model = Lars::new()
390 .fit_intercept(false)
391 .normalize(false)
392 .fit(&x, &y)
393 .expect("operation should succeed");
394
395 let _predictions = model.predict(&x).expect("prediction should succeed");
397 let r2 = model.score(&x, &y).expect("scoring should succeed");
398 assert!(
399 r2 > 0.99,
400 "R² score should be very high for perfect linear relationship"
401 );
402 }
403
404 #[test]
405 fn test_lars_max_features() {
406 let x = array![
408 [1.0, 0.1, 0.01],
409 [2.0, 0.2, 0.02],
410 [3.0, 0.3, 0.03],
411 [4.0, 0.4, 0.04],
412 [5.0, 0.5, 0.05],
413 [6.0, 0.6, 0.06],
414 ];
415 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; let model = Lars::new()
418 .fit_intercept(false)
419 .n_nonzero_coefs(1)
420 .normalize(false)
421 .fit(&x, &y)
422 .expect("operation should succeed");
423
424 let coef = model.coef();
425 let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
426 assert_eq!(n_nonzero, 1);
427
428 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-3);
430
431 assert_eq!(model.active().len(), 1);
433 assert_eq!(model.active()[0], 0);
434 }
435
436 #[test]
437 fn test_lars_with_intercept() {
438 let x = array![[1.0], [2.0], [3.0], [4.0]];
439 let y = array![3.0, 5.0, 7.0, 9.0]; let model = Lars::new()
442 .fit_intercept(true)
443 .fit(&x, &y)
444 .expect("model fitting should succeed");
445
446 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-5);
447 assert_abs_diff_eq!(
448 model.intercept().expect("intercept should be available"),
449 1.0,
450 epsilon = 1e-5
451 );
452 }
453}