sklears_linear/
omp.rs

1//! Orthogonal Matching Pursuit (OMP) implementation
2
3use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_linalg::compat::ArrayLinalgExt;
7// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
8use sklears_core::{
9    error::{validate, Result, SklearsError},
10    traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11    types::Float,
12};
13
14/// Configuration for OrthogonalMatchingPursuit
15#[derive(Debug, Clone)]
16pub struct OrthogonalMatchingPursuitConfig {
17    /// Maximum number of non-zero coefficients in the solution
18    pub n_nonzero_coefs: Option<usize>,
19    /// Tolerance for the residual
20    pub tol: Option<Float>,
21    /// Whether to fit the intercept
22    pub fit_intercept: bool,
23    /// Whether to normalize/standardize features before fitting
24    pub normalize: bool,
25}
26
27impl Default for OrthogonalMatchingPursuitConfig {
28    fn default() -> Self {
29        Self {
30            n_nonzero_coefs: None,
31            tol: None,
32            fit_intercept: true,
33            normalize: true,
34        }
35    }
36}
37
38/// Orthogonal Matching Pursuit model
39#[derive(Debug, Clone)]
40pub struct OrthogonalMatchingPursuit<State = Untrained> {
41    config: OrthogonalMatchingPursuitConfig,
42    state: PhantomData<State>,
43    // Trained state fields
44    coef_: Option<Array1<Float>>,
45    intercept_: Option<Float>,
46    n_features_: Option<usize>,
47    n_iter_: Option<usize>,
48}
49
50impl OrthogonalMatchingPursuit<Untrained> {
51    /// Create a new OMP model
52    pub fn new() -> Self {
53        Self {
54            config: OrthogonalMatchingPursuitConfig::default(),
55            state: PhantomData,
56            coef_: None,
57            intercept_: None,
58            n_features_: None,
59            n_iter_: None,
60        }
61    }
62
63    /// Set the maximum number of non-zero coefficients
64    pub fn n_nonzero_coefs(mut self, n_nonzero_coefs: usize) -> Self {
65        self.config.n_nonzero_coefs = Some(n_nonzero_coefs);
66        self
67    }
68
69    /// Set the tolerance for the residual
70    pub fn tol(mut self, tol: Float) -> Self {
71        self.config.tol = Some(tol);
72        self
73    }
74
75    /// Set whether to fit intercept
76    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
77        self.config.fit_intercept = fit_intercept;
78        self
79    }
80
81    /// Set whether to normalize features
82    pub fn normalize(mut self, normalize: bool) -> Self {
83        self.config.normalize = normalize;
84        self
85    }
86}
87
88impl Default for OrthogonalMatchingPursuit<Untrained> {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl Estimator for OrthogonalMatchingPursuit<Untrained> {
95    type Float = Float;
96    type Config = OrthogonalMatchingPursuitConfig;
97    type Error = SklearsError;
98
99    fn config(&self) -> &Self::Config {
100        &self.config
101    }
102}
103
104impl Fit<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Untrained> {
105    type Fitted = OrthogonalMatchingPursuit<Trained>;
106
107    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
108        // Validate inputs
109        validate::check_consistent_length(x, y)?;
110
111        let n_samples = x.nrows();
112        let n_features = x.ncols();
113
114        // Determine stopping criterion
115        let max_features = if let Some(n) = self.config.n_nonzero_coefs {
116            n.min(n_features).min(n_samples)
117        } else if self.config.tol.is_some() {
118            n_features.min(n_samples)
119        } else {
120            // Default: min(n_features, n_samples)
121            n_features.min(n_samples)
122        };
123
124        let tol = self.config.tol.unwrap_or(1e-3);
125
126        // Center X and y
127        let x_mean = x.mean_axis(Axis(0)).unwrap();
128        let mut x_centered = x - &x_mean;
129
130        let y_mean = if self.config.fit_intercept {
131            y.mean().unwrap_or(0.0)
132        } else {
133            0.0
134        };
135        let y_centered = y - y_mean;
136
137        // Normalize X if requested
138        let x_scale = if self.config.normalize {
139            let mut scale = Array1::zeros(n_features);
140            for j in 0..n_features {
141                let col = x_centered.column(j);
142                scale[j] = col.dot(&col).sqrt();
143                if scale[j] > Float::EPSILON {
144                    x_centered.column_mut(j).mapv_inplace(|x| x / scale[j]);
145                } else {
146                    scale[j] = 1.0;
147                }
148            }
149            scale
150        } else {
151            Array1::ones(n_features)
152        };
153
154        // Initialize OMP algorithm
155        let mut coef = Array1::zeros(n_features);
156        let mut active: Vec<usize> = Vec::new();
157        let mut residual = y_centered.clone();
158        let mut n_iter = 0;
159
160        // Main OMP loop
161        for _ in 0..max_features {
162            // Compute correlations with residual
163            let correlations = x_centered.t().dot(&residual);
164
165            // Find the most correlated feature not yet selected
166            let mut max_corr = 0.0;
167            let mut best_idx = 0;
168
169            for j in 0..n_features {
170                if !active.contains(&j) {
171                    let corr = correlations[j].abs();
172                    if corr > max_corr {
173                        max_corr = corr;
174                        best_idx = j;
175                    }
176                }
177            }
178
179            // Check stopping criterion
180            let residual_norm = residual.dot(&residual).sqrt();
181            if residual_norm < tol {
182                break;
183            }
184
185            // Add the best feature to active set
186            active.push(best_idx);
187            n_iter += 1;
188
189            // Solve least squares problem on active set
190            let n_active = active.len();
191            let mut x_active = Array2::zeros((n_samples, n_active));
192            for (i, &j) in active.iter().enumerate() {
193                x_active.column_mut(i).assign(&x_centered.column(j));
194            }
195
196            // Solve normal equations: (X_active^T X_active) coef_active = X_active^T y
197            let gram = x_active.t().dot(&x_active);
198            let x_active_t_y = x_active.t().dot(&y_centered);
199
200            // Add small regularization to avoid singular matrix
201            let mut gram_reg = gram.clone();
202            for i in 0..n_active {
203                gram_reg[[i, i]] += 1e-10;
204            }
205
206            let coef_active = &gram_reg
207                .solve(&x_active_t_y)
208                .map_err(|e| SklearsError::NumericalError(format!("Failed to solve: {}", e)))?;
209
210            // Update full coefficient vector
211            coef.fill(0.0);
212            for (i, &j) in active.iter().enumerate() {
213                coef[j] = coef_active[i];
214            }
215
216            // Update residual
217            residual = &y_centered - &x_centered.dot(&coef);
218        }
219
220        // Rescale coefficients if we normalized
221        if self.config.normalize {
222            for j in 0..n_features {
223                if x_scale[j] > 0.0 {
224                    coef[j] /= x_scale[j];
225                }
226            }
227        }
228
229        // Compute intercept if needed
230        let intercept = if self.config.fit_intercept {
231            Some(y_mean - x_mean.dot(&coef))
232        } else {
233            None
234        };
235
236        Ok(OrthogonalMatchingPursuit {
237            config: self.config,
238            state: PhantomData,
239            coef_: Some(coef),
240            intercept_: intercept,
241            n_features_: Some(n_features),
242            n_iter_: Some(n_iter),
243        })
244    }
245}
246
247impl OrthogonalMatchingPursuit<Trained> {
248    /// Get the coefficients
249    pub fn coef(&self) -> &Array1<Float> {
250        self.coef_.as_ref().expect("Model is trained")
251    }
252
253    /// Get the intercept
254    pub fn intercept(&self) -> Option<Float> {
255        self.intercept_
256    }
257
258    /// Get the number of iterations run
259    pub fn n_iter(&self) -> usize {
260        self.n_iter_.expect("Model is trained")
261    }
262}
263
264impl Predict<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
265    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
266        let n_features = self.n_features_.expect("Model is trained");
267        validate::check_n_features(x, n_features)?;
268
269        let coef = self.coef_.as_ref().expect("Model is trained");
270        let mut predictions = x.dot(coef);
271
272        if let Some(intercept) = self.intercept_ {
273            predictions += intercept;
274        }
275
276        Ok(predictions)
277    }
278}
279
280impl Score<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
281    type Float = Float;
282
283    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
284        let predictions = self.predict(x)?;
285
286        // Calculate R² score
287        let ss_res = (&predictions - y).mapv(|x| x * x).sum();
288        let y_mean = y.mean().unwrap_or(0.0);
289        let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
290
291        if ss_tot == 0.0 {
292            return Ok(1.0);
293        }
294
295        Ok(1.0 - (ss_res / ss_tot))
296    }
297}
298
299#[allow(non_snake_case)]
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use approx::assert_abs_diff_eq;
304    use scirs2_core::ndarray::array;
305
306    #[test]
307    fn test_omp_simple() {
308        // Simple test with orthogonal features
309        let x = array![
310            [1.0, 0.0],
311            [0.0, 1.0],
312            [1.0, 0.0],
313            [0.0, 1.0],
314            [2.0, 0.0],
315            [0.0, 2.0],
316        ];
317        let y = array![2.0, 3.0, 2.0, 3.0, 4.0, 6.0]; // y = 2*x1 + 3*x2
318
319        let model = OrthogonalMatchingPursuit::new()
320            .fit_intercept(false)
321            .normalize(false)
322            .fit(&x, &y)
323            .unwrap();
324
325        // Should recover the true coefficients
326        let coef = model.coef();
327        assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-5);
328        assert_abs_diff_eq!(coef[1], 3.0, epsilon = 1e-5);
329
330        // Predictions should be perfect
331        let predictions = model.predict(&x).unwrap();
332        for i in 0..y.len() {
333            assert_abs_diff_eq!(predictions[i], y[i], epsilon = 1e-5);
334        }
335    }
336
337    #[test]
338    fn test_omp_max_features() {
339        // Test limiting number of features
340        let x = array![
341            [1.0, 0.1, 0.01],
342            [2.0, 0.2, 0.02],
343            [3.0, 0.3, 0.03],
344            [4.0, 0.4, 0.04],
345            [5.0, 0.5, 0.05],
346            [6.0, 0.6, 0.06],
347        ];
348        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; // y = 2*x1
349
350        let model = OrthogonalMatchingPursuit::new()
351            .n_nonzero_coefs(1)
352            .fit_intercept(false)
353            .normalize(false)
354            .fit(&x, &y)
355            .unwrap();
356
357        let coef = model.coef();
358        let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
359        assert_eq!(n_nonzero, 1);
360
361        // First coefficient should be selected and close to 2.0
362        assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-3);
363
364        // Check that we ran exactly 1 iteration
365        assert_eq!(model.n_iter(), 1);
366    }
367
368    #[test]
369    fn test_omp_tolerance() {
370        // Test stopping based on tolerance
371        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
372        let y = array![2.1, 3.9, 6.05, 7.95, 10.1]; // y ≈ 2x with small noise
373
374        let model = OrthogonalMatchingPursuit::new()
375            .tol(0.5) // Relatively high tolerance
376            .fit_intercept(false)
377            .fit(&x, &y)
378            .unwrap();
379
380        // Should get a reasonable approximation
381        let _predictions = model.predict(&x).unwrap();
382        let r2 = model.score(&x, &y).unwrap();
383        assert!(r2 > 0.95);
384    }
385
386    #[test]
387    fn test_omp_with_intercept() {
388        let x = array![[1.0], [2.0], [3.0], [4.0]];
389        let y = array![3.0, 5.0, 7.0, 9.0]; // y = 2x + 1
390
391        let model = OrthogonalMatchingPursuit::new()
392            .fit_intercept(true)
393            .fit(&x, &y)
394            .unwrap();
395
396        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-5);
397        assert_abs_diff_eq!(model.intercept().unwrap(), 1.0, epsilon = 1e-5);
398    }
399
400    #[test]
401    fn test_omp_sparse_recovery() {
402        // Create sparse signal recovery problem
403        let n_samples = 20;
404        let n_features = 10;
405        let mut x = Array2::zeros((n_samples, n_features));
406        let mut true_coef = Array1::zeros(n_features);
407
408        // Generate random-like data deterministically
409        for i in 0..n_samples {
410            for j in 0..n_features {
411                x[[i, j]] = ((i * 7 + j * 13) % 20) as Float / 10.0 - 1.0;
412            }
413        }
414
415        // True coefficients are sparse (only 3 non-zero)
416        true_coef[1] = 2.0;
417        true_coef[4] = -1.5;
418        true_coef[7] = 1.0;
419
420        let y = x.dot(&true_coef);
421
422        let model = OrthogonalMatchingPursuit::new()
423            .n_nonzero_coefs(3)
424            .fit_intercept(false)
425            .normalize(true)
426            .fit(&x, &y)
427            .unwrap();
428
429        let coef = model.coef();
430
431        // Should recover the support (non-zero indices)
432        for j in 0..n_features {
433            if true_coef[j] != 0.0 {
434                assert!(
435                    coef[j].abs() > 0.1,
436                    "Failed to recover non-zero coefficient at index {}",
437                    j
438                );
439            }
440        }
441
442        // Should have exactly 3 non-zero coefficients
443        let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
444        assert_eq!(n_nonzero, 3);
445    }
446}