1use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_linalg::compat::ArrayLinalgExt;
7use sklears_core::{
9 error::{validate, Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
16pub struct OrthogonalMatchingPursuitConfig {
17 pub n_nonzero_coefs: Option<usize>,
19 pub tol: Option<Float>,
21 pub fit_intercept: bool,
23 pub normalize: bool,
25}
26
27impl Default for OrthogonalMatchingPursuitConfig {
28 fn default() -> Self {
29 Self {
30 n_nonzero_coefs: None,
31 tol: None,
32 fit_intercept: true,
33 normalize: true,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct OrthogonalMatchingPursuit<State = Untrained> {
41 config: OrthogonalMatchingPursuitConfig,
42 state: PhantomData<State>,
43 coef_: Option<Array1<Float>>,
45 intercept_: Option<Float>,
46 n_features_: Option<usize>,
47 n_iter_: Option<usize>,
48}
49
50impl OrthogonalMatchingPursuit<Untrained> {
51 pub fn new() -> Self {
53 Self {
54 config: OrthogonalMatchingPursuitConfig::default(),
55 state: PhantomData,
56 coef_: None,
57 intercept_: None,
58 n_features_: None,
59 n_iter_: None,
60 }
61 }
62
63 pub fn n_nonzero_coefs(mut self, n_nonzero_coefs: usize) -> Self {
65 self.config.n_nonzero_coefs = Some(n_nonzero_coefs);
66 self
67 }
68
69 pub fn tol(mut self, tol: Float) -> Self {
71 self.config.tol = Some(tol);
72 self
73 }
74
75 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
77 self.config.fit_intercept = fit_intercept;
78 self
79 }
80
81 pub fn normalize(mut self, normalize: bool) -> Self {
83 self.config.normalize = normalize;
84 self
85 }
86}
87
88impl Default for OrthogonalMatchingPursuit<Untrained> {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl Estimator for OrthogonalMatchingPursuit<Untrained> {
95 type Float = Float;
96 type Config = OrthogonalMatchingPursuitConfig;
97 type Error = SklearsError;
98
99 fn config(&self) -> &Self::Config {
100 &self.config
101 }
102}
103
104impl Fit<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Untrained> {
105 type Fitted = OrthogonalMatchingPursuit<Trained>;
106
107 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
108 validate::check_consistent_length(x, y)?;
110
111 let n_samples = x.nrows();
112 let n_features = x.ncols();
113
114 let max_features = if let Some(n) = self.config.n_nonzero_coefs {
116 n.min(n_features).min(n_samples)
117 } else if self.config.tol.is_some() {
118 n_features.min(n_samples)
119 } else {
120 n_features.min(n_samples)
122 };
123
124 let tol = self.config.tol.unwrap_or(1e-3);
125
126 let x_mean = x.mean_axis(Axis(0)).unwrap();
128 let mut x_centered = x - &x_mean;
129
130 let y_mean = if self.config.fit_intercept {
131 y.mean().unwrap_or(0.0)
132 } else {
133 0.0
134 };
135 let y_centered = y - y_mean;
136
137 let x_scale = if self.config.normalize {
139 let mut scale = Array1::zeros(n_features);
140 for j in 0..n_features {
141 let col = x_centered.column(j);
142 scale[j] = col.dot(&col).sqrt();
143 if scale[j] > Float::EPSILON {
144 x_centered.column_mut(j).mapv_inplace(|x| x / scale[j]);
145 } else {
146 scale[j] = 1.0;
147 }
148 }
149 scale
150 } else {
151 Array1::ones(n_features)
152 };
153
154 let mut coef = Array1::zeros(n_features);
156 let mut active: Vec<usize> = Vec::new();
157 let mut residual = y_centered.clone();
158 let mut n_iter = 0;
159
160 for _ in 0..max_features {
162 let correlations = x_centered.t().dot(&residual);
164
165 let mut max_corr = 0.0;
167 let mut best_idx = 0;
168
169 for j in 0..n_features {
170 if !active.contains(&j) {
171 let corr = correlations[j].abs();
172 if corr > max_corr {
173 max_corr = corr;
174 best_idx = j;
175 }
176 }
177 }
178
179 let residual_norm = residual.dot(&residual).sqrt();
181 if residual_norm < tol {
182 break;
183 }
184
185 active.push(best_idx);
187 n_iter += 1;
188
189 let n_active = active.len();
191 let mut x_active = Array2::zeros((n_samples, n_active));
192 for (i, &j) in active.iter().enumerate() {
193 x_active.column_mut(i).assign(&x_centered.column(j));
194 }
195
196 let gram = x_active.t().dot(&x_active);
198 let x_active_t_y = x_active.t().dot(&y_centered);
199
200 let mut gram_reg = gram.clone();
202 for i in 0..n_active {
203 gram_reg[[i, i]] += 1e-10;
204 }
205
206 let coef_active = &gram_reg
207 .solve(&x_active_t_y)
208 .map_err(|e| SklearsError::NumericalError(format!("Failed to solve: {}", e)))?;
209
210 coef.fill(0.0);
212 for (i, &j) in active.iter().enumerate() {
213 coef[j] = coef_active[i];
214 }
215
216 residual = &y_centered - &x_centered.dot(&coef);
218 }
219
220 if self.config.normalize {
222 for j in 0..n_features {
223 if x_scale[j] > 0.0 {
224 coef[j] /= x_scale[j];
225 }
226 }
227 }
228
229 let intercept = if self.config.fit_intercept {
231 Some(y_mean - x_mean.dot(&coef))
232 } else {
233 None
234 };
235
236 Ok(OrthogonalMatchingPursuit {
237 config: self.config,
238 state: PhantomData,
239 coef_: Some(coef),
240 intercept_: intercept,
241 n_features_: Some(n_features),
242 n_iter_: Some(n_iter),
243 })
244 }
245}
246
247impl OrthogonalMatchingPursuit<Trained> {
248 pub fn coef(&self) -> &Array1<Float> {
250 self.coef_.as_ref().expect("Model is trained")
251 }
252
253 pub fn intercept(&self) -> Option<Float> {
255 self.intercept_
256 }
257
258 pub fn n_iter(&self) -> usize {
260 self.n_iter_.expect("Model is trained")
261 }
262}
263
264impl Predict<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
265 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
266 let n_features = self.n_features_.expect("Model is trained");
267 validate::check_n_features(x, n_features)?;
268
269 let coef = self.coef_.as_ref().expect("Model is trained");
270 let mut predictions = x.dot(coef);
271
272 if let Some(intercept) = self.intercept_ {
273 predictions += intercept;
274 }
275
276 Ok(predictions)
277 }
278}
279
280impl Score<Array2<Float>, Array1<Float>> for OrthogonalMatchingPursuit<Trained> {
281 type Float = Float;
282
283 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
284 let predictions = self.predict(x)?;
285
286 let ss_res = (&predictions - y).mapv(|x| x * x).sum();
288 let y_mean = y.mean().unwrap_or(0.0);
289 let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
290
291 if ss_tot == 0.0 {
292 return Ok(1.0);
293 }
294
295 Ok(1.0 - (ss_res / ss_tot))
296 }
297}
298
299#[allow(non_snake_case)]
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use approx::assert_abs_diff_eq;
304 use scirs2_core::ndarray::array;
305
306 #[test]
307 fn test_omp_simple() {
308 let x = array![
310 [1.0, 0.0],
311 [0.0, 1.0],
312 [1.0, 0.0],
313 [0.0, 1.0],
314 [2.0, 0.0],
315 [0.0, 2.0],
316 ];
317 let y = array![2.0, 3.0, 2.0, 3.0, 4.0, 6.0]; let model = OrthogonalMatchingPursuit::new()
320 .fit_intercept(false)
321 .normalize(false)
322 .fit(&x, &y)
323 .unwrap();
324
325 let coef = model.coef();
327 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-5);
328 assert_abs_diff_eq!(coef[1], 3.0, epsilon = 1e-5);
329
330 let predictions = model.predict(&x).unwrap();
332 for i in 0..y.len() {
333 assert_abs_diff_eq!(predictions[i], y[i], epsilon = 1e-5);
334 }
335 }
336
337 #[test]
338 fn test_omp_max_features() {
339 let x = array![
341 [1.0, 0.1, 0.01],
342 [2.0, 0.2, 0.02],
343 [3.0, 0.3, 0.03],
344 [4.0, 0.4, 0.04],
345 [5.0, 0.5, 0.05],
346 [6.0, 0.6, 0.06],
347 ];
348 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; let model = OrthogonalMatchingPursuit::new()
351 .n_nonzero_coefs(1)
352 .fit_intercept(false)
353 .normalize(false)
354 .fit(&x, &y)
355 .unwrap();
356
357 let coef = model.coef();
358 let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
359 assert_eq!(n_nonzero, 1);
360
361 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 1e-3);
363
364 assert_eq!(model.n_iter(), 1);
366 }
367
368 #[test]
369 fn test_omp_tolerance() {
370 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
372 let y = array![2.1, 3.9, 6.05, 7.95, 10.1]; let model = OrthogonalMatchingPursuit::new()
375 .tol(0.5) .fit_intercept(false)
377 .fit(&x, &y)
378 .unwrap();
379
380 let _predictions = model.predict(&x).unwrap();
382 let r2 = model.score(&x, &y).unwrap();
383 assert!(r2 > 0.95);
384 }
385
386 #[test]
387 fn test_omp_with_intercept() {
388 let x = array![[1.0], [2.0], [3.0], [4.0]];
389 let y = array![3.0, 5.0, 7.0, 9.0]; let model = OrthogonalMatchingPursuit::new()
392 .fit_intercept(true)
393 .fit(&x, &y)
394 .unwrap();
395
396 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-5);
397 assert_abs_diff_eq!(model.intercept().unwrap(), 1.0, epsilon = 1e-5);
398 }
399
400 #[test]
401 fn test_omp_sparse_recovery() {
402 let n_samples = 20;
404 let n_features = 10;
405 let mut x = Array2::zeros((n_samples, n_features));
406 let mut true_coef = Array1::zeros(n_features);
407
408 for i in 0..n_samples {
410 for j in 0..n_features {
411 x[[i, j]] = ((i * 7 + j * 13) % 20) as Float / 10.0 - 1.0;
412 }
413 }
414
415 true_coef[1] = 2.0;
417 true_coef[4] = -1.5;
418 true_coef[7] = 1.0;
419
420 let y = x.dot(&true_coef);
421
422 let model = OrthogonalMatchingPursuit::new()
423 .n_nonzero_coefs(3)
424 .fit_intercept(false)
425 .normalize(true)
426 .fit(&x, &y)
427 .unwrap();
428
429 let coef = model.coef();
430
431 for j in 0..n_features {
433 if true_coef[j] != 0.0 {
434 assert!(
435 coef[j].abs() > 0.1,
436 "Failed to recover non-zero coefficient at index {}",
437 j
438 );
439 }
440 }
441
442 let n_nonzero = coef.iter().filter(|&&c| c.abs() > 1e-10).count();
444 assert_eq!(n_nonzero, 3);
445 }
446}