sklears_feature_selection/domain_specific/
image_features.rs1use crate::base::SelectorMixin;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, Axis};
8use sklears_core::{
9 error::{validate, Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Trained, Transform, Untrained},
11 types::Float,
12};
13use std::marker::PhantomData;
14
15#[derive(Debug, Clone)]
52pub struct ImageFeatureSelector<State = Untrained> {
53 include_spatial: bool,
55 include_frequency: bool,
57 include_texture: bool,
59 spatial_threshold: f64,
61 k: Option<usize>,
63 state: PhantomData<State>,
64 spatial_scores_: Option<Array1<Float>>,
66 frequency_scores_: Option<Array1<Float>>,
67 texture_scores_: Option<Array1<Float>>,
68 selected_features_: Option<Vec<usize>>,
69}
70
71impl ImageFeatureSelector<Untrained> {
72 pub fn new() -> Self {
81 Self {
82 include_spatial: true,
83 include_frequency: true,
84 include_texture: true,
85 spatial_threshold: 0.1,
86 k: None,
87 state: PhantomData,
88 spatial_scores_: None,
89 frequency_scores_: None,
90 texture_scores_: None,
91 selected_features_: None,
92 }
93 }
94
95 pub fn include_spatial(mut self, include_spatial: bool) -> Self {
100 self.include_spatial = include_spatial;
101 self
102 }
103
104 pub fn include_frequency(mut self, include_frequency: bool) -> Self {
109 self.include_frequency = include_frequency;
110 self
111 }
112
113 pub fn include_texture(mut self, include_texture: bool) -> Self {
118 self.include_texture = include_texture;
119 self
120 }
121
122 pub fn spatial_threshold(mut self, threshold: f64) -> Self {
127 self.spatial_threshold = threshold;
128 self
129 }
130
131 pub fn k(mut self, k: Option<usize>) -> Self {
136 self.k = k;
137 self
138 }
139}
140
141impl Default for ImageFeatureSelector<Untrained> {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl Estimator for ImageFeatureSelector<Untrained> {
148 type Config = ();
149 type Error = SklearsError;
150 type Float = f64;
151
152 fn config(&self) -> &Self::Config {
153 &()
154 }
155}
156
157impl Fit<Array2<Float>, Array1<Float>> for ImageFeatureSelector<Untrained> {
158 type Fitted = ImageFeatureSelector<Trained>;
159
160 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
161 validate::check_consistent_length(x, y)?;
162
163 let (_, n_features) = x.dim();
164
165 let spatial_scores = if self.include_spatial {
167 Some(compute_spatial_correlation_scores(x, y))
168 } else {
169 None
170 };
171
172 let frequency_scores = if self.include_frequency {
174 Some(compute_frequency_domain_scores(x, y))
175 } else {
176 None
177 };
178
179 let texture_scores = if self.include_texture {
181 Some(compute_texture_scores(x, y))
182 } else {
183 None
184 };
185
186 let mut combined_scores = Array1::zeros(n_features);
188 let mut weight_sum = 0.0;
189
190 if let Some(ref spatial) = spatial_scores {
191 for i in 0..n_features {
192 combined_scores[i] += 0.4 * spatial[i];
193 }
194 weight_sum += 0.4;
195 }
196
197 if let Some(ref frequency) = frequency_scores {
198 for i in 0..n_features {
199 combined_scores[i] += 0.3 * frequency[i];
200 }
201 weight_sum += 0.3;
202 }
203
204 if let Some(ref texture) = texture_scores {
205 for i in 0..n_features {
206 combined_scores[i] += 0.3 * texture[i];
207 }
208 weight_sum += 0.3;
209 }
210
211 if weight_sum > 0.0 {
212 combined_scores /= weight_sum;
213 }
214
215 let selected_features = self.select_features_from_combined_scores(&combined_scores);
217
218 Ok(ImageFeatureSelector {
219 include_spatial: self.include_spatial,
220 include_frequency: self.include_frequency,
221 include_texture: self.include_texture,
222 spatial_threshold: self.spatial_threshold,
223 k: self.k,
224 state: PhantomData,
225 spatial_scores_: spatial_scores,
226 frequency_scores_: frequency_scores,
227 texture_scores_: texture_scores,
228 selected_features_: Some(selected_features),
229 })
230 }
231}
232
233impl ImageFeatureSelector<Untrained> {
234 fn select_features_from_combined_scores(&self, scores: &Array1<Float>) -> Vec<usize> {
235 let mut feature_indices: Vec<(usize, Float)> = scores
236 .indexed_iter()
237 .map(|(i, &score)| (i, score))
238 .collect();
239
240 feature_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
241
242 let selected: Vec<usize> = if let Some(k) = self.k {
243 feature_indices
244 .iter()
245 .take(k.min(feature_indices.len()))
246 .map(|(i, _)| *i)
247 .collect()
248 } else {
249 feature_indices
250 .iter()
251 .filter(|(_, score)| *score >= self.spatial_threshold)
252 .map(|(i, _)| *i)
253 .collect()
254 };
255
256 let mut selected_sorted = selected;
257 selected_sorted.sort();
258
259 if selected_sorted.is_empty() {
260 if let Some(&(best_idx, _)) = feature_indices.first() {
261 selected_sorted.push(best_idx);
262 selected_sorted.sort();
263 }
264 }
265 selected_sorted
266 }
267}
268
269impl Transform<Array2<Float>> for ImageFeatureSelector<Trained> {
270 fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
271 let selected_features = self.selected_features_.as_ref().unwrap();
272 if selected_features.is_empty() {
273 return Err(SklearsError::InvalidInput(
274 "No features were selected".to_string(),
275 ));
276 }
277
278 let selected_indices: Vec<usize> = selected_features.to_vec();
279 Ok(x.select(Axis(1), &selected_indices))
280 }
281}
282
283impl SelectorMixin for ImageFeatureSelector<Trained> {
284 fn get_support(&self) -> SklResult<Array1<bool>> {
285 let selected_features = self.selected_features_.as_ref().unwrap();
286 let n_features = if let Some(ref scores) = self.spatial_scores_ {
287 scores.len()
288 } else if let Some(ref scores) = self.frequency_scores_ {
289 scores.len()
290 } else if let Some(ref scores) = self.texture_scores_ {
291 scores.len()
292 } else {
293 selected_features.iter().max().unwrap_or(&0) + 1
294 };
295
296 let mut support = Array1::from_elem(n_features, false);
297 for &idx in selected_features {
298 if idx < n_features {
299 support[idx] = true;
300 }
301 }
302 Ok(support)
303 }
304
305 fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
306 let selected_features = self.selected_features_.as_ref().unwrap();
307 Ok(indices
308 .iter()
309 .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
310 .collect())
311 }
312}
313
314impl ImageFeatureSelector<Trained> {
315 pub fn spatial_scores(&self) -> Option<&Array1<Float>> {
319 self.spatial_scores_.as_ref()
320 }
321
322 pub fn frequency_scores(&self) -> Option<&Array1<Float>> {
326 self.frequency_scores_.as_ref()
327 }
328
329 pub fn texture_scores(&self) -> Option<&Array1<Float>> {
333 self.texture_scores_.as_ref()
334 }
335
336 pub fn selected_features(&self) -> &[usize] {
338 self.selected_features_.as_ref().unwrap()
339 }
340
341 pub fn n_features_selected(&self) -> usize {
343 self.selected_features_.as_ref().unwrap().len()
344 }
345
346 pub fn feature_summary(&self) -> Vec<(usize, Option<Float>, Option<Float>, Option<Float>)> {
351 let indices = self.selected_features();
352 let mut summary = Vec::with_capacity(indices.len());
353
354 for &idx in indices {
355 let spatial_score = self.spatial_scores_.as_ref().map(|scores| scores[idx]);
356 let frequency_score = self.frequency_scores_.as_ref().map(|scores| scores[idx]);
357 let texture_score = self.texture_scores_.as_ref().map(|scores| scores[idx]);
358
359 summary.push((idx, spatial_score, frequency_score, texture_score));
360 }
361
362 summary
363 }
364}
365
366fn compute_spatial_correlation_scores(x: &Array2<Float>, y: &Array1<Float>) -> Array1<Float> {
376 let (_, n_features) = x.dim();
377 let mut scores = Array1::zeros(n_features);
378
379 for j in 0..n_features {
380 let feature = x.column(j);
381 let corr = compute_pearson_correlation(&feature, y);
383 scores[j] = corr.abs();
384 }
385
386 scores
387}
388
389fn compute_frequency_domain_scores(x: &Array2<Float>, y: &Array1<Float>) -> Array1<Float> {
395 let (_, n_features) = x.dim();
396 let mut scores = Array1::zeros(n_features);
397
398 for j in 0..n_features {
400 let feature = x.column(j);
401 let variance = feature.var(0.0);
403 let corr = compute_pearson_correlation(&feature, y);
405 scores[j] = variance * corr.abs();
406 }
407
408 scores
409}
410
411fn compute_texture_scores(x: &Array2<Float>, y: &Array1<Float>) -> Array1<Float> {
417 let (_, n_features) = x.dim();
418 let mut scores = Array1::zeros(n_features);
419
420 for j in 0..n_features {
422 let feature = x.column(j);
423 let local_variance = compute_local_variance(&feature);
424 let corr = compute_pearson_correlation(&feature, y);
425 scores[j] = local_variance * corr.abs();
426 }
427
428 scores
429}
430
431fn compute_pearson_correlation(x: &ArrayView1<Float>, y: &Array1<Float>) -> Float {
436 let n = x.len().min(y.len());
437 if n < 2 {
438 return 0.0;
439 }
440
441 let x_mean = x.iter().take(n).sum::<Float>() / n as Float;
442 let y_mean = y.iter().take(n).sum::<Float>() / n as Float;
443
444 let mut numerator = 0.0;
445 let mut x_var = 0.0;
446 let mut y_var = 0.0;
447
448 for i in 0..n {
449 let x_i = x[i] - x_mean;
450 let y_i = y[i] - y_mean;
451 numerator += x_i * y_i;
452 x_var += x_i * x_i;
453 y_var += y_i * y_i;
454 }
455
456 let denominator = (x_var * y_var).sqrt();
457 if denominator.abs() < 1e-10 {
458 0.0
459 } else {
460 numerator / denominator
461 }
462}
463
464fn compute_local_variance(feature: &ArrayView1<Float>) -> Float {
470 let n = feature.len();
471 if n < 3 {
472 return 0.0;
473 }
474
475 let mut local_var = 0.0;
476 let _window_size = 3; for i in 1..(n - 1) {
479 let window = &feature.slice(s![i - 1..i + 2]);
480 let var = window.var(0.0);
481 local_var += var;
482 }
483
484 local_var / (n - 2) as Float
485}
486
487pub fn create_image_feature_selector() -> ImageFeatureSelector<Untrained> {
489 ImageFeatureSelector::new()
490}
491
492pub fn create_low_resolution_selector() -> ImageFeatureSelector<Untrained> {
497 ImageFeatureSelector::new()
498 .include_spatial(true)
499 .include_frequency(false)
500 .include_texture(true)
501 .spatial_threshold(0.05)
502}
503
504pub fn create_high_resolution_selector() -> ImageFeatureSelector<Untrained> {
509 ImageFeatureSelector::new()
510 .include_spatial(true)
511 .include_frequency(true)
512 .include_texture(true)
513 .spatial_threshold(0.1)
514 .k(Some(500))
515}
516
517pub fn create_texture_focused_selector() -> ImageFeatureSelector<Untrained> {
522 ImageFeatureSelector::new()
523 .include_spatial(false)
524 .include_frequency(false)
525 .include_texture(true)
526 .spatial_threshold(0.2)
527}
528
529pub fn create_spatial_focused_selector() -> ImageFeatureSelector<Untrained> {
534 ImageFeatureSelector::new()
535 .include_spatial(true)
536 .include_frequency(false)
537 .include_texture(false)
538 .spatial_threshold(0.15)
539}
540
541#[allow(non_snake_case)]
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use scirs2_core::ndarray::{array, Array2};
546
547 #[test]
548 fn test_pearson_correlation_computation() {
549 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
550 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
551 let corr = compute_pearson_correlation(&x.view(), &y);
552
553 assert!((corr - 1.0).abs() < 1e-10);
555 }
556
557 #[test]
558 fn test_local_variance_computation() {
559 let feature = array![1.0, 1.0, 1.0, 5.0, 5.0, 5.0]; let local_var = compute_local_variance(&feature.view());
561
562 assert!(local_var > 0.0);
564 }
565
566 #[test]
567 fn test_spatial_correlation_scores() {
568 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 2.0, 4.0, 3.0, 6.0]).unwrap();
569 let y = array![1.0, 2.0, 3.0];
570
571 let scores = compute_spatial_correlation_scores(&x, &y);
572 assert_eq!(scores.len(), 2);
573
574 assert!((scores[0] - 1.0).abs() < 1e-10);
576 assert!((scores[1] - 1.0).abs() < 1e-10);
577 }
578
579 #[test]
580 fn test_image_feature_selector_basic() {
581 let selector = ImageFeatureSelector::new()
582 .include_spatial(true)
583 .include_frequency(false)
584 .include_texture(false)
585 .k(Some(1));
586
587 let x = Array2::from_shape_vec(
588 (4, 3),
589 vec![1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
590 )
591 .unwrap();
592 let y = array![1.0, 2.0, 3.0, 4.0];
593
594 let fitted = selector.fit(&x, &y).unwrap();
595 assert_eq!(fitted.n_features_selected(), 1);
596
597 let transformed = fitted.transform(&x).unwrap();
598 assert_eq!(transformed.ncols(), 1);
599 }
600
601 #[test]
602 fn test_feature_selection_with_threshold() {
603 let selector = ImageFeatureSelector::new()
604 .include_spatial(true)
605 .spatial_threshold(0.8); let x = Array2::from_shape_vec(
608 (3, 3),
609 vec![
610 1.0, 0.0, 1.0, 2.0, 1.0, 2.0, 3.0, 0.0, 3.0, ],
614 )
615 .unwrap();
616 let y = array![1.0, 2.0, 3.0];
617
618 let fitted = selector.fit(&x, &y).unwrap();
619
620 assert!(fitted.n_features_selected() >= 1);
622 assert!(fitted.n_features_selected() <= 2);
623 }
624}