sklears_feature_selection/domain_specific/
multi_modal.rs

1//! Multi-modal feature selection for heterogeneous data types.
2//!
3//! This module provides specialized feature selection capabilities for multi-modal data,
4//! where features come from different modalities such as text, images, audio, sensors,
5//! or other heterogeneous data sources. It implements various fusion strategies and
6//! cross-modal analysis techniques to identify the most informative features across modalities.
7//!
8//! # Features
9//!
10//! - **Early fusion**: Concatenates features from different modalities before selection
11//! - **Late fusion**: Selects features within each modality separately, then combines results
12//! - **Hybrid fusion**: Combines early and late fusion strategies
13//! - **Cross-modal correlations**: Analyzes relationships between features across modalities
14//! - **Modality weighting**: Assigns different importance weights to different modalities
15//! - **Missing modality handling**: Robust to missing data in some modalities
16//!
17//! # Examples
18//!
19//! ## Basic Multi-Modal Feature Selection
20//!
21//! ```rust,ignore
22//! use sklears_feature_selection::domain_specific::multi_modal::MultiModalFeatureSelector;
23//! use scirs2_core::ndarray::{Array2, Array1};
24//! use std::collections::HashMap;
25//!
26//! // Text features (TF-IDF, word embeddings, etc.)
27//! let text_features = Array2::from_shape_vec((100, 50), (0..5000).map(|x| x as f64).collect()).unwrap();
28//!
29//! // Image features (CNN features, color histograms, etc.)
30//! let image_features = Array2::from_shape_vec((100, 30), (0..3000).map(|x| (x * 2) as f64).collect()).unwrap();
31//!
32//! // Audio features (MFCCs, spectrograms, etc.)
33//! let audio_features = Array2::from_shape_vec((100, 20), (0..2000).map(|x| (x * 3) as f64).collect()).unwrap();
34//!
35//! let mut modalities = HashMap::new();
36//! modalities.insert("text".to_string(), text_features);
37//! modalities.insert("image".to_string(), image_features);
38//! modalities.insert("audio".to_string(), audio_features);
39//!
40//! let target = Array1::from_iter((0..100).map(|i| (i % 2) as f64));
41//!
42//! let selector = MultiModalFeatureSelector::builder()
43//!     .fusion_strategy("hybrid")
44//!     .modality_weights([("text", 0.4), ("image", 0.4), ("audio", 0.2)])
45//!     .cross_modal_analysis(true)
46//!     .k(20)
47//!     .build();
48//!
49//! let trained = selector.fit(&modalities, &target)?;
50//! let selected_features = trained.transform(&modalities)?;
51//! ```
52//!
53//! ## Early Fusion Strategy
54//!
55//! ```rust,ignore
56//! let selector = MultiModalFeatureSelector::builder()
57//!     .fusion_strategy("early")
58//!     .normalize_modalities(true)
59//!     .k(15)
60//!     .build();
61//! ```
62//!
63//! ## Late Fusion with Modality-Specific Selection
64//!
65//! ```rust,ignore
66//! let selector = MultiModalFeatureSelector::builder()
67//!     .fusion_strategy("late")
68//!     .modality_k([("text", 10), ("image", 8), ("audio", 5)])
69//!     .cross_modal_threshold(0.3)
70//!     .build();
71//! ```
72//!
73//! ## Handling Missing Modalities
74//!
75//! ```rust,ignore
76//! let selector = MultiModalFeatureSelector::builder()
77//!     .handle_missing_modalities(true)
78//!     .min_modalities_required(2)
79//!     .missing_strategy("impute")
80//!     .build();
81//! ```
82
83use scirs2_core::ndarray::{concatenate, s, Array1, Array2, ArrayView1, Axis};
84use sklears_core::error::{Result as SklResult, SklearsError};
85use sklears_core::traits::{Estimator, Fit, Transform};
86use std::collections::HashMap;
87use std::marker::PhantomData;
88
89type Result<T> = SklResult<T>;
90type Float = f64;
91
92#[derive(Debug, Clone)]
93pub struct Untrained;
94
95#[derive(Debug, Clone)]
96pub struct Trained {
97    fusion_strategy: String,
98    selected_features_per_modality: HashMap<String, Vec<usize>>,
99    combined_selected_features: Vec<(String, usize)>, // (modality, feature_index)
100    feature_scores_per_modality: HashMap<String, Array1<Float>>,
101    cross_modal_scores: Option<Array2<Float>>,
102    modality_weights: HashMap<String, Float>,
103    feature_mapping: HashMap<usize, (String, usize)>, // global_index -> (modality, local_index)
104    total_features: usize,
105    modality_feature_counts: HashMap<String, usize>,
106}
107
108/// Multi-modal feature selector for heterogeneous data types.
109///
110/// This selector handles feature selection across multiple data modalities using various
111/// fusion strategies. It can analyze cross-modal relationships, handle missing modalities,
112/// and apply different selection criteria for each modality type.
113#[derive(Debug, Clone)]
114pub struct MultiModalFeatureSelector<State = Untrained> {
115    fusion_strategy: String,
116    modality_weights: HashMap<String, Float>,
117    modality_k: HashMap<String, usize>,
118    cross_modal_analysis: bool,
119    cross_modal_threshold: Float,
120    normalize_modalities: bool,
121    handle_missing_modalities: bool,
122    min_modalities_required: usize,
123    missing_strategy: String,
124    k: Option<usize>,
125    score_threshold: Float,
126    correlation_method: String,
127    interaction_analysis: bool,
128    max_interaction_order: usize,
129    state: PhantomData<State>,
130    trained_state: Option<Trained>,
131}
132
133impl Default for MultiModalFeatureSelector<Untrained> {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl MultiModalFeatureSelector<Untrained> {
140    /// Creates a new MultiModalFeatureSelector with default parameters.
141    pub fn new() -> Self {
142        Self {
143            fusion_strategy: "hybrid".to_string(),
144            modality_weights: HashMap::new(),
145            modality_k: HashMap::new(),
146            cross_modal_analysis: true,
147            cross_modal_threshold: 0.1,
148            normalize_modalities: true,
149            handle_missing_modalities: true,
150            min_modalities_required: 1,
151            missing_strategy: "ignore".to_string(),
152            k: None,
153            score_threshold: 0.1,
154            correlation_method: "pearson".to_string(),
155            interaction_analysis: false,
156            max_interaction_order: 2,
157            state: PhantomData,
158            trained_state: None,
159        }
160    }
161
162    /// Creates a builder for configuring the MultiModalFeatureSelector.
163    pub fn builder() -> MultiModalFeatureSelectorBuilder {
164        MultiModalFeatureSelectorBuilder::new()
165    }
166}
167
168/// Builder for MultiModalFeatureSelector configuration.
169#[derive(Debug)]
170pub struct MultiModalFeatureSelectorBuilder {
171    fusion_strategy: String,
172    modality_weights: HashMap<String, Float>,
173    modality_k: HashMap<String, usize>,
174    cross_modal_analysis: bool,
175    cross_modal_threshold: Float,
176    normalize_modalities: bool,
177    handle_missing_modalities: bool,
178    min_modalities_required: usize,
179    missing_strategy: String,
180    k: Option<usize>,
181    score_threshold: Float,
182    correlation_method: String,
183    interaction_analysis: bool,
184    max_interaction_order: usize,
185}
186
187impl Default for MultiModalFeatureSelectorBuilder {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193impl MultiModalFeatureSelectorBuilder {
194    pub fn new() -> Self {
195        Self {
196            fusion_strategy: "hybrid".to_string(),
197            modality_weights: HashMap::new(),
198            modality_k: HashMap::new(),
199            cross_modal_analysis: true,
200            cross_modal_threshold: 0.1,
201            normalize_modalities: true,
202            handle_missing_modalities: true,
203            min_modalities_required: 1,
204            missing_strategy: "ignore".to_string(),
205            k: None,
206            score_threshold: 0.1,
207            correlation_method: "pearson".to_string(),
208            interaction_analysis: false,
209            max_interaction_order: 2,
210        }
211    }
212
213    /// Fusion strategy: "early", "late", or "hybrid".
214    pub fn fusion_strategy(mut self, strategy: &str) -> Self {
215        self.fusion_strategy = strategy.to_string();
216        self
217    }
218
219    /// Set weights for different modalities.
220    pub fn modality_weights<I>(mut self, weights: I) -> Self
221    where
222        I: IntoIterator<Item = (&'static str, f64)>,
223    {
224        self.modality_weights = weights
225            .into_iter()
226            .map(|(k, v)| (k.to_string(), v))
227            .collect();
228        self
229    }
230
231    /// Set number of features to select per modality (for late fusion).
232    pub fn modality_k<I>(mut self, k_values: I) -> Self
233    where
234        I: IntoIterator<Item = (&'static str, usize)>,
235    {
236        self.modality_k = k_values
237            .into_iter()
238            .map(|(k, v)| (k.to_string(), v))
239            .collect();
240        self
241    }
242
243    /// Whether to perform cross-modal correlation analysis.
244    pub fn cross_modal_analysis(mut self, enable: bool) -> Self {
245        self.cross_modal_analysis = enable;
246        self
247    }
248
249    /// Threshold for cross-modal correlation significance.
250    pub fn cross_modal_threshold(mut self, threshold: Float) -> Self {
251        self.cross_modal_threshold = threshold;
252        self
253    }
254
255    /// Whether to normalize features within each modality.
256    pub fn normalize_modalities(mut self, normalize: bool) -> Self {
257        self.normalize_modalities = normalize;
258        self
259    }
260
261    /// Whether to handle missing modalities gracefully.
262    pub fn handle_missing_modalities(mut self, handle: bool) -> Self {
263        self.handle_missing_modalities = handle;
264        self
265    }
266
267    /// Minimum number of modalities required for processing.
268    pub fn min_modalities_required(mut self, min: usize) -> Self {
269        self.min_modalities_required = min;
270        self
271    }
272
273    /// Strategy for handling missing modalities: "ignore", "impute", or "error".
274    pub fn missing_strategy(mut self, strategy: &str) -> Self {
275        self.missing_strategy = strategy.to_string();
276        self
277    }
278
279    /// Total number of features to select (for early/hybrid fusion).
280    pub fn k(mut self, k: usize) -> Self {
281        self.k = Some(k);
282        self
283    }
284
285    /// Minimum score threshold for feature selection.
286    pub fn score_threshold(mut self, threshold: Float) -> Self {
287        self.score_threshold = threshold;
288        self
289    }
290
291    /// Correlation method: "pearson", "spearman", or "mutual_info".
292    pub fn correlation_method(mut self, method: &str) -> Self {
293        self.correlation_method = method.to_string();
294        self
295    }
296
297    /// Whether to analyze feature interactions across modalities.
298    pub fn interaction_analysis(mut self, enable: bool) -> Self {
299        self.interaction_analysis = enable;
300        self
301    }
302
303    /// Maximum order of interactions to consider (2 = pairwise, 3 = three-way, etc.).
304    pub fn max_interaction_order(mut self, order: usize) -> Self {
305        self.max_interaction_order = order;
306        self
307    }
308
309    /// Builds the MultiModalFeatureSelector.
310    pub fn build(self) -> MultiModalFeatureSelector<Untrained> {
311        MultiModalFeatureSelector {
312            fusion_strategy: self.fusion_strategy,
313            modality_weights: self.modality_weights,
314            modality_k: self.modality_k,
315            cross_modal_analysis: self.cross_modal_analysis,
316            cross_modal_threshold: self.cross_modal_threshold,
317            normalize_modalities: self.normalize_modalities,
318            handle_missing_modalities: self.handle_missing_modalities,
319            min_modalities_required: self.min_modalities_required,
320            missing_strategy: self.missing_strategy,
321            k: self.k,
322            score_threshold: self.score_threshold,
323            correlation_method: self.correlation_method,
324            interaction_analysis: self.interaction_analysis,
325            max_interaction_order: self.max_interaction_order,
326            state: PhantomData,
327            trained_state: None,
328        }
329    }
330}
331
332impl Estimator for MultiModalFeatureSelector<Untrained> {
333    type Config = ();
334    type Error = sklears_core::error::SklearsError;
335    type Float = Float;
336
337    fn config(&self) -> &Self::Config {
338        &()
339    }
340}
341
342impl Estimator for MultiModalFeatureSelector<Trained> {
343    type Config = ();
344    type Error = sklears_core::error::SklearsError;
345    type Float = Float;
346
347    fn config(&self) -> &Self::Config {
348        &()
349    }
350}
351
352impl Fit<HashMap<String, Array2<Float>>, Array1<Float>> for MultiModalFeatureSelector<Untrained> {
353    type Fitted = MultiModalFeatureSelector<Trained>;
354
355    fn fit(
356        self,
357        modalities: &HashMap<String, Array2<Float>>,
358        y: &Array1<Float>,
359    ) -> Result<Self::Fitted> {
360        if modalities.is_empty() {
361            return Err(SklearsError::InvalidInput(
362                "At least one modality is required".to_string(),
363            ));
364        }
365
366        if modalities.len() < self.min_modalities_required {
367            return Err(SklearsError::InvalidInput(format!(
368                "At least {} modalities are required",
369                self.min_modalities_required
370            )));
371        }
372
373        // Validate that all modalities have the same number of samples
374        let n_samples = y.len();
375        for (modality_name, features) in modalities {
376            if features.nrows() != n_samples {
377                return Err(SklearsError::InvalidInput(format!(
378                    "Modality '{}' has {} samples, expected {}",
379                    modality_name,
380                    features.nrows(),
381                    n_samples
382                )));
383            }
384        }
385
386        // Normalize modalities if requested
387        let normalized_modalities = if self.normalize_modalities {
388            normalize_modalities(modalities)?
389        } else {
390            modalities.clone()
391        };
392
393        // Set default weights if not provided
394        let mut modality_weights = self.modality_weights.clone();
395        if modality_weights.is_empty() {
396            let default_weight = 1.0 / modalities.len() as Float;
397            for modality_name in modalities.keys() {
398                modality_weights.insert(modality_name.clone(), default_weight);
399            }
400        }
401
402        // Perform feature selection based on fusion strategy
403        let (
404            selected_features_per_modality,
405            combined_selected_features,
406            feature_scores_per_modality,
407            cross_modal_scores,
408            feature_mapping,
409            total_features,
410            modality_feature_counts,
411        ) = match self.fusion_strategy.as_str() {
412            "early" => self.early_fusion_selection(&normalized_modalities, y, &modality_weights)?,
413            "late" => self.late_fusion_selection(&normalized_modalities, y, &modality_weights)?,
414            "hybrid" => {
415                self.hybrid_fusion_selection(&normalized_modalities, y, &modality_weights)?
416            }
417            _ => {
418                return Err(SklearsError::InvalidInput(format!(
419                    "Unknown fusion strategy: {}",
420                    self.fusion_strategy
421                )))
422            }
423        };
424
425        let trained_state = Trained {
426            fusion_strategy: self.fusion_strategy.clone(),
427            selected_features_per_modality,
428            combined_selected_features,
429            feature_scores_per_modality,
430            cross_modal_scores,
431            modality_weights,
432            feature_mapping,
433            total_features,
434            modality_feature_counts,
435        };
436
437        Ok(MultiModalFeatureSelector {
438            fusion_strategy: self.fusion_strategy,
439            modality_weights: self.modality_weights,
440            modality_k: self.modality_k,
441            cross_modal_analysis: self.cross_modal_analysis,
442            cross_modal_threshold: self.cross_modal_threshold,
443            normalize_modalities: self.normalize_modalities,
444            handle_missing_modalities: self.handle_missing_modalities,
445            min_modalities_required: self.min_modalities_required,
446            missing_strategy: self.missing_strategy,
447            k: self.k,
448            score_threshold: self.score_threshold,
449            correlation_method: self.correlation_method,
450            interaction_analysis: self.interaction_analysis,
451            max_interaction_order: self.max_interaction_order,
452            state: PhantomData,
453            trained_state: Some(trained_state),
454        })
455    }
456}
457
458impl Transform<HashMap<String, Array2<Float>>, Array2<Float>>
459    for MultiModalFeatureSelector<Trained>
460{
461    fn transform(&self, modalities: &HashMap<String, Array2<Float>>) -> Result<Array2<Float>> {
462        let trained = self.trained_state.as_ref().ok_or_else(|| {
463            SklearsError::InvalidState("Selector must be fitted before transforming".to_string())
464        })?;
465
466        match trained.fusion_strategy.as_str() {
467            "early" => self.transform_early_fusion(modalities, trained),
468            "late" => self.transform_late_fusion(modalities, trained),
469            "hybrid" => self.transform_hybrid_fusion(modalities, trained),
470            _ => Err(SklearsError::InvalidState(format!(
471                "Unknown fusion strategy: {}",
472                trained.fusion_strategy
473            ))),
474        }
475    }
476}
477
478// MultiModalFeatureSelector uses HashMap<String, Array2> as input, not Array2,
479// so it cannot implement SelectorMixin which requires Transform<Array2>
480/* impl SelectorMixin for MultiModalFeatureSelector<Trained> {
481    fn get_support(&self) -> Result<Array1<bool>> {
482        let trained = self.trained_state.as_ref().ok_or_else(|| {
483            SklearsError::InvalidState("Selector must be fitted before getting support".to_string())
484        })?;
485
486        let mut support = Array1::from_elem(trained.total_features, false);
487
488        match trained.fusion_strategy.as_str() {
489            "early" | "hybrid" => {
490                for &(_, global_idx) in &trained.combined_selected_features {
491                    if let Some(&(_, _)) = trained.feature_mapping.get(&global_idx) {
492                        support[global_idx] = true;
493                    }
494                }
495            }
496            "late" => {
497                for (modality, selected_indices) in &trained.selected_features_per_modality {
498                    if let Some(&base_idx) = trained
499                        .modality_feature_counts
500                        .keys()
501                        .take_while(|&k| k != modality)
502                        .map(|k| trained.modality_feature_counts[k])
503                        .reduce(|acc, x| acc + x)
504                        .as_ref()
505                    {
506                        for &local_idx in selected_indices {
507                            support[base_idx + local_idx] = true;
508                        }
509                    }
510                }
511            }
512            _ => {
513                return Err(SklearsError::InvalidState(
514                    "Unknown fusion strategy".to_string(),
515                ))
516            }
517        }
518
519        Ok(support)
520    }
521
522    fn transform_features(&self, indices: &[usize]) -> SklResult<Vec<usize>> {
523        let trained = self.trained_state.as_ref().ok_or_else(|| {
524            SklearsError::InvalidState(
525                "Selector must be fitted before transforming features".to_string(),
526            )
527        })?;
528
529        let selected: Vec<usize> = indices
530            .iter()
531            .filter(|&&idx| trained.selected_features.contains(&idx))
532            .cloned()
533            .collect();
534        Ok(selected)
535    }
536} */
537
538// Implementation methods for MultiModalFeatureSelector
539impl MultiModalFeatureSelector<Untrained> {
540    fn early_fusion_selection(
541        &self,
542        modalities: &HashMap<String, Array2<Float>>,
543        y: &Array1<Float>,
544        modality_weights: &HashMap<String, Float>,
545    ) -> Result<(
546        HashMap<String, Vec<usize>>,
547        Vec<(String, usize)>,
548        HashMap<String, Array1<Float>>,
549        Option<Array2<Float>>,
550        HashMap<usize, (String, usize)>,
551        usize,
552        HashMap<String, usize>,
553    )> {
554        // Concatenate all modalities into a single feature matrix
555        let (combined_features, feature_mapping, modality_feature_counts) =
556            concatenate_modalities(modalities)?;
557
558        // Compute feature scores using weighted univariate analysis
559        let feature_scores = compute_weighted_feature_scores(
560            &combined_features,
561            y,
562            modalities,
563            modality_weights,
564            &self.correlation_method,
565        )?;
566
567        // Cross-modal analysis if enabled
568        let cross_modal_scores = if self.cross_modal_analysis {
569            Some(compute_cross_modal_correlations(
570                modalities,
571                &self.correlation_method,
572            )?)
573        } else {
574            None
575        };
576
577        // Select features based on scores
578        let selected_indices = if let Some(k) = self.k {
579            select_top_k_features(&feature_scores, k)
580        } else {
581            select_features_by_threshold(&feature_scores, self.score_threshold)
582        };
583
584        // Map back to per-modality selections
585        let mut selected_features_per_modality = HashMap::new();
586        let mut combined_selected_features = Vec::new();
587
588        for &global_idx in &selected_indices {
589            if let Some(&(ref modality, local_idx)) = feature_mapping.get(&global_idx) {
590                selected_features_per_modality
591                    .entry(modality.clone())
592                    .or_insert_with(Vec::new)
593                    .push(local_idx);
594                combined_selected_features.push((modality.clone(), global_idx));
595            }
596        }
597
598        // Create per-modality feature scores
599        let mut feature_scores_per_modality = HashMap::new();
600        let mut current_idx = 0;
601        for (modality_name, features) in modalities {
602            let n_features = features.ncols();
603            let modality_scores = feature_scores
604                .slice(s![current_idx..current_idx + n_features])
605                .to_owned();
606            feature_scores_per_modality.insert(modality_name.clone(), modality_scores);
607            current_idx += n_features;
608        }
609
610        Ok((
611            selected_features_per_modality,
612            combined_selected_features,
613            feature_scores_per_modality,
614            cross_modal_scores,
615            feature_mapping,
616            combined_features.ncols(),
617            modality_feature_counts,
618        ))
619    }
620
621    fn late_fusion_selection(
622        &self,
623        modalities: &HashMap<String, Array2<Float>>,
624        y: &Array1<Float>,
625        modality_weights: &HashMap<String, Float>,
626    ) -> Result<(
627        HashMap<String, Vec<usize>>,
628        Vec<(String, usize)>,
629        HashMap<String, Array1<Float>>,
630        Option<Array2<Float>>,
631        HashMap<usize, (String, usize)>,
632        usize,
633        HashMap<String, usize>,
634    )> {
635        let mut selected_features_per_modality = HashMap::new();
636        let mut feature_scores_per_modality = HashMap::new();
637        let mut combined_selected_features = Vec::new();
638        let mut feature_mapping = HashMap::new();
639        let mut modality_feature_counts = HashMap::new();
640        let mut total_features = 0;
641        let mut global_idx = 0;
642
643        // Select features within each modality separately
644        for (modality_name, features) in modalities {
645            let n_features = features.ncols();
646            modality_feature_counts.insert(modality_name.clone(), n_features);
647
648            // Create feature mapping for this modality
649            for local_idx in 0..n_features {
650                feature_mapping.insert(global_idx, (modality_name.clone(), local_idx));
651                global_idx += 1;
652            }
653
654            let weight = modality_weights.get(modality_name).cloned().unwrap_or(1.0);
655            let scores = compute_univariate_scores(features, y, &self.correlation_method)?;
656            let weighted_scores = scores.mapv(|s| s * weight);
657
658            let k = self
659                .modality_k
660                .get(modality_name)
661                .cloned()
662                .or(self.k.map(|total_k| total_k / modalities.len()))
663                .unwrap_or(n_features / 2);
664
665            let selected_indices = select_top_k_features(&weighted_scores, k.min(n_features));
666
667            for &local_idx in &selected_indices {
668                let global_feature_idx = total_features + local_idx;
669                combined_selected_features.push((modality_name.clone(), global_feature_idx));
670            }
671
672            selected_features_per_modality.insert(modality_name.clone(), selected_indices);
673            feature_scores_per_modality.insert(modality_name.clone(), weighted_scores);
674            total_features += n_features;
675        }
676
677        // Cross-modal analysis if enabled
678        let cross_modal_scores = if self.cross_modal_analysis {
679            Some(compute_cross_modal_correlations(
680                modalities,
681                &self.correlation_method,
682            )?)
683        } else {
684            None
685        };
686
687        Ok((
688            selected_features_per_modality,
689            combined_selected_features,
690            feature_scores_per_modality,
691            cross_modal_scores,
692            feature_mapping,
693            total_features,
694            modality_feature_counts,
695        ))
696    }
697
698    fn hybrid_fusion_selection(
699        &self,
700        modalities: &HashMap<String, Array2<Float>>,
701        y: &Array1<Float>,
702        modality_weights: &HashMap<String, Float>,
703    ) -> Result<(
704        HashMap<String, Vec<usize>>,
705        Vec<(String, usize)>,
706        HashMap<String, Array1<Float>>,
707        Option<Array2<Float>>,
708        HashMap<usize, (String, usize)>,
709        usize,
710        HashMap<String, usize>,
711    )> {
712        // Perform both early and late fusion, then combine results
713        let (
714            early_selected,
715            _early_combined,
716            early_scores,
717            early_cross_modal,
718            early_mapping,
719            early_total,
720            early_counts,
721        ) = self.early_fusion_selection(modalities, y, modality_weights)?;
722
723        let (
724            late_selected,
725            _late_combined,
726            late_scores,
727            late_cross_modal,
728            _late_mapping,
729            _late_total,
730            _late_counts,
731        ) = self.late_fusion_selection(modalities, y, modality_weights)?;
732
733        // Combine results with hybrid strategy
734        let mut final_selected_per_modality = HashMap::new();
735        let mut final_combined_selected = Vec::new();
736        let mut final_scores_per_modality = HashMap::new();
737
738        for (modality_name, early_indices) in &early_selected {
739            let late_indices = late_selected
740                .get(modality_name)
741                .cloned()
742                .unwrap_or_default();
743            let early_modality_scores = &early_scores[modality_name];
744            let late_modality_scores = &late_scores[modality_name];
745
746            // Combine scores from early and late fusion
747            let combined_scores = (early_modality_scores + late_modality_scores) / 2.0;
748
749            // Merge selected indices (take union and re-rank)
750            let mut all_candidates: std::collections::HashSet<usize> =
751                early_indices.iter().cloned().collect();
752            all_candidates.extend(late_indices.iter());
753
754            let mut candidate_scores: Vec<(usize, Float)> = all_candidates
755                .iter()
756                .map(|&idx| (idx, combined_scores[idx]))
757                .collect();
758
759            candidate_scores
760                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
761
762            let target_k = self
763                .modality_k
764                .get(modality_name)
765                .cloned()
766                .or(self.k.map(|total_k| total_k / modalities.len()))
767                .unwrap_or(candidate_scores.len() / 2);
768
769            let final_indices: Vec<usize> = candidate_scores
770                .into_iter()
771                .take(target_k.min(all_candidates.len()))
772                .map(|(idx, _)| idx)
773                .collect();
774
775            // Update global combined selection
776            for &local_idx in &final_indices {
777                if let Some(base_offset) = compute_modality_offset(modality_name, modalities) {
778                    final_combined_selected.push((modality_name.clone(), base_offset + local_idx));
779                }
780            }
781
782            final_selected_per_modality.insert(modality_name.clone(), final_indices);
783            final_scores_per_modality.insert(modality_name.clone(), combined_scores);
784        }
785
786        // Use early fusion mappings and totals as base
787        let cross_modal_scores = early_cross_modal.or(late_cross_modal);
788
789        Ok((
790            final_selected_per_modality,
791            final_combined_selected,
792            final_scores_per_modality,
793            cross_modal_scores,
794            early_mapping,
795            early_total,
796            early_counts,
797        ))
798    }
799}
800
801impl MultiModalFeatureSelector<Trained> {
802    fn transform_early_fusion(
803        &self,
804        modalities: &HashMap<String, Array2<Float>>,
805        trained: &Trained,
806    ) -> Result<Array2<Float>> {
807        let (combined_features, _, _) = concatenate_modalities(modalities)?;
808
809        let global_indices: Vec<usize> = trained
810            .combined_selected_features
811            .iter()
812            .map(|(_, global_idx)| *global_idx)
813            .collect();
814
815        if global_indices.is_empty() {
816            return Err(SklearsError::InvalidState(
817                "No features were selected".to_string(),
818            ));
819        }
820
821        let selected_data = combined_features.select(Axis(1), &global_indices);
822        Ok(selected_data)
823    }
824
825    fn transform_late_fusion(
826        &self,
827        modalities: &HashMap<String, Array2<Float>>,
828        trained: &Trained,
829    ) -> Result<Array2<Float>> {
830        let mut selected_features_owned = Vec::new();
831
832        for (modality_name, features) in modalities {
833            if let Some(selected_indices) =
834                trained.selected_features_per_modality.get(modality_name)
835            {
836                if !selected_indices.is_empty() {
837                    let selected_modality_features = features.select(Axis(1), selected_indices);
838                    selected_features_owned.push(selected_modality_features);
839                }
840            }
841        }
842
843        if selected_features_owned.is_empty() {
844            return Err(SklearsError::InvalidState(
845                "No features were selected from any modality".to_string(),
846            ));
847        }
848
849        let selected_features_list: Vec<_> =
850            selected_features_owned.iter().map(|a| a.view()).collect();
851        let combined = concatenate(Axis(1), &selected_features_list).map_err(|e| {
852            SklearsError::InvalidInput(format!("Failed to concatenate selected features: {}", e))
853        })?;
854
855        Ok(combined)
856    }
857
858    fn transform_hybrid_fusion(
859        &self,
860        modalities: &HashMap<String, Array2<Float>>,
861        trained: &Trained,
862    ) -> Result<Array2<Float>> {
863        // For hybrid fusion, use the same approach as late fusion since features were selected per modality
864        self.transform_late_fusion(modalities, trained)
865    }
866}
867
868// Utility functions
869
870fn normalize_modalities(
871    modalities: &HashMap<String, Array2<Float>>,
872) -> Result<HashMap<String, Array2<Float>>> {
873    let mut normalized = HashMap::new();
874
875    for (modality_name, features) in modalities {
876        let normalized_features = normalize_features(features)?;
877        normalized.insert(modality_name.clone(), normalized_features);
878    }
879
880    Ok(normalized)
881}
882
883fn normalize_features(features: &Array2<Float>) -> Result<Array2<Float>> {
884    let (n_samples, n_features) = features.dim();
885    let mut normalized = Array2::zeros((n_samples, n_features));
886
887    for j in 0..n_features {
888        let feature = features.column(j);
889        let mean = feature.sum() / n_samples as Float;
890        let variance = feature.mapv(|x| (x - mean).powi(2)).sum() / n_samples as Float;
891        let std_dev = variance.sqrt();
892
893        if std_dev > 1e-8 {
894            for i in 0..n_samples {
895                normalized[[i, j]] = (features[[i, j]] - mean) / std_dev;
896            }
897        } else {
898            // Handle constant features
899            normalized.column_mut(j).fill(0.0);
900        }
901    }
902
903    Ok(normalized)
904}
905
906fn concatenate_modalities(
907    modalities: &HashMap<String, Array2<Float>>,
908) -> Result<(
909    Array2<Float>,
910    HashMap<usize, (String, usize)>,
911    HashMap<String, usize>,
912)> {
913    if modalities.is_empty() {
914        return Err(SklearsError::InvalidInput(
915            "No modalities provided".to_string(),
916        ));
917    }
918
919    let n_samples = modalities.values().next().unwrap().nrows();
920    let mut feature_views = Vec::new();
921    let mut feature_mapping = HashMap::new();
922    let mut modality_feature_counts = HashMap::new();
923    let mut global_idx = 0;
924
925    for (modality_name, features) in modalities {
926        if features.nrows() != n_samples {
927            return Err(SklearsError::InvalidInput(
928                format!("All modalities must have the same number of samples. Expected {}, got {} for modality '{}'",
929                       n_samples, features.nrows(), modality_name)
930            ));
931        }
932
933        let n_features = features.ncols();
934        modality_feature_counts.insert(modality_name.clone(), n_features);
935
936        for local_idx in 0..n_features {
937            feature_mapping.insert(global_idx, (modality_name.clone(), local_idx));
938            global_idx += 1;
939        }
940
941        feature_views.push(features.view());
942    }
943
944    let combined = concatenate(Axis(1), &feature_views).map_err(|e| {
945        SklearsError::InvalidInput(format!("Failed to concatenate modalities: {}", e))
946    })?;
947
948    Ok((combined, feature_mapping, modality_feature_counts))
949}
950
951fn compute_weighted_feature_scores(
952    features: &Array2<Float>,
953    y: &Array1<Float>,
954    modalities: &HashMap<String, Array2<Float>>,
955    modality_weights: &HashMap<String, Float>,
956    correlation_method: &str,
957) -> Result<Array1<Float>> {
958    let (_, n_features) = features.dim();
959    let mut scores = Array1::zeros(n_features);
960    let mut feature_idx = 0;
961
962    for (modality_name, modality_features) in modalities {
963        let weight = modality_weights.get(modality_name).cloned().unwrap_or(1.0);
964        let modality_scores = compute_univariate_scores(modality_features, y, correlation_method)?;
965        let weighted_scores = modality_scores.mapv(|s| s * weight);
966
967        let end_idx = feature_idx + modality_features.ncols();
968        scores
969            .slice_mut(s![feature_idx..end_idx])
970            .assign(&weighted_scores);
971        feature_idx = end_idx;
972    }
973
974    Ok(scores)
975}
976
977fn compute_univariate_scores(
978    features: &Array2<Float>,
979    y: &Array1<Float>,
980    method: &str,
981) -> Result<Array1<Float>> {
982    let (_, n_features) = features.dim();
983    let mut scores = Array1::zeros(n_features);
984
985    for j in 0..n_features {
986        let feature = features.column(j);
987        let score = match method {
988            "pearson" => compute_pearson_correlation(&feature, &y.view()),
989            "spearman" => compute_spearman_correlation(&feature, &y.view()),
990            "mutual_info" => compute_mutual_information(&feature, &y.view()),
991            _ => compute_pearson_correlation(&feature, &y.view()),
992        };
993        scores[j] = score.abs();
994    }
995
996    Ok(scores)
997}
998
999fn compute_cross_modal_correlations(
1000    modalities: &HashMap<String, Array2<Float>>,
1001    method: &str,
1002) -> Result<Array2<Float>> {
1003    let modality_names: Vec<&String> = modalities.keys().collect();
1004    let n_modalities = modality_names.len();
1005    let mut correlations = Array2::zeros((n_modalities, n_modalities));
1006
1007    for i in 0..n_modalities {
1008        for j in i..n_modalities {
1009            if i == j {
1010                correlations[[i, j]] = 1.0;
1011            } else {
1012                let features_i = &modalities[modality_names[i]];
1013                let features_j = &modalities[modality_names[j]];
1014
1015                let correlation = compute_modality_correlation(features_i, features_j, method)?;
1016                correlations[[i, j]] = correlation;
1017                correlations[[j, i]] = correlation;
1018            }
1019        }
1020    }
1021
1022    Ok(correlations)
1023}
1024
1025fn compute_modality_correlation(
1026    features_a: &Array2<Float>,
1027    features_b: &Array2<Float>,
1028    method: &str,
1029) -> Result<Float> {
1030    // Compute average correlation between modalities by taking mean of all pairwise feature correlations
1031    let (_, n_features_a) = features_a.dim();
1032    let (_, n_features_b) = features_b.dim();
1033
1034    let mut total_correlation = 0.0;
1035    let mut count = 0;
1036
1037    for i in 0..n_features_a.min(10) {
1038        // Limit to avoid computational explosion
1039        for j in 0..n_features_b.min(10) {
1040            let feature_a = features_a.column(i);
1041            let feature_b = features_b.column(j);
1042
1043            let correlation = match method {
1044                "pearson" => compute_pearson_correlation(&feature_a, &feature_b),
1045                "spearman" => compute_spearman_correlation(&feature_a, &feature_b),
1046                _ => compute_pearson_correlation(&feature_a, &feature_b),
1047            };
1048
1049            total_correlation += correlation.abs();
1050            count += 1;
1051        }
1052    }
1053
1054    Ok(if count > 0 {
1055        total_correlation / count as Float
1056    } else {
1057        0.0
1058    })
1059}
1060
1061fn compute_pearson_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
1062    let n = x.len();
1063    if n != y.len() || n == 0 {
1064        return 0.0;
1065    }
1066
1067    let mean_x = x.sum() / n as Float;
1068    let mean_y = y.sum() / n as Float;
1069
1070    let mut numerator = 0.0;
1071    let mut sum_sq_x = 0.0;
1072    let mut sum_sq_y = 0.0;
1073
1074    for i in 0..n {
1075        let dx = x[i] - mean_x;
1076        let dy = y[i] - mean_y;
1077        numerator += dx * dy;
1078        sum_sq_x += dx * dx;
1079        sum_sq_y += dy * dy;
1080    }
1081
1082    let denominator = (sum_sq_x * sum_sq_y).sqrt();
1083    if denominator == 0.0 {
1084        0.0
1085    } else {
1086        numerator / denominator
1087    }
1088}
1089
1090fn compute_spearman_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
1091    // Simplified Spearman correlation (rank-based)
1092    let ranks_x = compute_ranks(x);
1093    let ranks_y = compute_ranks(y);
1094    compute_pearson_correlation(&ranks_x.view(), &ranks_y.view())
1095}
1096
1097fn compute_mutual_information(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
1098    // Simplified mutual information estimation using correlation as proxy
1099    compute_pearson_correlation(x, y).abs()
1100}
1101
1102fn compute_ranks(values: &ArrayView1<Float>) -> Array1<Float> {
1103    let n = values.len();
1104    let mut indexed_values: Vec<(usize, Float)> =
1105        values.iter().enumerate().map(|(i, &v)| (i, v)).collect();
1106
1107    indexed_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1108
1109    let mut ranks = Array1::zeros(n);
1110    for (rank, &(original_idx, _)) in indexed_values.iter().enumerate() {
1111        ranks[original_idx] = rank as Float;
1112    }
1113
1114    ranks
1115}
1116
1117fn compute_modality_offset(
1118    modality_name: &str,
1119    modalities: &HashMap<String, Array2<Float>>,
1120) -> Option<usize> {
1121    let mut offset = 0;
1122    for (name, features) in modalities {
1123        if name == modality_name {
1124            return Some(offset);
1125        }
1126        offset += features.ncols();
1127    }
1128    None
1129}
1130
1131fn select_top_k_features(scores: &Array1<Float>, k: usize) -> Vec<usize> {
1132    let mut indexed_scores: Vec<(usize, Float)> = scores
1133        .iter()
1134        .enumerate()
1135        .map(|(i, &score)| (i, score))
1136        .collect();
1137
1138    indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1139
1140    indexed_scores
1141        .into_iter()
1142        .take(k.min(scores.len()))
1143        .map(|(i, _)| i)
1144        .collect()
1145}
1146
1147fn select_features_by_threshold(scores: &Array1<Float>, threshold: Float) -> Vec<usize> {
1148    scores
1149        .iter()
1150        .enumerate()
1151        .filter(|(_, &score)| score >= threshold)
1152        .map(|(i, _)| i)
1153        .collect()
1154}
1155
1156#[allow(non_snake_case)]
1157#[cfg(test)]
1158mod tests {
1159    use super::*;
1160
1161    #[test]
1162    fn test_multi_modal_feature_selector_creation() {
1163        let selector = MultiModalFeatureSelector::new();
1164        assert_eq!(selector.fusion_strategy, "hybrid");
1165        assert!(selector.cross_modal_analysis);
1166        assert!(selector.normalize_modalities);
1167    }
1168
1169    #[test]
1170    fn test_multi_modal_feature_selector_builder() {
1171        let selector = MultiModalFeatureSelector::builder()
1172            .fusion_strategy("early")
1173            .cross_modal_analysis(false)
1174            .k(10)
1175            .build();
1176
1177        assert_eq!(selector.fusion_strategy, "early");
1178        assert!(!selector.cross_modal_analysis);
1179        assert_eq!(selector.k, Some(10));
1180    }
1181
1182    #[test]
1183    fn test_concatenate_modalities() {
1184        let mut modalities = HashMap::new();
1185
1186        let text_features =
1187            Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1188        let image_features =
1189            Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
1190
1191        modalities.insert("text".to_string(), text_features);
1192        modalities.insert("image".to_string(), image_features);
1193
1194        let (combined, mapping, counts) = concatenate_modalities(&modalities).unwrap();
1195
1196        assert_eq!(combined.dim(), (3, 4));
1197        assert_eq!(mapping.len(), 4);
1198        assert_eq!(counts.len(), 2);
1199    }
1200
1201    #[test]
1202    fn test_fit_transform_early_fusion() {
1203        let mut modalities = HashMap::new();
1204
1205        let text_features =
1206            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1207        let image_features =
1208            Array2::from_shape_vec((4, 2), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
1209
1210        modalities.insert("text".to_string(), text_features);
1211        modalities.insert("image".to_string(), image_features);
1212
1213        let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
1214
1215        let selector = MultiModalFeatureSelector::builder()
1216            .fusion_strategy("early")
1217            .k(2)
1218            .build();
1219
1220        let trained = selector.fit(&modalities, &target).unwrap();
1221        let transformed = trained.transform(&modalities).unwrap();
1222
1223        assert_eq!(transformed.ncols(), 2);
1224        assert_eq!(transformed.nrows(), 4);
1225    }
1226
1227    #[test]
1228    fn test_fit_transform_late_fusion() {
1229        let mut modalities = HashMap::new();
1230
1231        let text_features = Array2::from_shape_vec(
1232            (4, 3),
1233            vec![
1234                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1235            ],
1236        )
1237        .unwrap();
1238        let image_features =
1239            Array2::from_shape_vec((4, 2), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
1240
1241        modalities.insert("text".to_string(), text_features);
1242        modalities.insert("image".to_string(), image_features);
1243
1244        let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
1245
1246        let selector = MultiModalFeatureSelector::builder()
1247            .fusion_strategy("late")
1248            .modality_k([("text", 2), ("image", 1)])
1249            .build();
1250
1251        let trained = selector.fit(&modalities, &target).unwrap();
1252        let transformed = trained.transform(&modalities).unwrap();
1253
1254        assert_eq!(transformed.ncols(), 3); // 2 from text + 1 from image
1255        assert_eq!(transformed.nrows(), 4);
1256    }
1257
1258    #[test]
1259    fn test_normalize_features() {
1260        let features = Array2::from_shape_vec((3, 2), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]).unwrap();
1261        let normalized = normalize_features(&features).unwrap();
1262
1263        assert_eq!(normalized.dim(), (3, 2));
1264
1265        // Check that each column has approximately zero mean and unit variance
1266        for j in 0..2 {
1267            let column = normalized.column(j);
1268            let mean = column.sum() / column.len() as Float;
1269            assert!((mean).abs() < 1e-10);
1270        }
1271    }
1272
1273    // Note: get_support() is not implemented for MultiModalFeatureSelector
1274    // (SelectorMixin is commented out because it requires Transform<Array2> but this uses HashMap<String, Array2>)
1275    // #[test]
1276    // fn test_get_support() {
1277    //     let mut modalities = HashMap::new();
1278    //
1279    //     let text_features = Array2::from_shape_vec(
1280    //         (3, 4),
1281    //         vec![
1282    //             1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1283    //         ],
1284    //     )
1285    //     .unwrap();
1286    //     let image_features =
1287    //         Array2::from_shape_vec((3, 2), vec![13.0, 14.0, 15.0, 16.0, 17.0, 18.0]).unwrap();
1288    //
1289    //     modalities.insert("text".to_string(), text_features);
1290    //     modalities.insert("image".to_string(), image_features);
1291    //
1292    //     let target = Array1::from_vec(vec![0.0, 1.0, 1.0]);
1293    //
1294    //     let selector = MultiModalFeatureSelector::builder()
1295    //         .fusion_strategy("late")
1296    //         .modality_k([("text", 2), ("image", 1)])
1297    //         .build();
1298    //
1299    //     let trained = selector.fit(&modalities, &target).unwrap();
1300    //     let support = trained.get_support().unwrap();
1301    //
1302    //     assert_eq!(support.len(), 6); // 4 + 2 total features
1303    //     assert_eq!(support.iter().filter(|&&x| x).count(), 3); // 2 + 1 selected
1304    // }
1305
1306    #[test]
1307    fn test_pearson_correlation() {
1308        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1309        let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1310
1311        let correlation = compute_pearson_correlation(&x.view(), &y.view());
1312
1313        // Perfect positive correlation
1314        assert!((correlation - 1.0).abs() < 1e-10);
1315    }
1316}