1use crate::base::{FeatureSelector, SelectorMixin};
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use sklears_core::{
9 error::{validate, Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15#[derive(Debug, Clone)]
17pub struct ConvexFeatureSelector<State = Untrained> {
18 k: usize,
19 regularization: Float,
20 max_iter: usize,
21 tolerance: Float,
22 state: PhantomData<State>,
23 weights_: Option<Array1<Float>>,
25 selected_features_: Option<Vec<usize>>,
26 n_features_: Option<usize>,
27 objective_values_: Option<Vec<Float>>,
28}
29
30impl Default for ConvexFeatureSelector<Untrained> {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl ConvexFeatureSelector<Untrained> {
37 pub fn new() -> Self {
39 Self {
40 k: 10,
41 regularization: 1.0,
42 max_iter: 1000,
43 tolerance: 1e-6,
44 state: PhantomData,
45 weights_: None,
46 selected_features_: None,
47 n_features_: None,
48 objective_values_: None,
49 }
50 }
51
52 pub fn k(mut self, k: usize) -> Self {
54 self.k = k;
55 self
56 }
57
58 pub fn regularization(mut self, regularization: Float) -> Self {
60 self.regularization = regularization;
61 self
62 }
63
64 pub fn max_iter(mut self, max_iter: usize) -> Self {
66 self.max_iter = max_iter;
67 self
68 }
69
70 pub fn tolerance(mut self, tolerance: Float) -> Self {
72 self.tolerance = tolerance;
73 self
74 }
75
76 fn solve_convex_optimization(
78 &self,
79 features: &Array2<Float>,
80 target: &Array1<Float>,
81 ) -> SklResult<(Array1<Float>, Vec<Float>)> {
82 let n_features = features.ncols();
83 let n_samples = features.nrows();
84
85 let mut weights = Array1::from_elem(n_features, 1.0 / n_features as Float);
87 let mut objective_values = Vec::new();
88
89 for iter in 0..self.max_iter {
91 let predictions = features.dot(&weights);
93
94 let residuals = &predictions - target;
96
97 let data_gradient = features.t().dot(&residuals) / n_samples as Float;
99
100 let reg_gradient = weights.mapv(|w| {
102 if w > 0.0 {
103 self.regularization
104 } else if w < 0.0 {
105 -self.regularization
106 } else {
107 0.0 }
109 });
110
111 let gradient = data_gradient + reg_gradient;
112
113 let step_size = 0.01 / (iter + 1) as Float;
115
116 let new_weights = &weights - step_size * &gradient;
118
119 let new_weights = new_weights.mapv(|w| w.max(0.0));
121
122 let data_term = residuals.mapv(|r| r * r).sum() / (2.0 * n_samples as Float);
124 let reg_term = self.regularization * weights.mapv(|w| w.abs()).sum();
125 let objective = data_term + reg_term;
126 objective_values.push(objective);
127
128 let weight_diff = (&new_weights - &weights).mapv(|d| d.abs()).sum();
130 if weight_diff < self.tolerance {
131 break;
132 }
133
134 weights = new_weights;
135 }
136
137 Ok((weights, objective_values))
138 }
139}
140
141impl Estimator for ConvexFeatureSelector<Untrained> {
142 type Config = ();
143 type Error = SklearsError;
144 type Float = Float;
145
146 fn config(&self) -> &Self::Config {
147 &()
148 }
149}
150
151impl Fit<Array2<Float>, Array1<Float>> for ConvexFeatureSelector<Untrained> {
152 type Fitted = ConvexFeatureSelector<Trained>;
153
154 fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
155 let n_features = features.ncols();
156 if n_features == 0 {
157 return Err(SklearsError::InvalidInput(
158 "No features provided".to_string(),
159 ));
160 }
161
162 if self.k > n_features {
163 return Err(SklearsError::InvalidInput(format!(
164 "k ({}) cannot be greater than number of features ({})",
165 self.k, n_features
166 )));
167 }
168
169 let (weights, objective_values) = self.solve_convex_optimization(features, target)?;
171
172 let mut feature_indices: Vec<usize> = (0..n_features).collect();
174 feature_indices.sort_by(|&a, &b| {
175 weights[b]
176 .partial_cmp(&weights[a])
177 .unwrap_or(std::cmp::Ordering::Equal)
178 });
179
180 let selected_features = feature_indices.into_iter().take(self.k).collect();
181
182 Ok(ConvexFeatureSelector {
183 k: self.k,
184 regularization: self.regularization,
185 max_iter: self.max_iter,
186 tolerance: self.tolerance,
187 state: PhantomData,
188 weights_: Some(weights),
189 selected_features_: Some(selected_features),
190 n_features_: Some(n_features),
191 objective_values_: Some(objective_values),
192 })
193 }
194}
195
196impl Transform<Array2<Float>> for ConvexFeatureSelector<Trained> {
197 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
198 validate::check_n_features(x, self.n_features_.unwrap())?;
199
200 let selected_features = self.selected_features_.as_ref().unwrap();
201 let n_samples = x.nrows();
202 let n_selected = selected_features.len();
203 let mut x_new = Array2::zeros((n_samples, n_selected));
204
205 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
206 x_new.column_mut(new_idx).assign(&x.column(old_idx));
207 }
208
209 Ok(x_new)
210 }
211}
212
213impl SelectorMixin for ConvexFeatureSelector<Trained> {
214 fn get_support(&self) -> SklResult<Array1<bool>> {
215 let n_features = self.n_features_.unwrap();
216 let selected_features = self.selected_features_.as_ref().unwrap();
217 let mut support = Array1::from_elem(n_features, false);
218
219 for &idx in selected_features {
220 support[idx] = true;
221 }
222
223 Ok(support)
224 }
225
226 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
227 let selected_features = self.selected_features_.as_ref().unwrap();
228 Ok(indices
229 .iter()
230 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
231 .collect())
232 }
233}
234
235impl FeatureSelector for ConvexFeatureSelector<Trained> {
236 fn selected_features(&self) -> &Vec<usize> {
237 self.selected_features_.as_ref().unwrap()
238 }
239}
240
241impl ConvexFeatureSelector<Trained> {
242 pub fn weights(&self) -> &Array1<Float> {
244 self.weights_.as_ref().unwrap()
245 }
246
247 pub fn objective_values(&self) -> &[Float] {
249 self.objective_values_.as_ref().unwrap()
250 }
251
252 pub fn n_features_out(&self) -> usize {
254 self.selected_features_.as_ref().unwrap().len()
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct ProximalGradientSelector<State = Untrained> {
261 k: usize,
262 regularization: Float,
263 max_iter: usize,
264 tolerance: Float,
265 step_size: Float,
266 state: PhantomData<State>,
267 weights_: Option<Array1<Float>>,
269 selected_features_: Option<Vec<usize>>,
270 n_features_: Option<usize>,
271 objective_values_: Option<Vec<Float>>,
272}
273
274impl Default for ProximalGradientSelector<Untrained> {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280impl ProximalGradientSelector<Untrained> {
281 pub fn new() -> Self {
283 Self {
284 k: 10,
285 regularization: 1.0,
286 max_iter: 1000,
287 tolerance: 1e-6,
288 step_size: 0.01,
289 state: PhantomData,
290 weights_: None,
291 selected_features_: None,
292 n_features_: None,
293 objective_values_: None,
294 }
295 }
296
297 pub fn k(mut self, k: usize) -> Self {
299 self.k = k;
300 self
301 }
302
303 pub fn regularization(mut self, regularization: Float) -> Self {
305 self.regularization = regularization;
306 self
307 }
308
309 pub fn max_iter(mut self, max_iter: usize) -> Self {
311 self.max_iter = max_iter;
312 self
313 }
314
315 pub fn tolerance(mut self, tolerance: Float) -> Self {
317 self.tolerance = tolerance;
318 self
319 }
320
321 pub fn step_size(mut self, step_size: Float) -> Self {
323 self.step_size = step_size;
324 self
325 }
326
327 fn soft_threshold(&self, x: Float, threshold: Float) -> Float {
329 if x > threshold {
330 x - threshold
331 } else if x < -threshold {
332 x + threshold
333 } else {
334 0.0
335 }
336 }
337
338 fn solve_proximal_gradient(
340 &self,
341 features: &Array2<Float>,
342 target: &Array1<Float>,
343 ) -> SklResult<(Array1<Float>, Vec<Float>)> {
344 let n_features = features.ncols();
345 let n_samples = features.nrows();
346
347 let mut weights = Array1::zeros(n_features);
349 let mut objective_values = Vec::new();
350
351 for _iter in 0..self.max_iter {
353 let predictions = features.dot(&weights);
355
356 let residuals = &predictions - target;
358
359 let gradient = features.t().dot(&residuals) / n_samples as Float;
361
362 let temp_weights = &weights - self.step_size * &gradient;
364
365 let threshold = self.step_size * self.regularization;
367 let new_weights = temp_weights.mapv(|w| self.soft_threshold(w, threshold));
368
369 let data_term = residuals.mapv(|r| r * r).sum() / (2.0 * n_samples as Float);
371 let reg_term = self.regularization * weights.mapv(|w| w.abs()).sum();
372 let objective = data_term + reg_term;
373 objective_values.push(objective);
374
375 let weight_diff = (&new_weights - &weights).mapv(|d| d.abs()).sum();
377 if weight_diff < self.tolerance {
378 break;
379 }
380
381 weights = new_weights;
382 }
383
384 Ok((weights, objective_values))
385 }
386}
387
388impl Estimator for ProximalGradientSelector<Untrained> {
389 type Config = ();
390 type Error = SklearsError;
391 type Float = Float;
392
393 fn config(&self) -> &Self::Config {
394 &()
395 }
396}
397
398impl Fit<Array2<Float>, Array1<Float>> for ProximalGradientSelector<Untrained> {
399 type Fitted = ProximalGradientSelector<Trained>;
400
401 fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
402 let n_features = features.ncols();
403 if n_features == 0 {
404 return Err(SklearsError::InvalidInput(
405 "No features provided".to_string(),
406 ));
407 }
408
409 if self.k > n_features {
410 return Err(SklearsError::InvalidInput(format!(
411 "k ({}) cannot be greater than number of features ({})",
412 self.k, n_features
413 )));
414 }
415
416 let (weights, objective_values) = self.solve_proximal_gradient(features, target)?;
418
419 let mut feature_indices: Vec<usize> = (0..n_features).collect();
421 feature_indices.sort_by(|&a, &b| {
422 weights[b]
423 .abs()
424 .partial_cmp(&weights[a].abs())
425 .unwrap_or(std::cmp::Ordering::Equal)
426 });
427
428 let selected_features = feature_indices.into_iter().take(self.k).collect();
429
430 Ok(ProximalGradientSelector {
431 k: self.k,
432 regularization: self.regularization,
433 max_iter: self.max_iter,
434 tolerance: self.tolerance,
435 step_size: self.step_size,
436 state: PhantomData,
437 weights_: Some(weights),
438 selected_features_: Some(selected_features),
439 n_features_: Some(n_features),
440 objective_values_: Some(objective_values),
441 })
442 }
443}
444
445impl Transform<Array2<Float>> for ProximalGradientSelector<Trained> {
446 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
447 validate::check_n_features(x, self.n_features_.unwrap())?;
448
449 let selected_features = self.selected_features_.as_ref().unwrap();
450 let n_samples = x.nrows();
451 let n_selected = selected_features.len();
452 let mut x_new = Array2::zeros((n_samples, n_selected));
453
454 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
455 x_new.column_mut(new_idx).assign(&x.column(old_idx));
456 }
457
458 Ok(x_new)
459 }
460}
461
462impl SelectorMixin for ProximalGradientSelector<Trained> {
463 fn get_support(&self) -> SklResult<Array1<bool>> {
464 let n_features = self.n_features_.unwrap();
465 let selected_features = self.selected_features_.as_ref().unwrap();
466 let mut support = Array1::from_elem(n_features, false);
467
468 for &idx in selected_features {
469 support[idx] = true;
470 }
471
472 Ok(support)
473 }
474
475 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
476 let selected_features = self.selected_features_.as_ref().unwrap();
477 Ok(indices
478 .iter()
479 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
480 .collect())
481 }
482}
483
484impl FeatureSelector for ProximalGradientSelector<Trained> {
485 fn selected_features(&self) -> &Vec<usize> {
486 self.selected_features_.as_ref().unwrap()
487 }
488}
489
490impl ProximalGradientSelector<Trained> {
491 pub fn weights(&self) -> &Array1<Float> {
493 self.weights_.as_ref().unwrap()
494 }
495
496 pub fn objective_values(&self) -> &[Float] {
498 self.objective_values_.as_ref().unwrap()
499 }
500
501 pub fn n_features_out(&self) -> usize {
503 self.selected_features_.as_ref().unwrap().len()
504 }
505}
506
507#[derive(Debug, Clone)]
509pub struct ADMMFeatureSelector<State = Untrained> {
510 k: usize,
511 regularization: Float,
512 max_iter: usize,
513 tolerance: Float,
514 rho: Float, state: PhantomData<State>,
516 weights_: Option<Array1<Float>>,
518 selected_features_: Option<Vec<usize>>,
519 n_features_: Option<usize>,
520 objective_values_: Option<Vec<Float>>,
521}
522
523impl Default for ADMMFeatureSelector<Untrained> {
524 fn default() -> Self {
525 Self::new()
526 }
527}
528
529impl ADMMFeatureSelector<Untrained> {
530 pub fn new() -> Self {
532 Self {
533 k: 10,
534 regularization: 1.0,
535 max_iter: 1000,
536 tolerance: 1e-6,
537 rho: 1.0,
538 state: PhantomData,
539 weights_: None,
540 selected_features_: None,
541 n_features_: None,
542 objective_values_: None,
543 }
544 }
545
546 pub fn k(mut self, k: usize) -> Self {
548 self.k = k;
549 self
550 }
551
552 pub fn regularization(mut self, regularization: Float) -> Self {
554 self.regularization = regularization;
555 self
556 }
557
558 pub fn max_iter(mut self, max_iter: usize) -> Self {
560 self.max_iter = max_iter;
561 self
562 }
563
564 pub fn tolerance(mut self, tolerance: Float) -> Self {
566 self.tolerance = tolerance;
567 self
568 }
569
570 pub fn rho(mut self, rho: Float) -> Self {
572 self.rho = rho;
573 self
574 }
575
576 fn soft_threshold(&self, x: Float, threshold: Float) -> Float {
578 if x > threshold {
579 x - threshold
580 } else if x < -threshold {
581 x + threshold
582 } else {
583 0.0
584 }
585 }
586
587 fn solve_admm(
589 &self,
590 features: &Array2<Float>,
591 target: &Array1<Float>,
592 ) -> SklResult<(Array1<Float>, Vec<Float>)> {
593 let n_features = features.ncols();
594 let n_samples = features.nrows();
595
596 let mut x = Array1::<Float>::zeros(n_features); let mut z = Array1::<Float>::zeros(n_features); let mut u = Array1::<Float>::zeros(n_features); let mut objective_values = Vec::new();
602
603 let xtx = features.t().dot(features);
605 let xty = features.t().dot(target);
606
607 for _iter in 0..self.max_iter {
609 let _x_old = x.clone();
610 let z_old = z.clone();
611
612 let rhs = &xty + self.rho * (&z - &u);
615
616 for i in 0..n_features {
618 let diag_elem = xtx[[i, i]] + self.rho;
619 if diag_elem > 1e-12 {
620 let off_diag = (0..n_features)
621 .filter(|&j| j != i)
622 .map(|j| xtx[[i, j]] * x[j])
623 .sum::<Float>();
624 x[i] = (rhs[i] - off_diag) / diag_elem;
625 }
626 }
627
628 let threshold = self.regularization / self.rho;
630 for i in 0..n_features {
631 z[i] = self.soft_threshold(x[i] + u[i], threshold);
632 }
633
634 u = &u + &x - &z;
636
637 let predictions = features.dot(&x);
639 let residuals = &predictions - target;
640 let data_term = residuals.mapv(|r| r * r).sum() / (2.0 * n_samples as Float);
641 let reg_term = self.regularization * z.mapv(|z_i| z_i.abs()).sum();
642 let objective = data_term + reg_term;
643 objective_values.push(objective);
644
645 let primal_residual = (&x - &z).mapv(|r| r.abs()).sum();
647 let dual_residual = self.rho * (&z - &z_old).mapv(|r| r.abs()).sum();
648
649 if primal_residual < self.tolerance && dual_residual < self.tolerance {
650 break;
651 }
652 }
653
654 Ok((z, objective_values))
655 }
656}
657
658impl Estimator for ADMMFeatureSelector<Untrained> {
659 type Config = ();
660 type Error = SklearsError;
661 type Float = Float;
662
663 fn config(&self) -> &Self::Config {
664 &()
665 }
666}
667
668impl Fit<Array2<Float>, Array1<Float>> for ADMMFeatureSelector<Untrained> {
669 type Fitted = ADMMFeatureSelector<Trained>;
670
671 fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
672 let n_features = features.ncols();
673 if n_features == 0 {
674 return Err(SklearsError::InvalidInput(
675 "No features provided".to_string(),
676 ));
677 }
678
679 if self.k > n_features {
680 return Err(SklearsError::InvalidInput(format!(
681 "k ({}) cannot be greater than number of features ({})",
682 self.k, n_features
683 )));
684 }
685
686 let (weights, objective_values) = self.solve_admm(features, target)?;
688
689 let mut feature_indices: Vec<usize> = (0..n_features).collect();
691 feature_indices.sort_by(|&a, &b| {
692 weights[b]
693 .abs()
694 .partial_cmp(&weights[a].abs())
695 .unwrap_or(std::cmp::Ordering::Equal)
696 });
697
698 let selected_features = feature_indices.into_iter().take(self.k).collect();
699
700 Ok(ADMMFeatureSelector {
701 k: self.k,
702 regularization: self.regularization,
703 max_iter: self.max_iter,
704 tolerance: self.tolerance,
705 rho: self.rho,
706 state: PhantomData,
707 weights_: Some(weights),
708 selected_features_: Some(selected_features),
709 n_features_: Some(n_features),
710 objective_values_: Some(objective_values),
711 })
712 }
713}
714
715impl Transform<Array2<Float>> for ADMMFeatureSelector<Trained> {
716 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
717 validate::check_n_features(x, self.n_features_.unwrap())?;
718
719 let selected_features = self.selected_features_.as_ref().unwrap();
720 let n_samples = x.nrows();
721 let n_selected = selected_features.len();
722 let mut x_new = Array2::zeros((n_samples, n_selected));
723
724 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
725 x_new.column_mut(new_idx).assign(&x.column(old_idx));
726 }
727
728 Ok(x_new)
729 }
730}
731
732impl SelectorMixin for ADMMFeatureSelector<Trained> {
733 fn get_support(&self) -> SklResult<Array1<bool>> {
734 let n_features = self.n_features_.unwrap();
735 let selected_features = self.selected_features_.as_ref().unwrap();
736 let mut support = Array1::from_elem(n_features, false);
737
738 for &idx in selected_features {
739 support[idx] = true;
740 }
741
742 Ok(support)
743 }
744
745 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
746 let selected_features = self.selected_features_.as_ref().unwrap();
747 Ok(indices
748 .iter()
749 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
750 .collect())
751 }
752}
753
754impl FeatureSelector for ADMMFeatureSelector<Trained> {
755 fn selected_features(&self) -> &Vec<usize> {
756 self.selected_features_.as_ref().unwrap()
757 }
758}
759
760impl ADMMFeatureSelector<Trained> {
761 pub fn weights(&self) -> &Array1<Float> {
763 self.weights_.as_ref().unwrap()
764 }
765
766 pub fn objective_values(&self) -> &[Float] {
768 self.objective_values_.as_ref().unwrap()
769 }
770
771 pub fn n_features_out(&self) -> usize {
773 self.selected_features_.as_ref().unwrap().len()
774 }
775}
776
777#[derive(Debug, Clone)]
779pub struct SemidefiniteFeatureSelector<State = Untrained> {
780 k: usize,
781 max_iter: usize,
782 tolerance: Float,
783 regularization: Float,
784 state: PhantomData<State>,
785 feature_matrix_: Option<Array2<Float>>,
787 selected_features_: Option<Vec<usize>>,
788 n_features_: Option<usize>,
789 eigenvalues_: Option<Array1<Float>>,
790 objective_values_: Option<Vec<Float>>,
791}
792
793impl Default for SemidefiniteFeatureSelector<Untrained> {
794 fn default() -> Self {
795 Self::new()
796 }
797}
798
799impl SemidefiniteFeatureSelector<Untrained> {
800 pub fn new() -> Self {
802 Self {
803 k: 10,
804 max_iter: 100,
805 tolerance: 1e-6,
806 regularization: 1.0,
807 state: PhantomData,
808 feature_matrix_: None,
809 selected_features_: None,
810 n_features_: None,
811 eigenvalues_: None,
812 objective_values_: None,
813 }
814 }
815
816 pub fn k(mut self, k: usize) -> Self {
818 self.k = k;
819 self
820 }
821
822 pub fn max_iter(mut self, max_iter: usize) -> Self {
824 self.max_iter = max_iter;
825 self
826 }
827
828 pub fn tolerance(mut self, tolerance: Float) -> Self {
830 self.tolerance = tolerance;
831 self
832 }
833
834 pub fn regularization(mut self, regularization: Float) -> Self {
836 self.regularization = regularization;
837 self
838 }
839
840 fn project_psd(&self, matrix: &Array2<Float>) -> SklResult<Array2<Float>> {
842 let n = matrix.nrows();
843
844 let mut projected = Array2::zeros((n, n));
846 for i in 0..n {
847 for j in 0..n {
848 projected[[i, j]] = (matrix[[i, j]] + matrix[[j, i]]) / 2.0;
849 }
850 }
851
852 for i in 0..n {
854 if projected[[i, i]] < 0.0 {
855 projected[[i, i]] = 0.0;
856 }
857 }
858
859 Ok(projected)
860 }
861
862 fn solve_sdp_relaxation(
864 &self,
865 features: &Array2<Float>,
866 target: &Array1<Float>,
867 ) -> SklResult<(Array2<Float>, Array1<Float>, Vec<Float>)> {
868 let n_features = features.ncols();
869
870 let centered_features = features - &features.mean_axis(Axis(0)).unwrap();
872 let cov_matrix =
873 centered_features.t().dot(¢ered_features) / (features.nrows() - 1) as Float;
874
875 let target_centered = target - target.mean().unwrap();
877 let correlations =
878 centered_features.t().dot(&target_centered) / (features.nrows() - 1) as Float;
879
880 let mut x_matrix = Array2::eye(n_features) * 0.5; let mut objective_values = Vec::new();
883
884 for _iter in 0..self.max_iter {
886 let _x_old = x_matrix.clone();
887
888 let outer_corr = outer_product(&correlations, &correlations);
891 let grad = &outer_corr - self.regularization * &cov_matrix;
892
893 let step_size = 0.01;
895 let x_new = &x_matrix + step_size * &grad;
896
897 let mut x_projected = self.project_psd(&x_new)?;
899
900 for i in 0..n_features {
902 x_projected[[i, i]] = x_projected[[i, i]].clamp(0.0, 1.0);
903 }
904
905 let obj = correlations.dot(&x_projected.dot(&correlations))
907 - self.regularization * trace(&x_projected.dot(&cov_matrix));
908 objective_values.push(obj);
909
910 let diff = (&x_projected - &x_matrix).mapv(|x| x.abs()).sum();
912 if diff < self.tolerance {
913 break;
914 }
915
916 x_matrix = x_projected;
917 }
918
919 let eigenvalues = extract_diagonal(&x_matrix);
921
922 Ok((x_matrix, eigenvalues, objective_values))
923 }
924}
925
926impl Estimator for SemidefiniteFeatureSelector<Untrained> {
927 type Config = ();
928 type Error = SklearsError;
929 type Float = Float;
930
931 fn config(&self) -> &Self::Config {
932 &()
933 }
934}
935
936impl Fit<Array2<Float>, Array1<Float>> for SemidefiniteFeatureSelector<Untrained> {
937 type Fitted = SemidefiniteFeatureSelector<Trained>;
938
939 fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
940 let n_features = features.ncols();
941 if n_features == 0 {
942 return Err(SklearsError::InvalidInput(
943 "No features provided".to_string(),
944 ));
945 }
946
947 if self.k > n_features {
948 return Err(SklearsError::InvalidInput(format!(
949 "k ({}) cannot be greater than number of features ({})",
950 self.k, n_features
951 )));
952 }
953
954 let (feature_matrix, eigenvalues, objective_values) =
956 self.solve_sdp_relaxation(features, target)?;
957
958 let mut feature_indices: Vec<usize> = (0..n_features).collect();
960 feature_indices.sort_by(|&a, &b| {
961 eigenvalues[b]
962 .partial_cmp(&eigenvalues[a])
963 .unwrap_or(std::cmp::Ordering::Equal)
964 });
965
966 let selected_features = feature_indices.into_iter().take(self.k).collect();
967
968 Ok(SemidefiniteFeatureSelector {
969 k: self.k,
970 max_iter: self.max_iter,
971 tolerance: self.tolerance,
972 regularization: self.regularization,
973 state: PhantomData,
974 feature_matrix_: Some(feature_matrix),
975 selected_features_: Some(selected_features),
976 n_features_: Some(n_features),
977 eigenvalues_: Some(eigenvalues),
978 objective_values_: Some(objective_values),
979 })
980 }
981}
982
983impl Transform<Array2<Float>> for SemidefiniteFeatureSelector<Trained> {
984 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
985 validate::check_n_features(x, self.n_features_.unwrap())?;
986
987 let selected_features = self.selected_features_.as_ref().unwrap();
988 let n_samples = x.nrows();
989 let n_selected = selected_features.len();
990 let mut x_new = Array2::zeros((n_samples, n_selected));
991
992 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
993 x_new.column_mut(new_idx).assign(&x.column(old_idx));
994 }
995
996 Ok(x_new)
997 }
998}
999
1000impl SelectorMixin for SemidefiniteFeatureSelector<Trained> {
1001 fn get_support(&self) -> SklResult<Array1<bool>> {
1002 let n_features = self.n_features_.unwrap();
1003 let selected_features = self.selected_features_.as_ref().unwrap();
1004 let mut support = Array1::from_elem(n_features, false);
1005
1006 for &idx in selected_features {
1007 support[idx] = true;
1008 }
1009
1010 Ok(support)
1011 }
1012
1013 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
1014 let selected_features = self.selected_features_.as_ref().unwrap();
1015 Ok(indices
1016 .iter()
1017 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
1018 .collect())
1019 }
1020}
1021
1022impl FeatureSelector for SemidefiniteFeatureSelector<Trained> {
1023 fn selected_features(&self) -> &Vec<usize> {
1024 self.selected_features_.as_ref().unwrap()
1025 }
1026}
1027
1028impl SemidefiniteFeatureSelector<Trained> {
1029 pub fn feature_matrix(&self) -> &Array2<Float> {
1031 self.feature_matrix_.as_ref().unwrap()
1032 }
1033
1034 pub fn eigenvalues(&self) -> &Array1<Float> {
1036 self.eigenvalues_.as_ref().unwrap()
1037 }
1038
1039 pub fn objective_values(&self) -> &[Float] {
1041 self.objective_values_.as_ref().unwrap()
1042 }
1043
1044 pub fn n_features_out(&self) -> usize {
1046 self.selected_features_.as_ref().unwrap().len()
1047 }
1048}
1049
1050#[derive(Debug, Clone)]
1052pub struct IntegerProgrammingFeatureSelector<State = Untrained> {
1053 k: usize,
1054 max_iter: usize,
1055 tolerance: Float,
1056 greedy_init: bool,
1057 local_search: bool,
1058 state: PhantomData<State>,
1059 binary_solution_: Option<Array1<bool>>,
1061 selected_features_: Option<Vec<usize>>,
1062 n_features_: Option<usize>,
1063 objective_value_: Option<Float>,
1064 improvement_history_: Option<Vec<Float>>,
1065}
1066
1067impl Default for IntegerProgrammingFeatureSelector<Untrained> {
1068 fn default() -> Self {
1069 Self::new()
1070 }
1071}
1072
1073impl IntegerProgrammingFeatureSelector<Untrained> {
1074 pub fn new() -> Self {
1076 Self {
1077 k: 10,
1078 max_iter: 1000,
1079 tolerance: 1e-6,
1080 greedy_init: true,
1081 local_search: true,
1082 state: PhantomData,
1083 binary_solution_: None,
1084 selected_features_: None,
1085 n_features_: None,
1086 objective_value_: None,
1087 improvement_history_: None,
1088 }
1089 }
1090
1091 pub fn k(mut self, k: usize) -> Self {
1093 self.k = k;
1094 self
1095 }
1096
1097 pub fn max_iter(mut self, max_iter: usize) -> Self {
1099 self.max_iter = max_iter;
1100 self
1101 }
1102
1103 pub fn tolerance(mut self, tolerance: Float) -> Self {
1105 self.tolerance = tolerance;
1106 self
1107 }
1108
1109 pub fn greedy_init(mut self, greedy_init: bool) -> Self {
1111 self.greedy_init = greedy_init;
1112 self
1113 }
1114
1115 pub fn local_search(mut self, local_search: bool) -> Self {
1117 self.local_search = local_search;
1118 self
1119 }
1120
1121 fn compute_feature_scores(
1123 &self,
1124 features: &Array2<Float>,
1125 target: &Array1<Float>,
1126 ) -> SklResult<Array1<Float>> {
1127 let n_features = features.ncols();
1128 let mut scores = Array1::zeros(n_features);
1129
1130 for i in 0..n_features {
1132 let feature_col = features.column(i);
1133 let correlation = correlation_coefficient(&feature_col.to_owned(), target)?;
1134 scores[i] = correlation.abs();
1135 }
1136
1137 Ok(scores)
1138 }
1139
1140 fn greedy_initialization(&self, scores: &Array1<Float>) -> Array1<bool> {
1142 let n_features = scores.len();
1143 let mut solution = Array1::from_elem(n_features, false);
1144
1145 let mut indices: Vec<usize> = (0..n_features).collect();
1147 indices.sort_by(|&a, &b| {
1148 scores[b]
1149 .partial_cmp(&scores[a])
1150 .unwrap_or(std::cmp::Ordering::Equal)
1151 });
1152
1153 for &idx in indices.iter().take(self.k) {
1154 solution[idx] = true;
1155 }
1156
1157 solution
1158 }
1159
1160 fn evaluate_objective(&self, solution: &Array1<bool>, scores: &Array1<Float>) -> Float {
1162 let mut objective = 0.0;
1163 let mut selected_count = 0;
1164
1165 for i in 0..solution.len() {
1166 if solution[i] {
1167 objective += scores[i];
1168 selected_count += 1;
1169 }
1170 }
1171
1172 if selected_count != self.k {
1174 objective -= 1000.0 * (selected_count as Float - self.k as Float).abs();
1175 }
1176
1177 objective
1178 }
1179
1180 fn local_search_improvement(
1182 &self,
1183 solution: &mut Array1<bool>,
1184 scores: &Array1<Float>,
1185 best_obj: &mut Float,
1186 ) -> bool {
1187 let n_features = solution.len();
1188 let mut improved = false;
1189
1190 for i in 0..n_features {
1192 let original = solution[i];
1193 solution[i] = !solution[i];
1194
1195 let new_obj = self.evaluate_objective(solution, scores);
1196 if new_obj > *best_obj + self.tolerance {
1197 *best_obj = new_obj;
1198 improved = true;
1199 } else {
1200 solution[i] = original; }
1202 }
1203
1204 if self.local_search {
1206 for i in 0..n_features {
1207 for j in (i + 1)..n_features {
1208 if solution[i] == solution[j] {
1209 continue; }
1211
1212 let temp = solution[i];
1214 solution[i] = solution[j];
1215 solution[j] = temp;
1216
1217 let new_obj = self.evaluate_objective(solution, scores);
1218 if new_obj > *best_obj + self.tolerance {
1219 *best_obj = new_obj;
1220 improved = true;
1221 } else {
1222 let temp = solution[i];
1224 solution[i] = solution[j];
1225 solution[j] = temp;
1226 }
1227 }
1228 }
1229 }
1230
1231 improved
1232 }
1233
1234 fn solve_integer_programming(
1236 &self,
1237 features: &Array2<Float>,
1238 target: &Array1<Float>,
1239 ) -> SklResult<(Array1<bool>, Float, Vec<Float>)> {
1240 let scores = self.compute_feature_scores(features, target)?;
1241 let mut improvement_history = Vec::new();
1242
1243 let mut solution = if self.greedy_init {
1245 self.greedy_initialization(&scores)
1246 } else {
1247 let mut random_solution = Array1::from_elem(scores.len(), false);
1249 let indices: Vec<usize> = (0..scores.len()).collect();
1250 for &idx in indices.iter().take(self.k) {
1251 random_solution[idx] = true;
1252 }
1253 random_solution
1254 };
1255
1256 let mut best_objective = self.evaluate_objective(&solution, &scores);
1257 improvement_history.push(best_objective);
1258
1259 for _iter in 0..self.max_iter {
1261 let prev_objective = best_objective;
1262
1263 let improved =
1264 self.local_search_improvement(&mut solution, &scores, &mut best_objective);
1265 improvement_history.push(best_objective);
1266
1267 if !improved || (best_objective - prev_objective).abs() < self.tolerance {
1268 break;
1269 }
1270 }
1271
1272 Ok((solution, best_objective, improvement_history))
1273 }
1274}
1275
1276impl Estimator for IntegerProgrammingFeatureSelector<Untrained> {
1277 type Config = ();
1278 type Error = SklearsError;
1279 type Float = Float;
1280
1281 fn config(&self) -> &Self::Config {
1282 &()
1283 }
1284}
1285
1286impl Fit<Array2<Float>, Array1<Float>> for IntegerProgrammingFeatureSelector<Untrained> {
1287 type Fitted = IntegerProgrammingFeatureSelector<Trained>;
1288
1289 fn fit(self, features: &Array2<Float>, target: &Array1<Float>) -> SklResult<Self::Fitted> {
1290 let n_features = features.ncols();
1291 if n_features == 0 {
1292 return Err(SklearsError::InvalidInput(
1293 "No features provided".to_string(),
1294 ));
1295 }
1296
1297 if self.k > n_features {
1298 return Err(SklearsError::InvalidInput(format!(
1299 "k ({}) cannot be greater than number of features ({})",
1300 self.k, n_features
1301 )));
1302 }
1303
1304 let (binary_solution, objective_value, improvement_history) =
1306 self.solve_integer_programming(features, target)?;
1307
1308 let selected_features: Vec<usize> = binary_solution
1310 .iter()
1311 .enumerate()
1312 .filter_map(|(i, &selected)| if selected { Some(i) } else { None })
1313 .collect();
1314
1315 Ok(IntegerProgrammingFeatureSelector {
1316 k: self.k,
1317 max_iter: self.max_iter,
1318 tolerance: self.tolerance,
1319 greedy_init: self.greedy_init,
1320 local_search: self.local_search,
1321 state: PhantomData,
1322 binary_solution_: Some(binary_solution),
1323 selected_features_: Some(selected_features),
1324 n_features_: Some(n_features),
1325 objective_value_: Some(objective_value),
1326 improvement_history_: Some(improvement_history),
1327 })
1328 }
1329}
1330
1331impl Transform<Array2<Float>> for IntegerProgrammingFeatureSelector<Trained> {
1332 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
1333 validate::check_n_features(x, self.n_features_.unwrap())?;
1334
1335 let selected_features = self.selected_features_.as_ref().unwrap();
1336 let n_samples = x.nrows();
1337 let n_selected = selected_features.len();
1338 let mut x_new = Array2::zeros((n_samples, n_selected));
1339
1340 for (new_idx, &old_idx) in selected_features.iter().enumerate() {
1341 x_new.column_mut(new_idx).assign(&x.column(old_idx));
1342 }
1343
1344 Ok(x_new)
1345 }
1346}
1347
1348impl SelectorMixin for IntegerProgrammingFeatureSelector<Trained> {
1349 fn get_support(&self) -> SklResult<Array1<bool>> {
1350 Ok(self.binary_solution_.as_ref().unwrap().clone())
1351 }
1352
1353 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
1354 let selected_features = self.selected_features_.as_ref().unwrap();
1355 Ok(indices
1356 .iter()
1357 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
1358 .collect())
1359 }
1360}
1361
1362impl FeatureSelector for IntegerProgrammingFeatureSelector<Trained> {
1363 fn selected_features(&self) -> &Vec<usize> {
1364 self.selected_features_.as_ref().unwrap()
1365 }
1366}
1367
1368impl IntegerProgrammingFeatureSelector<Trained> {
1369 pub fn binary_solution(&self) -> &Array1<bool> {
1371 self.binary_solution_.as_ref().unwrap()
1372 }
1373
1374 pub fn objective_value(&self) -> Float {
1376 self.objective_value_.unwrap()
1377 }
1378
1379 pub fn improvement_history(&self) -> &[Float] {
1381 self.improvement_history_.as_ref().unwrap()
1382 }
1383
1384 pub fn n_features_out(&self) -> usize {
1386 self.selected_features_.as_ref().unwrap().len()
1387 }
1388}
1389
1390fn outer_product(a: &Array1<Float>, b: &Array1<Float>) -> Array2<Float> {
1392 let mut result = Array2::zeros((a.len(), b.len()));
1393 for i in 0..a.len() {
1394 for j in 0..b.len() {
1395 result[[i, j]] = a[i] * b[j];
1396 }
1397 }
1398 result
1399}
1400
1401fn trace(matrix: &Array2<Float>) -> Float {
1402 let n = matrix.nrows().min(matrix.ncols());
1403 (0..n).map(|i| matrix[[i, i]]).sum()
1404}
1405
1406fn extract_diagonal(matrix: &Array2<Float>) -> Array1<Float> {
1407 let n = matrix.nrows().min(matrix.ncols());
1408 let mut diag = Array1::zeros(n);
1409 for i in 0..n {
1410 diag[i] = matrix[[i, i]];
1411 }
1412 diag
1413}
1414
1415fn correlation_coefficient(x: &Array1<Float>, y: &Array1<Float>) -> SklResult<Float> {
1416 if x.len() != y.len() {
1417 return Err(SklearsError::InvalidInput(
1418 "Arrays must have the same length".to_string(),
1419 ));
1420 }
1421
1422 let _n = x.len() as Float;
1423 let mean_x = x.mean().unwrap();
1424 let mean_y = y.mean().unwrap();
1425
1426 let mut num = 0.0;
1427 let mut den_x = 0.0;
1428 let mut den_y = 0.0;
1429
1430 for i in 0..x.len() {
1431 let diff_x = x[i] - mean_x;
1432 let diff_y = y[i] - mean_y;
1433 num += diff_x * diff_y;
1434 den_x += diff_x * diff_x;
1435 den_y += diff_y * diff_y;
1436 }
1437
1438 if den_x.abs() < 1e-10 || den_y.abs() < 1e-10 {
1439 return Ok(0.0);
1440 }
1441
1442 Ok(num / (den_x * den_y).sqrt())
1443}
1444
1445#[allow(non_snake_case)]
1446#[cfg(test)]
1447mod tests {
1448 use super::*;
1449 use scirs2_core::ndarray::Array2;
1450
1451 fn create_test_data() -> (Array2<Float>, Array1<Float>) {
1452 let n_samples = 50;
1454 let n_features = 10;
1455 let mut features = Array2::zeros((n_samples, n_features));
1456 let mut target = Array1::zeros(n_samples);
1457
1458 for i in 0..n_samples {
1460 for j in 0..n_features {
1461 features[[i, j]] = (i as Float * 0.1 + j as Float * 0.01).sin() + 0.1 * j as Float;
1462 }
1463 target[i] = features[[i, 0]] + 0.5 * features[[i, 1]] + 0.1 * features[[i, 2]];
1465 }
1466
1467 (features, target)
1468 }
1469
1470 #[test]
1471 fn test_convex_feature_selector() {
1472 let (features, target) = create_test_data();
1473
1474 let selector = ConvexFeatureSelector::new()
1475 .k(5)
1476 .regularization(0.1)
1477 .max_iter(100);
1478
1479 let trained = selector.fit(&features, &target).unwrap();
1480 assert_eq!(trained.n_features_out(), 5);
1481
1482 let transformed = trained.transform(&features).unwrap();
1484 assert_eq!(transformed.ncols(), 5);
1485 assert_eq!(transformed.nrows(), features.nrows());
1486
1487 let weights = trained.weights();
1489 assert_eq!(weights.len(), features.ncols());
1490 assert!(weights.iter().all(|&x| x.is_finite()));
1491
1492 let obj_vals = trained.objective_values();
1494 assert!(!obj_vals.is_empty());
1495 assert!(obj_vals.iter().all(|&x| x.is_finite()));
1496 }
1497
1498 #[test]
1499 fn test_proximal_gradient_selector() {
1500 let (features, target) = create_test_data();
1501
1502 let selector = ProximalGradientSelector::new()
1503 .k(4)
1504 .regularization(0.1)
1505 .step_size(0.01)
1506 .max_iter(100);
1507
1508 let trained = selector.fit(&features, &target).unwrap();
1509 assert_eq!(trained.n_features_out(), 4);
1510
1511 let transformed = trained.transform(&features).unwrap();
1513 assert_eq!(transformed.ncols(), 4);
1514 assert_eq!(transformed.nrows(), features.nrows());
1515
1516 let weights = trained.weights();
1518 assert_eq!(weights.len(), features.ncols());
1519 assert!(weights.iter().all(|&x| x.is_finite()));
1520 }
1521
1522 #[test]
1523 fn test_admm_feature_selector() {
1524 let (features, target) = create_test_data();
1525
1526 let selector = ADMMFeatureSelector::new()
1527 .k(3)
1528 .regularization(0.1)
1529 .rho(1.0)
1530 .max_iter(50);
1531
1532 let trained = selector.fit(&features, &target).unwrap();
1533 assert_eq!(trained.n_features_out(), 3);
1534
1535 let transformed = trained.transform(&features).unwrap();
1537 assert_eq!(transformed.ncols(), 3);
1538 assert_eq!(transformed.nrows(), features.nrows());
1539
1540 let weights = trained.weights();
1542 assert_eq!(weights.len(), features.ncols());
1543 assert!(weights.iter().all(|&x| x.is_finite()));
1544 }
1545
1546 #[test]
1547 fn test_convex_selector_invalid_k() {
1548 let (features, target) = create_test_data();
1549
1550 let selector = ConvexFeatureSelector::new().k(features.ncols() + 1);
1551 assert!(selector.fit(&features, &target).is_err());
1552 }
1553
1554 #[test]
1555 fn test_proximal_selector_invalid_k() {
1556 let (features, target) = create_test_data();
1557
1558 let selector = ProximalGradientSelector::new().k(features.ncols() + 1);
1559 assert!(selector.fit(&features, &target).is_err());
1560 }
1561
1562 #[test]
1563 fn test_admm_selector_invalid_k() {
1564 let (features, target) = create_test_data();
1565
1566 let selector = ADMMFeatureSelector::new().k(features.ncols() + 1);
1567 assert!(selector.fit(&features, &target).is_err());
1568 }
1569
1570 #[test]
1571 fn test_semidefinite_feature_selector() {
1572 let (features, target) = create_test_data();
1573
1574 let selector = SemidefiniteFeatureSelector::new()
1575 .k(4)
1576 .regularization(0.1)
1577 .max_iter(50);
1578
1579 let trained = selector.fit(&features, &target).unwrap();
1580 assert_eq!(trained.n_features_out(), 4);
1581
1582 let transformed = trained.transform(&features).unwrap();
1584 assert_eq!(transformed.ncols(), 4);
1585 assert_eq!(transformed.nrows(), features.nrows());
1586
1587 let feature_matrix = trained.feature_matrix();
1589 assert_eq!(feature_matrix.nrows(), features.ncols());
1590 assert_eq!(feature_matrix.ncols(), features.ncols());
1591
1592 let eigenvalues = trained.eigenvalues();
1594 assert_eq!(eigenvalues.len(), features.ncols());
1595 assert!(eigenvalues.iter().all(|&x| x.is_finite()));
1596
1597 let obj_vals = trained.objective_values();
1599 assert!(!obj_vals.is_empty());
1600 assert!(obj_vals.iter().all(|&x| x.is_finite()));
1601 }
1602
1603 #[test]
1604 fn test_integer_programming_feature_selector() {
1605 let (features, target) = create_test_data();
1606
1607 let selector = IntegerProgrammingFeatureSelector::new()
1608 .k(3)
1609 .greedy_init(true)
1610 .local_search(true)
1611 .max_iter(100);
1612
1613 let trained = selector.fit(&features, &target).unwrap();
1614 assert_eq!(trained.n_features_out(), 3);
1615
1616 let transformed = trained.transform(&features).unwrap();
1618 assert_eq!(transformed.ncols(), 3);
1619 assert_eq!(transformed.nrows(), features.nrows());
1620
1621 let binary_solution = trained.binary_solution();
1623 assert_eq!(binary_solution.len(), features.ncols());
1624 let selected_count = binary_solution.iter().filter(|&&x| x).count();
1625 assert_eq!(selected_count, 3);
1626
1627 let obj_value = trained.objective_value();
1629 assert!(obj_value.is_finite());
1630
1631 let improvement_history = trained.improvement_history();
1633 assert!(!improvement_history.is_empty());
1634 assert!(improvement_history.iter().all(|&x| x.is_finite()));
1635 }
1636
1637 #[test]
1638 fn test_semidefinite_selector_invalid_k() {
1639 let (features, target) = create_test_data();
1640
1641 let selector = SemidefiniteFeatureSelector::new().k(features.ncols() + 1);
1642 assert!(selector.fit(&features, &target).is_err());
1643 }
1644
1645 #[test]
1646 fn test_integer_programming_selector_invalid_k() {
1647 let (features, target) = create_test_data();
1648
1649 let selector = IntegerProgrammingFeatureSelector::new().k(features.ncols() + 1);
1650 assert!(selector.fit(&features, &target).is_err());
1651 }
1652
1653 #[test]
1654 fn test_correlation_coefficient() {
1655 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1656 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1657
1658 let corr = correlation_coefficient(&x, &y).unwrap();
1659 assert!((corr - 1.0).abs() < 1e-10); let z = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
1662 let corr2 = correlation_coefficient(&x, &z).unwrap();
1663 assert!((corr2 + 1.0).abs() < 1e-10); }
1665
1666 #[test]
1667 fn test_helper_functions() {
1668 let a = Array1::from_vec(vec![1.0, 2.0]);
1669 let b = Array1::from_vec(vec![3.0, 4.0]);
1670
1671 let outer = outer_product(&a, &b);
1672 assert_eq!(outer[[0, 0]], 3.0);
1673 assert_eq!(outer[[0, 1]], 4.0);
1674 assert_eq!(outer[[1, 0]], 6.0);
1675 assert_eq!(outer[[1, 1]], 8.0);
1676
1677 let matrix = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1678 let tr = trace(&matrix);
1679 assert_eq!(tr, 5.0); let diag = extract_diagonal(&matrix);
1682 assert_eq!(diag[0], 1.0);
1683 assert_eq!(diag[1], 4.0);
1684 }
1685}