sklears_feature_selection/domain_specific/
image_features.rs

1//! Image feature selection module
2//!
3//! This module provides specialized feature selection algorithms for image data,
4//! including spatial correlation analysis, frequency domain features, and texture analysis.
5
6use crate::base::SelectorMixin;
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, Axis};
8use sklears_core::{
9    error::{validate, Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Trained, Transform, Untrained},
11    types::Float,
12};
13use std::marker::PhantomData;
14
15/// Image feature selection using spatial correlation and frequency analysis
16///
17/// This selector analyzes image features represented as flattened pixel matrices
18/// or extracted feature vectors and applies image-specific selection criteria:
19/// - Spatial correlation analysis for neighboring pixels
20/// - Frequency domain analysis using variance as a proxy
21/// - Texture analysis using local variance measurements
22/// - Combined scoring with target correlation
23///
24/// # Input Format
25///
26/// The input matrix `X` should be structured as:
27/// - Rows: Images
28/// - Columns: Features (pixels, extracted features, etc.)
29/// - Values: Pixel intensities, feature values, or derived measurements
30///
31/// # Examples
32///
33/// ```rust,ignore
34/// use sklears_feature_selection::domain_specific::image_features::ImageFeatureSelector;
35/// use sklears_core::traits::{Fit, Transform};
36/// use scirs2_core::ndarray::{Array1, Array2};
37///
38/// let selector = ImageFeatureSelector::new()
39///     .include_spatial(true)     // Enable spatial correlation analysis
40///     .include_frequency(true)   // Enable frequency domain analysis
41///     .include_texture(true)     // Enable texture analysis
42///     .spatial_threshold(0.15)   // Threshold for spatial features
43///     .k(Some(100));             // Select top 100 features
44///
45/// let x = Array2::zeros((50, 784)); // 50 images, 784 pixels (28x28)
46/// let y = Array1::zeros(50);         // Image labels
47///
48/// let fitted_selector = selector.fit(&x, &y)?;
49/// let transformed_x = fitted_selector.transform(&x)?;
50/// ```
51#[derive(Debug, Clone)]
52pub struct ImageFeatureSelector<State = Untrained> {
53    /// Whether to include spatial correlation features
54    include_spatial: bool,
55    /// Whether to include frequency domain features
56    include_frequency: bool,
57    /// Whether to include texture features
58    include_texture: bool,
59    /// Threshold for spatial correlation
60    spatial_threshold: f64,
61    /// Number of top features to select
62    k: Option<usize>,
63    state: PhantomData<State>,
64    // Trained state
65    spatial_scores_: Option<Array1<Float>>,
66    frequency_scores_: Option<Array1<Float>>,
67    texture_scores_: Option<Array1<Float>>,
68    selected_features_: Option<Vec<usize>>,
69}
70
71impl ImageFeatureSelector<Untrained> {
72    /// Create a new image feature selector with default parameters
73    ///
74    /// Default configuration:
75    /// - `include_spatial`: true
76    /// - `include_frequency`: true
77    /// - `include_texture`: true
78    /// - `spatial_threshold`: 0.1
79    /// - `k`: None (use threshold-based selection)
80    pub fn new() -> Self {
81        Self {
82            include_spatial: true,
83            include_frequency: true,
84            include_texture: true,
85            spatial_threshold: 0.1,
86            k: None,
87            state: PhantomData,
88            spatial_scores_: None,
89            frequency_scores_: None,
90            texture_scores_: None,
91            selected_features_: None,
92        }
93    }
94
95    /// Enable or disable spatial correlation analysis
96    ///
97    /// When enabled, the selector computes correlations between features
98    /// and the target variable, emphasizing spatial relationships.
99    pub fn include_spatial(mut self, include_spatial: bool) -> Self {
100        self.include_spatial = include_spatial;
101        self
102    }
103
104    /// Enable or disable frequency domain analysis
105    ///
106    /// When enabled, the selector analyzes frequency content using
107    /// variance as a proxy for spectral energy, combined with target correlation.
108    pub fn include_frequency(mut self, include_frequency: bool) -> Self {
109        self.include_frequency = include_frequency;
110        self
111    }
112
113    /// Enable or disable texture analysis
114    ///
115    /// When enabled, the selector computes local variance measurements
116    /// to identify texture-rich regions that correlate with the target.
117    pub fn include_texture(mut self, include_texture: bool) -> Self {
118        self.include_texture = include_texture;
119        self
120    }
121
122    /// Set the threshold for spatial correlation selection
123    ///
124    /// Features with combined scores below this threshold will be filtered out
125    /// (when not using k-based selection).
126    pub fn spatial_threshold(mut self, threshold: f64) -> Self {
127        self.spatial_threshold = threshold;
128        self
129    }
130
131    /// Set the number of top features to select
132    ///
133    /// When set to `Some(k)`, selects the top k features by combined score.
134    /// When set to `None`, uses threshold-based selection.
135    pub fn k(mut self, k: Option<usize>) -> Self {
136        self.k = k;
137        self
138    }
139}
140
141impl Default for ImageFeatureSelector<Untrained> {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl Estimator for ImageFeatureSelector<Untrained> {
148    type Config = ();
149    type Error = SklearsError;
150    type Float = f64;
151
152    fn config(&self) -> &Self::Config {
153        &()
154    }
155}
156
157impl Fit<Array2<Float>, Array1<Float>> for ImageFeatureSelector<Untrained> {
158    type Fitted = ImageFeatureSelector<Trained>;
159
160    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> SklResult<Self::Fitted> {
161        validate::check_consistent_length(x, y)?;
162
163        let (_, n_features) = x.dim();
164
165        // Compute spatial correlation scores
166        let spatial_scores = if self.include_spatial {
167            Some(compute_spatial_correlation_scores(x, y))
168        } else {
169            None
170        };
171
172        // Compute frequency domain scores
173        let frequency_scores = if self.include_frequency {
174            Some(compute_frequency_domain_scores(x, y))
175        } else {
176            None
177        };
178
179        // Compute texture scores
180        let texture_scores = if self.include_texture {
181            Some(compute_texture_scores(x, y))
182        } else {
183            None
184        };
185
186        // Combine scores and select features
187        let mut combined_scores = Array1::zeros(n_features);
188        let mut weight_sum = 0.0;
189
190        if let Some(ref spatial) = spatial_scores {
191            for i in 0..n_features {
192                combined_scores[i] += 0.4 * spatial[i];
193            }
194            weight_sum += 0.4;
195        }
196
197        if let Some(ref frequency) = frequency_scores {
198            for i in 0..n_features {
199                combined_scores[i] += 0.3 * frequency[i];
200            }
201            weight_sum += 0.3;
202        }
203
204        if let Some(ref texture) = texture_scores {
205            for i in 0..n_features {
206                combined_scores[i] += 0.3 * texture[i];
207            }
208            weight_sum += 0.3;
209        }
210
211        if weight_sum > 0.0 {
212            combined_scores /= weight_sum;
213        }
214
215        // Select features based on combined scores
216        let selected_features = self.select_features_from_combined_scores(&combined_scores);
217
218        Ok(ImageFeatureSelector {
219            include_spatial: self.include_spatial,
220            include_frequency: self.include_frequency,
221            include_texture: self.include_texture,
222            spatial_threshold: self.spatial_threshold,
223            k: self.k,
224            state: PhantomData,
225            spatial_scores_: spatial_scores,
226            frequency_scores_: frequency_scores,
227            texture_scores_: texture_scores,
228            selected_features_: Some(selected_features),
229        })
230    }
231}
232
233impl ImageFeatureSelector<Untrained> {
234    fn select_features_from_combined_scores(&self, scores: &Array1<Float>) -> Vec<usize> {
235        let mut feature_indices: Vec<(usize, Float)> = scores
236            .indexed_iter()
237            .map(|(i, &score)| (i, score))
238            .collect();
239
240        feature_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
241
242        let selected: Vec<usize> = if let Some(k) = self.k {
243            feature_indices
244                .iter()
245                .take(k.min(feature_indices.len()))
246                .map(|(i, _)| *i)
247                .collect()
248        } else {
249            feature_indices
250                .iter()
251                .filter(|(_, score)| *score >= self.spatial_threshold)
252                .map(|(i, _)| *i)
253                .collect()
254        };
255
256        let mut selected_sorted = selected;
257        selected_sorted.sort();
258
259        if selected_sorted.is_empty() {
260            if let Some(&(best_idx, _)) = feature_indices.first() {
261                selected_sorted.push(best_idx);
262                selected_sorted.sort();
263            }
264        }
265        selected_sorted
266    }
267}
268
269impl Transform<Array2<Float>> for ImageFeatureSelector<Trained> {
270    fn transform(&self, x: &Array2<Float>) -> SklResult<Array2<Float>> {
271        let selected_features = self.selected_features_.as_ref().unwrap();
272        if selected_features.is_empty() {
273            return Err(SklearsError::InvalidInput(
274                "No features were selected".to_string(),
275            ));
276        }
277
278        let selected_indices: Vec<usize> = selected_features.to_vec();
279        Ok(x.select(Axis(1), &selected_indices))
280    }
281}
282
283impl SelectorMixin for ImageFeatureSelector<Trained> {
284    fn get_support(&self) -> SklResult<Array1<bool>> {
285        let selected_features = self.selected_features_.as_ref().unwrap();
286        let n_features = if let Some(ref scores) = self.spatial_scores_ {
287            scores.len()
288        } else if let Some(ref scores) = self.frequency_scores_ {
289            scores.len()
290        } else if let Some(ref scores) = self.texture_scores_ {
291            scores.len()
292        } else {
293            selected_features.iter().max().unwrap_or(&0) + 1
294        };
295
296        let mut support = Array1::from_elem(n_features, false);
297        for &idx in selected_features {
298            if idx < n_features {
299                support[idx] = true;
300            }
301        }
302        Ok(support)
303    }
304
305    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
306        let selected_features = self.selected_features_.as_ref().unwrap();
307        Ok(indices
308            .iter()
309            .filter_map(|&idx| selected_features.iter().position(|&f| f == idx))
310            .collect())
311    }
312}
313
314impl ImageFeatureSelector<Trained> {
315    /// Get the spatial correlation scores (if spatial analysis was enabled)
316    ///
317    /// Returns `None` if spatial analysis was not enabled during fitting.
318    pub fn spatial_scores(&self) -> Option<&Array1<Float>> {
319        self.spatial_scores_.as_ref()
320    }
321
322    /// Get the frequency domain scores (if frequency analysis was enabled)
323    ///
324    /// Returns `None` if frequency analysis was not enabled during fitting.
325    pub fn frequency_scores(&self) -> Option<&Array1<Float>> {
326        self.frequency_scores_.as_ref()
327    }
328
329    /// Get the texture scores (if texture analysis was enabled)
330    ///
331    /// Returns `None` if texture analysis was not enabled during fitting.
332    pub fn texture_scores(&self) -> Option<&Array1<Float>> {
333        self.texture_scores_.as_ref()
334    }
335
336    /// Get the indices of selected features
337    pub fn selected_features(&self) -> &[usize] {
338        self.selected_features_.as_ref().unwrap()
339    }
340
341    /// Get the number of selected features
342    pub fn n_features_selected(&self) -> usize {
343        self.selected_features_.as_ref().unwrap().len()
344    }
345
346    /// Get a summary of feature scores across all analysis types
347    ///
348    /// Returns a vector of tuples containing (feature_index, spatial_score, frequency_score, texture_score)
349    /// for all selected features. Scores are `None` if the corresponding analysis was not enabled.
350    pub fn feature_summary(&self) -> Vec<(usize, Option<Float>, Option<Float>, Option<Float>)> {
351        let indices = self.selected_features();
352        let mut summary = Vec::with_capacity(indices.len());
353
354        for &idx in indices {
355            let spatial_score = self.spatial_scores_.as_ref().map(|scores| scores[idx]);
356            let frequency_score = self.frequency_scores_.as_ref().map(|scores| scores[idx]);
357            let texture_score = self.texture_scores_.as_ref().map(|scores| scores[idx]);
358
359            summary.push((idx, spatial_score, frequency_score, texture_score));
360        }
361
362        summary
363    }
364}
365
366// ================================================================================================
367// Helper Functions
368// ================================================================================================
369
370/// Compute spatial correlation scores between features and target
371///
372/// This function calculates the Pearson correlation coefficient between
373/// each feature and the target variable. Higher absolute correlations
374/// indicate stronger spatial relationships with the prediction target.
375fn compute_spatial_correlation_scores(x: &Array2<Float>, y: &Array1<Float>) -> Array1<Float> {
376    let (_, n_features) = x.dim();
377    let mut scores = Array1::zeros(n_features);
378
379    for j in 0..n_features {
380        let feature = x.column(j);
381        // Compute correlation with target
382        let corr = compute_pearson_correlation(&feature, y);
383        scores[j] = corr.abs();
384    }
385
386    scores
387}
388
389/// Compute frequency domain scores for image features
390///
391/// This function uses variance as a proxy for frequency content,
392/// combined with correlation to the target variable. Higher variance
393/// indicates more spectral energy, which may be important for classification.
394fn compute_frequency_domain_scores(x: &Array2<Float>, y: &Array1<Float>) -> Array1<Float> {
395    let (_, n_features) = x.dim();
396    let mut scores = Array1::zeros(n_features);
397
398    // Simplified frequency domain analysis
399    for j in 0..n_features {
400        let feature = x.column(j);
401        // Compute variance as a proxy for frequency content
402        let variance = feature.var(0.0);
403        // Combine with correlation to target
404        let corr = compute_pearson_correlation(&feature, y);
405        scores[j] = variance * corr.abs();
406    }
407
408    scores
409}
410
411/// Compute texture scores using local variance analysis
412///
413/// This function analyzes texture by computing local variance measurements
414/// within small windows. High texture regions often contain important
415/// discriminative information for image classification tasks.
416fn compute_texture_scores(x: &Array2<Float>, y: &Array1<Float>) -> Array1<Float> {
417    let (_, n_features) = x.dim();
418    let mut scores = Array1::zeros(n_features);
419
420    // Simplified texture analysis using local variance
421    for j in 0..n_features {
422        let feature = x.column(j);
423        let local_variance = compute_local_variance(&feature);
424        let corr = compute_pearson_correlation(&feature, y);
425        scores[j] = local_variance * corr.abs();
426    }
427
428    scores
429}
430
431/// Compute Pearson correlation coefficient between two variables
432///
433/// Returns the linear correlation coefficient between x and y,
434/// ranging from -1 (perfect negative correlation) to +1 (perfect positive correlation).
435fn compute_pearson_correlation(x: &ArrayView1<Float>, y: &Array1<Float>) -> Float {
436    let n = x.len().min(y.len());
437    if n < 2 {
438        return 0.0;
439    }
440
441    let x_mean = x.iter().take(n).sum::<Float>() / n as Float;
442    let y_mean = y.iter().take(n).sum::<Float>() / n as Float;
443
444    let mut numerator = 0.0;
445    let mut x_var = 0.0;
446    let mut y_var = 0.0;
447
448    for i in 0..n {
449        let x_i = x[i] - x_mean;
450        let y_i = y[i] - y_mean;
451        numerator += x_i * y_i;
452        x_var += x_i * x_i;
453        y_var += y_i * y_i;
454    }
455
456    let denominator = (x_var * y_var).sqrt();
457    if denominator.abs() < 1e-10 {
458        0.0
459    } else {
460        numerator / denominator
461    }
462}
463
464/// Compute local variance for texture analysis
465///
466/// Uses a sliding window approach to compute variance within small neighborhoods.
467/// This helps identify regions with high texture content that may be important
468/// for image classification or object detection tasks.
469fn compute_local_variance(feature: &ArrayView1<Float>) -> Float {
470    let n = feature.len();
471    if n < 3 {
472        return 0.0;
473    }
474
475    let mut local_var = 0.0;
476    let _window_size = 3; // Simple 3-point window
477
478    for i in 1..(n - 1) {
479        let window = &feature.slice(s![i - 1..i + 2]);
480        let var = window.var(0.0);
481        local_var += var;
482    }
483
484    local_var / (n - 2) as Float
485}
486
487/// Create a new image feature selector
488pub fn create_image_feature_selector() -> ImageFeatureSelector<Untrained> {
489    ImageFeatureSelector::new()
490}
491
492/// Create an image feature selector optimized for low-resolution images
493///
494/// Suitable for small images (e.g., 28x28, 32x32) where spatial relationships
495/// are more important than high-frequency details.
496pub fn create_low_resolution_selector() -> ImageFeatureSelector<Untrained> {
497    ImageFeatureSelector::new()
498        .include_spatial(true)
499        .include_frequency(false)
500        .include_texture(true)
501        .spatial_threshold(0.05)
502}
503
504/// Create an image feature selector optimized for high-resolution images
505///
506/// Suitable for large images where frequency domain and texture analysis
507/// can capture fine-grained details that are important for classification.
508pub fn create_high_resolution_selector() -> ImageFeatureSelector<Untrained> {
509    ImageFeatureSelector::new()
510        .include_spatial(true)
511        .include_frequency(true)
512        .include_texture(true)
513        .spatial_threshold(0.1)
514        .k(Some(500))
515}
516
517/// Create an image feature selector focused on texture analysis
518///
519/// Suitable for applications where texture is the primary discriminative
520/// feature, such as material classification or medical image analysis.
521pub fn create_texture_focused_selector() -> ImageFeatureSelector<Untrained> {
522    ImageFeatureSelector::new()
523        .include_spatial(false)
524        .include_frequency(false)
525        .include_texture(true)
526        .spatial_threshold(0.2)
527}
528
529/// Create an image feature selector focused on spatial relationships
530///
531/// Suitable for applications where spatial structure is most important,
532/// such as object detection or shape classification.
533pub fn create_spatial_focused_selector() -> ImageFeatureSelector<Untrained> {
534    ImageFeatureSelector::new()
535        .include_spatial(true)
536        .include_frequency(false)
537        .include_texture(false)
538        .spatial_threshold(0.15)
539}
540
541#[allow(non_snake_case)]
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use scirs2_core::ndarray::{array, Array2};
546
547    #[test]
548    fn test_pearson_correlation_computation() {
549        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
550        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
551        let corr = compute_pearson_correlation(&x.view(), &y);
552
553        // Perfect positive correlation
554        assert!((corr - 1.0).abs() < 1e-10);
555    }
556
557    #[test]
558    fn test_local_variance_computation() {
559        let feature = array![1.0, 1.0, 1.0, 5.0, 5.0, 5.0]; // Two constant regions
560        let local_var = compute_local_variance(&feature.view());
561
562        // Should detect variance at the boundary
563        assert!(local_var > 0.0);
564    }
565
566    #[test]
567    fn test_spatial_correlation_scores() {
568        let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 2.0, 4.0, 3.0, 6.0]).unwrap();
569        let y = array![1.0, 2.0, 3.0];
570
571        let scores = compute_spatial_correlation_scores(&x, &y);
572        assert_eq!(scores.len(), 2);
573
574        // Both features should have perfect correlation with target
575        assert!((scores[0] - 1.0).abs() < 1e-10);
576        assert!((scores[1] - 1.0).abs() < 1e-10);
577    }
578
579    #[test]
580    fn test_image_feature_selector_basic() {
581        let selector = ImageFeatureSelector::new()
582            .include_spatial(true)
583            .include_frequency(false)
584            .include_texture(false)
585            .k(Some(1));
586
587        let x = Array2::from_shape_vec(
588            (4, 3),
589            vec![1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
590        )
591        .unwrap();
592        let y = array![1.0, 2.0, 3.0, 4.0];
593
594        let fitted = selector.fit(&x, &y).unwrap();
595        assert_eq!(fitted.n_features_selected(), 1);
596
597        let transformed = fitted.transform(&x).unwrap();
598        assert_eq!(transformed.ncols(), 1);
599    }
600
601    #[test]
602    fn test_feature_selection_with_threshold() {
603        let selector = ImageFeatureSelector::new()
604            .include_spatial(true)
605            .spatial_threshold(0.8); // High threshold
606
607        let x = Array2::from_shape_vec(
608            (3, 3),
609            vec![
610                1.0, 0.0, 1.0, // Strong correlation
611                2.0, 1.0, 2.0, // Weak correlation
612                3.0, 0.0, 3.0, // Strong correlation
613            ],
614        )
615        .unwrap();
616        let y = array![1.0, 2.0, 3.0];
617
618        let fitted = selector.fit(&x, &y).unwrap();
619
620        // Should select features with high correlation (0 and 2)
621        assert!(fitted.n_features_selected() >= 1);
622        assert!(fitted.n_features_selected() <= 2);
623    }
624}