1use scirs2_core::ndarray::{concatenate, s, Array1, Array2, ArrayView1, Axis};
84use sklears_core::error::{Result as SklResult, SklearsError};
85use sklears_core::traits::{Estimator, Fit, Transform};
86use std::collections::HashMap;
87use std::marker::PhantomData;
88
89type Result<T> = SklResult<T>;
90type Float = f64;
91
92#[derive(Debug, Clone)]
93pub struct Untrained;
94
95#[derive(Debug, Clone)]
96pub struct Trained {
97 fusion_strategy: String,
98 selected_features_per_modality: HashMap<String, Vec<usize>>,
99 combined_selected_features: Vec<(String, usize)>, feature_scores_per_modality: HashMap<String, Array1<Float>>,
101 cross_modal_scores: Option<Array2<Float>>,
102 modality_weights: HashMap<String, Float>,
103 feature_mapping: HashMap<usize, (String, usize)>, total_features: usize,
105 modality_feature_counts: HashMap<String, usize>,
106}
107
108#[derive(Debug, Clone)]
114pub struct MultiModalFeatureSelector<State = Untrained> {
115 fusion_strategy: String,
116 modality_weights: HashMap<String, Float>,
117 modality_k: HashMap<String, usize>,
118 cross_modal_analysis: bool,
119 cross_modal_threshold: Float,
120 normalize_modalities: bool,
121 handle_missing_modalities: bool,
122 min_modalities_required: usize,
123 missing_strategy: String,
124 k: Option<usize>,
125 score_threshold: Float,
126 correlation_method: String,
127 interaction_analysis: bool,
128 max_interaction_order: usize,
129 state: PhantomData<State>,
130 trained_state: Option<Trained>,
131}
132
133impl Default for MultiModalFeatureSelector<Untrained> {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl MultiModalFeatureSelector<Untrained> {
140 pub fn new() -> Self {
142 Self {
143 fusion_strategy: "hybrid".to_string(),
144 modality_weights: HashMap::new(),
145 modality_k: HashMap::new(),
146 cross_modal_analysis: true,
147 cross_modal_threshold: 0.1,
148 normalize_modalities: true,
149 handle_missing_modalities: true,
150 min_modalities_required: 1,
151 missing_strategy: "ignore".to_string(),
152 k: None,
153 score_threshold: 0.1,
154 correlation_method: "pearson".to_string(),
155 interaction_analysis: false,
156 max_interaction_order: 2,
157 state: PhantomData,
158 trained_state: None,
159 }
160 }
161
162 pub fn builder() -> MultiModalFeatureSelectorBuilder {
164 MultiModalFeatureSelectorBuilder::new()
165 }
166}
167
168#[derive(Debug)]
170pub struct MultiModalFeatureSelectorBuilder {
171 fusion_strategy: String,
172 modality_weights: HashMap<String, Float>,
173 modality_k: HashMap<String, usize>,
174 cross_modal_analysis: bool,
175 cross_modal_threshold: Float,
176 normalize_modalities: bool,
177 handle_missing_modalities: bool,
178 min_modalities_required: usize,
179 missing_strategy: String,
180 k: Option<usize>,
181 score_threshold: Float,
182 correlation_method: String,
183 interaction_analysis: bool,
184 max_interaction_order: usize,
185}
186
187impl Default for MultiModalFeatureSelectorBuilder {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193impl MultiModalFeatureSelectorBuilder {
194 pub fn new() -> Self {
195 Self {
196 fusion_strategy: "hybrid".to_string(),
197 modality_weights: HashMap::new(),
198 modality_k: HashMap::new(),
199 cross_modal_analysis: true,
200 cross_modal_threshold: 0.1,
201 normalize_modalities: true,
202 handle_missing_modalities: true,
203 min_modalities_required: 1,
204 missing_strategy: "ignore".to_string(),
205 k: None,
206 score_threshold: 0.1,
207 correlation_method: "pearson".to_string(),
208 interaction_analysis: false,
209 max_interaction_order: 2,
210 }
211 }
212
213 pub fn fusion_strategy(mut self, strategy: &str) -> Self {
215 self.fusion_strategy = strategy.to_string();
216 self
217 }
218
219 pub fn modality_weights<I>(mut self, weights: I) -> Self
221 where
222 I: IntoIterator<Item = (&'static str, f64)>,
223 {
224 self.modality_weights = weights
225 .into_iter()
226 .map(|(k, v)| (k.to_string(), v))
227 .collect();
228 self
229 }
230
231 pub fn modality_k<I>(mut self, k_values: I) -> Self
233 where
234 I: IntoIterator<Item = (&'static str, usize)>,
235 {
236 self.modality_k = k_values
237 .into_iter()
238 .map(|(k, v)| (k.to_string(), v))
239 .collect();
240 self
241 }
242
243 pub fn cross_modal_analysis(mut self, enable: bool) -> Self {
245 self.cross_modal_analysis = enable;
246 self
247 }
248
249 pub fn cross_modal_threshold(mut self, threshold: Float) -> Self {
251 self.cross_modal_threshold = threshold;
252 self
253 }
254
255 pub fn normalize_modalities(mut self, normalize: bool) -> Self {
257 self.normalize_modalities = normalize;
258 self
259 }
260
261 pub fn handle_missing_modalities(mut self, handle: bool) -> Self {
263 self.handle_missing_modalities = handle;
264 self
265 }
266
267 pub fn min_modalities_required(mut self, min: usize) -> Self {
269 self.min_modalities_required = min;
270 self
271 }
272
273 pub fn missing_strategy(mut self, strategy: &str) -> Self {
275 self.missing_strategy = strategy.to_string();
276 self
277 }
278
279 pub fn k(mut self, k: usize) -> Self {
281 self.k = Some(k);
282 self
283 }
284
285 pub fn score_threshold(mut self, threshold: Float) -> Self {
287 self.score_threshold = threshold;
288 self
289 }
290
291 pub fn correlation_method(mut self, method: &str) -> Self {
293 self.correlation_method = method.to_string();
294 self
295 }
296
297 pub fn interaction_analysis(mut self, enable: bool) -> Self {
299 self.interaction_analysis = enable;
300 self
301 }
302
303 pub fn max_interaction_order(mut self, order: usize) -> Self {
305 self.max_interaction_order = order;
306 self
307 }
308
309 pub fn build(self) -> MultiModalFeatureSelector<Untrained> {
311 MultiModalFeatureSelector {
312 fusion_strategy: self.fusion_strategy,
313 modality_weights: self.modality_weights,
314 modality_k: self.modality_k,
315 cross_modal_analysis: self.cross_modal_analysis,
316 cross_modal_threshold: self.cross_modal_threshold,
317 normalize_modalities: self.normalize_modalities,
318 handle_missing_modalities: self.handle_missing_modalities,
319 min_modalities_required: self.min_modalities_required,
320 missing_strategy: self.missing_strategy,
321 k: self.k,
322 score_threshold: self.score_threshold,
323 correlation_method: self.correlation_method,
324 interaction_analysis: self.interaction_analysis,
325 max_interaction_order: self.max_interaction_order,
326 state: PhantomData,
327 trained_state: None,
328 }
329 }
330}
331
332impl Estimator for MultiModalFeatureSelector<Untrained> {
333 type Config = ();
334 type Error = sklears_core::error::SklearsError;
335 type Float = Float;
336
337 fn config(&self) -> &Self::Config {
338 &()
339 }
340}
341
342impl Estimator for MultiModalFeatureSelector<Trained> {
343 type Config = ();
344 type Error = sklears_core::error::SklearsError;
345 type Float = Float;
346
347 fn config(&self) -> &Self::Config {
348 &()
349 }
350}
351
352impl Fit<HashMap<String, Array2<Float>>, Array1<Float>> for MultiModalFeatureSelector<Untrained> {
353 type Fitted = MultiModalFeatureSelector<Trained>;
354
355 fn fit(
356 self,
357 modalities: &HashMap<String, Array2<Float>>,
358 y: &Array1<Float>,
359 ) -> Result<Self::Fitted> {
360 if modalities.is_empty() {
361 return Err(SklearsError::InvalidInput(
362 "At least one modality is required".to_string(),
363 ));
364 }
365
366 if modalities.len() < self.min_modalities_required {
367 return Err(SklearsError::InvalidInput(format!(
368 "At least {} modalities are required",
369 self.min_modalities_required
370 )));
371 }
372
373 let n_samples = y.len();
375 for (modality_name, features) in modalities {
376 if features.nrows() != n_samples {
377 return Err(SklearsError::InvalidInput(format!(
378 "Modality '{}' has {} samples, expected {}",
379 modality_name,
380 features.nrows(),
381 n_samples
382 )));
383 }
384 }
385
386 let normalized_modalities = if self.normalize_modalities {
388 normalize_modalities(modalities)?
389 } else {
390 modalities.clone()
391 };
392
393 let mut modality_weights = self.modality_weights.clone();
395 if modality_weights.is_empty() {
396 let default_weight = 1.0 / modalities.len() as Float;
397 for modality_name in modalities.keys() {
398 modality_weights.insert(modality_name.clone(), default_weight);
399 }
400 }
401
402 let (
404 selected_features_per_modality,
405 combined_selected_features,
406 feature_scores_per_modality,
407 cross_modal_scores,
408 feature_mapping,
409 total_features,
410 modality_feature_counts,
411 ) = match self.fusion_strategy.as_str() {
412 "early" => self.early_fusion_selection(&normalized_modalities, y, &modality_weights)?,
413 "late" => self.late_fusion_selection(&normalized_modalities, y, &modality_weights)?,
414 "hybrid" => {
415 self.hybrid_fusion_selection(&normalized_modalities, y, &modality_weights)?
416 }
417 _ => {
418 return Err(SklearsError::InvalidInput(format!(
419 "Unknown fusion strategy: {}",
420 self.fusion_strategy
421 )))
422 }
423 };
424
425 let trained_state = Trained {
426 fusion_strategy: self.fusion_strategy.clone(),
427 selected_features_per_modality,
428 combined_selected_features,
429 feature_scores_per_modality,
430 cross_modal_scores,
431 modality_weights,
432 feature_mapping,
433 total_features,
434 modality_feature_counts,
435 };
436
437 Ok(MultiModalFeatureSelector {
438 fusion_strategy: self.fusion_strategy,
439 modality_weights: self.modality_weights,
440 modality_k: self.modality_k,
441 cross_modal_analysis: self.cross_modal_analysis,
442 cross_modal_threshold: self.cross_modal_threshold,
443 normalize_modalities: self.normalize_modalities,
444 handle_missing_modalities: self.handle_missing_modalities,
445 min_modalities_required: self.min_modalities_required,
446 missing_strategy: self.missing_strategy,
447 k: self.k,
448 score_threshold: self.score_threshold,
449 correlation_method: self.correlation_method,
450 interaction_analysis: self.interaction_analysis,
451 max_interaction_order: self.max_interaction_order,
452 state: PhantomData,
453 trained_state: Some(trained_state),
454 })
455 }
456}
457
458impl Transform<HashMap<String, Array2<Float>>, Array2<Float>>
459 for MultiModalFeatureSelector<Trained>
460{
461 fn transform(&self, modalities: &HashMap<String, Array2<Float>>) -> Result<Array2<Float>> {
462 let trained = self.trained_state.as_ref().ok_or_else(|| {
463 SklearsError::InvalidState("Selector must be fitted before transforming".to_string())
464 })?;
465
466 match trained.fusion_strategy.as_str() {
467 "early" => self.transform_early_fusion(modalities, trained),
468 "late" => self.transform_late_fusion(modalities, trained),
469 "hybrid" => self.transform_hybrid_fusion(modalities, trained),
470 _ => Err(SklearsError::InvalidState(format!(
471 "Unknown fusion strategy: {}",
472 trained.fusion_strategy
473 ))),
474 }
475 }
476}
477
478impl MultiModalFeatureSelector<Untrained> {
540 fn early_fusion_selection(
541 &self,
542 modalities: &HashMap<String, Array2<Float>>,
543 y: &Array1<Float>,
544 modality_weights: &HashMap<String, Float>,
545 ) -> Result<(
546 HashMap<String, Vec<usize>>,
547 Vec<(String, usize)>,
548 HashMap<String, Array1<Float>>,
549 Option<Array2<Float>>,
550 HashMap<usize, (String, usize)>,
551 usize,
552 HashMap<String, usize>,
553 )> {
554 let (combined_features, feature_mapping, modality_feature_counts) =
556 concatenate_modalities(modalities)?;
557
558 let feature_scores = compute_weighted_feature_scores(
560 &combined_features,
561 y,
562 modalities,
563 modality_weights,
564 &self.correlation_method,
565 )?;
566
567 let cross_modal_scores = if self.cross_modal_analysis {
569 Some(compute_cross_modal_correlations(
570 modalities,
571 &self.correlation_method,
572 )?)
573 } else {
574 None
575 };
576
577 let selected_indices = if let Some(k) = self.k {
579 select_top_k_features(&feature_scores, k)
580 } else {
581 select_features_by_threshold(&feature_scores, self.score_threshold)
582 };
583
584 let mut selected_features_per_modality = HashMap::new();
586 let mut combined_selected_features = Vec::new();
587
588 for &global_idx in &selected_indices {
589 if let Some(&(ref modality, local_idx)) = feature_mapping.get(&global_idx) {
590 selected_features_per_modality
591 .entry(modality.clone())
592 .or_insert_with(Vec::new)
593 .push(local_idx);
594 combined_selected_features.push((modality.clone(), global_idx));
595 }
596 }
597
598 let mut feature_scores_per_modality = HashMap::new();
600 let mut current_idx = 0;
601 for (modality_name, features) in modalities {
602 let n_features = features.ncols();
603 let modality_scores = feature_scores
604 .slice(s![current_idx..current_idx + n_features])
605 .to_owned();
606 feature_scores_per_modality.insert(modality_name.clone(), modality_scores);
607 current_idx += n_features;
608 }
609
610 Ok((
611 selected_features_per_modality,
612 combined_selected_features,
613 feature_scores_per_modality,
614 cross_modal_scores,
615 feature_mapping,
616 combined_features.ncols(),
617 modality_feature_counts,
618 ))
619 }
620
621 fn late_fusion_selection(
622 &self,
623 modalities: &HashMap<String, Array2<Float>>,
624 y: &Array1<Float>,
625 modality_weights: &HashMap<String, Float>,
626 ) -> Result<(
627 HashMap<String, Vec<usize>>,
628 Vec<(String, usize)>,
629 HashMap<String, Array1<Float>>,
630 Option<Array2<Float>>,
631 HashMap<usize, (String, usize)>,
632 usize,
633 HashMap<String, usize>,
634 )> {
635 let mut selected_features_per_modality = HashMap::new();
636 let mut feature_scores_per_modality = HashMap::new();
637 let mut combined_selected_features = Vec::new();
638 let mut feature_mapping = HashMap::new();
639 let mut modality_feature_counts = HashMap::new();
640 let mut total_features = 0;
641 let mut global_idx = 0;
642
643 for (modality_name, features) in modalities {
645 let n_features = features.ncols();
646 modality_feature_counts.insert(modality_name.clone(), n_features);
647
648 for local_idx in 0..n_features {
650 feature_mapping.insert(global_idx, (modality_name.clone(), local_idx));
651 global_idx += 1;
652 }
653
654 let weight = modality_weights.get(modality_name).cloned().unwrap_or(1.0);
655 let scores = compute_univariate_scores(features, y, &self.correlation_method)?;
656 let weighted_scores = scores.mapv(|s| s * weight);
657
658 let k = self
659 .modality_k
660 .get(modality_name)
661 .cloned()
662 .or(self.k.map(|total_k| total_k / modalities.len()))
663 .unwrap_or(n_features / 2);
664
665 let selected_indices = select_top_k_features(&weighted_scores, k.min(n_features));
666
667 for &local_idx in &selected_indices {
668 let global_feature_idx = total_features + local_idx;
669 combined_selected_features.push((modality_name.clone(), global_feature_idx));
670 }
671
672 selected_features_per_modality.insert(modality_name.clone(), selected_indices);
673 feature_scores_per_modality.insert(modality_name.clone(), weighted_scores);
674 total_features += n_features;
675 }
676
677 let cross_modal_scores = if self.cross_modal_analysis {
679 Some(compute_cross_modal_correlations(
680 modalities,
681 &self.correlation_method,
682 )?)
683 } else {
684 None
685 };
686
687 Ok((
688 selected_features_per_modality,
689 combined_selected_features,
690 feature_scores_per_modality,
691 cross_modal_scores,
692 feature_mapping,
693 total_features,
694 modality_feature_counts,
695 ))
696 }
697
698 fn hybrid_fusion_selection(
699 &self,
700 modalities: &HashMap<String, Array2<Float>>,
701 y: &Array1<Float>,
702 modality_weights: &HashMap<String, Float>,
703 ) -> Result<(
704 HashMap<String, Vec<usize>>,
705 Vec<(String, usize)>,
706 HashMap<String, Array1<Float>>,
707 Option<Array2<Float>>,
708 HashMap<usize, (String, usize)>,
709 usize,
710 HashMap<String, usize>,
711 )> {
712 let (
714 early_selected,
715 _early_combined,
716 early_scores,
717 early_cross_modal,
718 early_mapping,
719 early_total,
720 early_counts,
721 ) = self.early_fusion_selection(modalities, y, modality_weights)?;
722
723 let (
724 late_selected,
725 _late_combined,
726 late_scores,
727 late_cross_modal,
728 _late_mapping,
729 _late_total,
730 _late_counts,
731 ) = self.late_fusion_selection(modalities, y, modality_weights)?;
732
733 let mut final_selected_per_modality = HashMap::new();
735 let mut final_combined_selected = Vec::new();
736 let mut final_scores_per_modality = HashMap::new();
737
738 for (modality_name, early_indices) in &early_selected {
739 let late_indices = late_selected
740 .get(modality_name)
741 .cloned()
742 .unwrap_or_default();
743 let early_modality_scores = &early_scores[modality_name];
744 let late_modality_scores = &late_scores[modality_name];
745
746 let combined_scores = (early_modality_scores + late_modality_scores) / 2.0;
748
749 let mut all_candidates: std::collections::HashSet<usize> =
751 early_indices.iter().cloned().collect();
752 all_candidates.extend(late_indices.iter());
753
754 let mut candidate_scores: Vec<(usize, Float)> = all_candidates
755 .iter()
756 .map(|&idx| (idx, combined_scores[idx]))
757 .collect();
758
759 candidate_scores
760 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
761
762 let target_k = self
763 .modality_k
764 .get(modality_name)
765 .cloned()
766 .or(self.k.map(|total_k| total_k / modalities.len()))
767 .unwrap_or(candidate_scores.len() / 2);
768
769 let final_indices: Vec<usize> = candidate_scores
770 .into_iter()
771 .take(target_k.min(all_candidates.len()))
772 .map(|(idx, _)| idx)
773 .collect();
774
775 for &local_idx in &final_indices {
777 if let Some(base_offset) = compute_modality_offset(modality_name, modalities) {
778 final_combined_selected.push((modality_name.clone(), base_offset + local_idx));
779 }
780 }
781
782 final_selected_per_modality.insert(modality_name.clone(), final_indices);
783 final_scores_per_modality.insert(modality_name.clone(), combined_scores);
784 }
785
786 let cross_modal_scores = early_cross_modal.or(late_cross_modal);
788
789 Ok((
790 final_selected_per_modality,
791 final_combined_selected,
792 final_scores_per_modality,
793 cross_modal_scores,
794 early_mapping,
795 early_total,
796 early_counts,
797 ))
798 }
799}
800
801impl MultiModalFeatureSelector<Trained> {
802 fn transform_early_fusion(
803 &self,
804 modalities: &HashMap<String, Array2<Float>>,
805 trained: &Trained,
806 ) -> Result<Array2<Float>> {
807 let (combined_features, _, _) = concatenate_modalities(modalities)?;
808
809 let global_indices: Vec<usize> = trained
810 .combined_selected_features
811 .iter()
812 .map(|(_, global_idx)| *global_idx)
813 .collect();
814
815 if global_indices.is_empty() {
816 return Err(SklearsError::InvalidState(
817 "No features were selected".to_string(),
818 ));
819 }
820
821 let selected_data = combined_features.select(Axis(1), &global_indices);
822 Ok(selected_data)
823 }
824
825 fn transform_late_fusion(
826 &self,
827 modalities: &HashMap<String, Array2<Float>>,
828 trained: &Trained,
829 ) -> Result<Array2<Float>> {
830 let mut selected_features_owned = Vec::new();
831
832 for (modality_name, features) in modalities {
833 if let Some(selected_indices) =
834 trained.selected_features_per_modality.get(modality_name)
835 {
836 if !selected_indices.is_empty() {
837 let selected_modality_features = features.select(Axis(1), selected_indices);
838 selected_features_owned.push(selected_modality_features);
839 }
840 }
841 }
842
843 if selected_features_owned.is_empty() {
844 return Err(SklearsError::InvalidState(
845 "No features were selected from any modality".to_string(),
846 ));
847 }
848
849 let selected_features_list: Vec<_> =
850 selected_features_owned.iter().map(|a| a.view()).collect();
851 let combined = concatenate(Axis(1), &selected_features_list).map_err(|e| {
852 SklearsError::InvalidInput(format!("Failed to concatenate selected features: {}", e))
853 })?;
854
855 Ok(combined)
856 }
857
858 fn transform_hybrid_fusion(
859 &self,
860 modalities: &HashMap<String, Array2<Float>>,
861 trained: &Trained,
862 ) -> Result<Array2<Float>> {
863 self.transform_late_fusion(modalities, trained)
865 }
866}
867
868fn normalize_modalities(
871 modalities: &HashMap<String, Array2<Float>>,
872) -> Result<HashMap<String, Array2<Float>>> {
873 let mut normalized = HashMap::new();
874
875 for (modality_name, features) in modalities {
876 let normalized_features = normalize_features(features)?;
877 normalized.insert(modality_name.clone(), normalized_features);
878 }
879
880 Ok(normalized)
881}
882
883fn normalize_features(features: &Array2<Float>) -> Result<Array2<Float>> {
884 let (n_samples, n_features) = features.dim();
885 let mut normalized = Array2::zeros((n_samples, n_features));
886
887 for j in 0..n_features {
888 let feature = features.column(j);
889 let mean = feature.sum() / n_samples as Float;
890 let variance = feature.mapv(|x| (x - mean).powi(2)).sum() / n_samples as Float;
891 let std_dev = variance.sqrt();
892
893 if std_dev > 1e-8 {
894 for i in 0..n_samples {
895 normalized[[i, j]] = (features[[i, j]] - mean) / std_dev;
896 }
897 } else {
898 normalized.column_mut(j).fill(0.0);
900 }
901 }
902
903 Ok(normalized)
904}
905
906fn concatenate_modalities(
907 modalities: &HashMap<String, Array2<Float>>,
908) -> Result<(
909 Array2<Float>,
910 HashMap<usize, (String, usize)>,
911 HashMap<String, usize>,
912)> {
913 if modalities.is_empty() {
914 return Err(SklearsError::InvalidInput(
915 "No modalities provided".to_string(),
916 ));
917 }
918
919 let n_samples = modalities.values().next().unwrap().nrows();
920 let mut feature_views = Vec::new();
921 let mut feature_mapping = HashMap::new();
922 let mut modality_feature_counts = HashMap::new();
923 let mut global_idx = 0;
924
925 for (modality_name, features) in modalities {
926 if features.nrows() != n_samples {
927 return Err(SklearsError::InvalidInput(
928 format!("All modalities must have the same number of samples. Expected {}, got {} for modality '{}'",
929 n_samples, features.nrows(), modality_name)
930 ));
931 }
932
933 let n_features = features.ncols();
934 modality_feature_counts.insert(modality_name.clone(), n_features);
935
936 for local_idx in 0..n_features {
937 feature_mapping.insert(global_idx, (modality_name.clone(), local_idx));
938 global_idx += 1;
939 }
940
941 feature_views.push(features.view());
942 }
943
944 let combined = concatenate(Axis(1), &feature_views).map_err(|e| {
945 SklearsError::InvalidInput(format!("Failed to concatenate modalities: {}", e))
946 })?;
947
948 Ok((combined, feature_mapping, modality_feature_counts))
949}
950
951fn compute_weighted_feature_scores(
952 features: &Array2<Float>,
953 y: &Array1<Float>,
954 modalities: &HashMap<String, Array2<Float>>,
955 modality_weights: &HashMap<String, Float>,
956 correlation_method: &str,
957) -> Result<Array1<Float>> {
958 let (_, n_features) = features.dim();
959 let mut scores = Array1::zeros(n_features);
960 let mut feature_idx = 0;
961
962 for (modality_name, modality_features) in modalities {
963 let weight = modality_weights.get(modality_name).cloned().unwrap_or(1.0);
964 let modality_scores = compute_univariate_scores(modality_features, y, correlation_method)?;
965 let weighted_scores = modality_scores.mapv(|s| s * weight);
966
967 let end_idx = feature_idx + modality_features.ncols();
968 scores
969 .slice_mut(s![feature_idx..end_idx])
970 .assign(&weighted_scores);
971 feature_idx = end_idx;
972 }
973
974 Ok(scores)
975}
976
977fn compute_univariate_scores(
978 features: &Array2<Float>,
979 y: &Array1<Float>,
980 method: &str,
981) -> Result<Array1<Float>> {
982 let (_, n_features) = features.dim();
983 let mut scores = Array1::zeros(n_features);
984
985 for j in 0..n_features {
986 let feature = features.column(j);
987 let score = match method {
988 "pearson" => compute_pearson_correlation(&feature, &y.view()),
989 "spearman" => compute_spearman_correlation(&feature, &y.view()),
990 "mutual_info" => compute_mutual_information(&feature, &y.view()),
991 _ => compute_pearson_correlation(&feature, &y.view()),
992 };
993 scores[j] = score.abs();
994 }
995
996 Ok(scores)
997}
998
999fn compute_cross_modal_correlations(
1000 modalities: &HashMap<String, Array2<Float>>,
1001 method: &str,
1002) -> Result<Array2<Float>> {
1003 let modality_names: Vec<&String> = modalities.keys().collect();
1004 let n_modalities = modality_names.len();
1005 let mut correlations = Array2::zeros((n_modalities, n_modalities));
1006
1007 for i in 0..n_modalities {
1008 for j in i..n_modalities {
1009 if i == j {
1010 correlations[[i, j]] = 1.0;
1011 } else {
1012 let features_i = &modalities[modality_names[i]];
1013 let features_j = &modalities[modality_names[j]];
1014
1015 let correlation = compute_modality_correlation(features_i, features_j, method)?;
1016 correlations[[i, j]] = correlation;
1017 correlations[[j, i]] = correlation;
1018 }
1019 }
1020 }
1021
1022 Ok(correlations)
1023}
1024
1025fn compute_modality_correlation(
1026 features_a: &Array2<Float>,
1027 features_b: &Array2<Float>,
1028 method: &str,
1029) -> Result<Float> {
1030 let (_, n_features_a) = features_a.dim();
1032 let (_, n_features_b) = features_b.dim();
1033
1034 let mut total_correlation = 0.0;
1035 let mut count = 0;
1036
1037 for i in 0..n_features_a.min(10) {
1038 for j in 0..n_features_b.min(10) {
1040 let feature_a = features_a.column(i);
1041 let feature_b = features_b.column(j);
1042
1043 let correlation = match method {
1044 "pearson" => compute_pearson_correlation(&feature_a, &feature_b),
1045 "spearman" => compute_spearman_correlation(&feature_a, &feature_b),
1046 _ => compute_pearson_correlation(&feature_a, &feature_b),
1047 };
1048
1049 total_correlation += correlation.abs();
1050 count += 1;
1051 }
1052 }
1053
1054 Ok(if count > 0 {
1055 total_correlation / count as Float
1056 } else {
1057 0.0
1058 })
1059}
1060
1061fn compute_pearson_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
1062 let n = x.len();
1063 if n != y.len() || n == 0 {
1064 return 0.0;
1065 }
1066
1067 let mean_x = x.sum() / n as Float;
1068 let mean_y = y.sum() / n as Float;
1069
1070 let mut numerator = 0.0;
1071 let mut sum_sq_x = 0.0;
1072 let mut sum_sq_y = 0.0;
1073
1074 for i in 0..n {
1075 let dx = x[i] - mean_x;
1076 let dy = y[i] - mean_y;
1077 numerator += dx * dy;
1078 sum_sq_x += dx * dx;
1079 sum_sq_y += dy * dy;
1080 }
1081
1082 let denominator = (sum_sq_x * sum_sq_y).sqrt();
1083 if denominator == 0.0 {
1084 0.0
1085 } else {
1086 numerator / denominator
1087 }
1088}
1089
1090fn compute_spearman_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
1091 let ranks_x = compute_ranks(x);
1093 let ranks_y = compute_ranks(y);
1094 compute_pearson_correlation(&ranks_x.view(), &ranks_y.view())
1095}
1096
1097fn compute_mutual_information(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
1098 compute_pearson_correlation(x, y).abs()
1100}
1101
1102fn compute_ranks(values: &ArrayView1<Float>) -> Array1<Float> {
1103 let n = values.len();
1104 let mut indexed_values: Vec<(usize, Float)> =
1105 values.iter().enumerate().map(|(i, &v)| (i, v)).collect();
1106
1107 indexed_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
1108
1109 let mut ranks = Array1::zeros(n);
1110 for (rank, &(original_idx, _)) in indexed_values.iter().enumerate() {
1111 ranks[original_idx] = rank as Float;
1112 }
1113
1114 ranks
1115}
1116
1117fn compute_modality_offset(
1118 modality_name: &str,
1119 modalities: &HashMap<String, Array2<Float>>,
1120) -> Option<usize> {
1121 let mut offset = 0;
1122 for (name, features) in modalities {
1123 if name == modality_name {
1124 return Some(offset);
1125 }
1126 offset += features.ncols();
1127 }
1128 None
1129}
1130
1131fn select_top_k_features(scores: &Array1<Float>, k: usize) -> Vec<usize> {
1132 let mut indexed_scores: Vec<(usize, Float)> = scores
1133 .iter()
1134 .enumerate()
1135 .map(|(i, &score)| (i, score))
1136 .collect();
1137
1138 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1139
1140 indexed_scores
1141 .into_iter()
1142 .take(k.min(scores.len()))
1143 .map(|(i, _)| i)
1144 .collect()
1145}
1146
1147fn select_features_by_threshold(scores: &Array1<Float>, threshold: Float) -> Vec<usize> {
1148 scores
1149 .iter()
1150 .enumerate()
1151 .filter(|(_, &score)| score >= threshold)
1152 .map(|(i, _)| i)
1153 .collect()
1154}
1155
1156#[allow(non_snake_case)]
1157#[cfg(test)]
1158mod tests {
1159 use super::*;
1160
1161 #[test]
1162 fn test_multi_modal_feature_selector_creation() {
1163 let selector = MultiModalFeatureSelector::new();
1164 assert_eq!(selector.fusion_strategy, "hybrid");
1165 assert!(selector.cross_modal_analysis);
1166 assert!(selector.normalize_modalities);
1167 }
1168
1169 #[test]
1170 fn test_multi_modal_feature_selector_builder() {
1171 let selector = MultiModalFeatureSelector::builder()
1172 .fusion_strategy("early")
1173 .cross_modal_analysis(false)
1174 .k(10)
1175 .build();
1176
1177 assert_eq!(selector.fusion_strategy, "early");
1178 assert!(!selector.cross_modal_analysis);
1179 assert_eq!(selector.k, Some(10));
1180 }
1181
1182 #[test]
1183 fn test_concatenate_modalities() {
1184 let mut modalities = HashMap::new();
1185
1186 let text_features =
1187 Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1188 let image_features =
1189 Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
1190
1191 modalities.insert("text".to_string(), text_features);
1192 modalities.insert("image".to_string(), image_features);
1193
1194 let (combined, mapping, counts) = concatenate_modalities(&modalities).unwrap();
1195
1196 assert_eq!(combined.dim(), (3, 4));
1197 assert_eq!(mapping.len(), 4);
1198 assert_eq!(counts.len(), 2);
1199 }
1200
1201 #[test]
1202 fn test_fit_transform_early_fusion() {
1203 let mut modalities = HashMap::new();
1204
1205 let text_features =
1206 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1207 let image_features =
1208 Array2::from_shape_vec((4, 2), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
1209
1210 modalities.insert("text".to_string(), text_features);
1211 modalities.insert("image".to_string(), image_features);
1212
1213 let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
1214
1215 let selector = MultiModalFeatureSelector::builder()
1216 .fusion_strategy("early")
1217 .k(2)
1218 .build();
1219
1220 let trained = selector.fit(&modalities, &target).unwrap();
1221 let transformed = trained.transform(&modalities).unwrap();
1222
1223 assert_eq!(transformed.ncols(), 2);
1224 assert_eq!(transformed.nrows(), 4);
1225 }
1226
1227 #[test]
1228 fn test_fit_transform_late_fusion() {
1229 let mut modalities = HashMap::new();
1230
1231 let text_features = Array2::from_shape_vec(
1232 (4, 3),
1233 vec![
1234 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1235 ],
1236 )
1237 .unwrap();
1238 let image_features =
1239 Array2::from_shape_vec((4, 2), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).unwrap();
1240
1241 modalities.insert("text".to_string(), text_features);
1242 modalities.insert("image".to_string(), image_features);
1243
1244 let target = Array1::from_vec(vec![0.0, 1.0, 1.0, 0.0]);
1245
1246 let selector = MultiModalFeatureSelector::builder()
1247 .fusion_strategy("late")
1248 .modality_k([("text", 2), ("image", 1)])
1249 .build();
1250
1251 let trained = selector.fit(&modalities, &target).unwrap();
1252 let transformed = trained.transform(&modalities).unwrap();
1253
1254 assert_eq!(transformed.ncols(), 3); assert_eq!(transformed.nrows(), 4);
1256 }
1257
1258 #[test]
1259 fn test_normalize_features() {
1260 let features = Array2::from_shape_vec((3, 2), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]).unwrap();
1261 let normalized = normalize_features(&features).unwrap();
1262
1263 assert_eq!(normalized.dim(), (3, 2));
1264
1265 for j in 0..2 {
1267 let column = normalized.column(j);
1268 let mean = column.sum() / column.len() as Float;
1269 assert!((mean).abs() < 1e-10);
1270 }
1271 }
1272
1273 #[test]
1307 fn test_pearson_correlation() {
1308 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1309 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
1310
1311 let correlation = compute_pearson_correlation(&x.view(), &y.view());
1312
1313 assert!((correlation - 1.0).abs() < 1e-10);
1315 }
1316}