sklears_mixture/
multi_modal.rs

1//! Multi-Modal Data Mixture Models
2//!
3//! This module provides mixture models that can handle multi-modal and heterogeneous data types.
4//! It includes implementations for multi-view mixture models, coupled mixture models,
5//! shared latent variable models, and cross-modal alignment techniques.
6//! All implementations follow SciRS2 Policy for numerical computing and random number generation.
7
8use crate::common::{CovarianceType, ModelSelection};
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
10use scirs2_core::random::{thread_rng, RandNormal, RandUniform};
11use sklears_core::{
12    error::{Result as SklResult, SklearsError},
13    traits::{Estimator, Fit, Predict, Trained, Untrained},
14};
15use std::collections::HashMap;
16
17/// Type of multi-modal data fusion strategy
18#[derive(Debug, Clone, PartialEq)]
19pub enum FusionStrategy {
20    /// Early fusion: Concatenate features from all modalities
21    EarlyFusion,
22    /// Late fusion: Train separate models then combine predictions
23    LateFusion,
24    /// Intermediate fusion: Learn shared latent representation
25    IntermediateFusion,
26    /// Coupled fusion: Joint optimization across modalities
27    CoupledFusion,
28}
29
30/// Multi-view data modality specification
31#[derive(Debug, Clone)]
32pub struct ModalitySpec {
33    /// Name of the modality (e.g., "visual", "textual", "audio")
34    pub name: String,
35    /// Feature dimension for this modality
36    pub n_features: usize,
37    /// Covariance type for this modality
38    pub covariance_type: CovarianceType,
39    /// Weight for this modality in the fusion process
40    pub modality_weight: f64,
41}
42
43/// Configuration for multi-modal mixture models
44#[derive(Debug, Clone)]
45pub struct MultiModalConfig {
46    /// n_components
47    pub n_components: usize,
48    /// modalities
49    pub modalities: Vec<ModalitySpec>,
50    /// fusion_strategy
51    pub fusion_strategy: FusionStrategy,
52    /// shared_latent_dim
53    pub shared_latent_dim: Option<usize>,
54    /// coupling_strength
55    pub coupling_strength: f64,
56    /// max_iter
57    pub max_iter: usize,
58    /// tol
59    pub tol: f64,
60    /// regularization_strength
61    pub regularization_strength: f64,
62    /// random_state
63    pub random_state: Option<u64>,
64}
65
66/// Multi-Modal Gaussian Mixture Model
67///
68/// A mixture model that can handle multiple data modalities simultaneously,
69/// learning both modality-specific patterns and cross-modal relationships.
70/// This is useful for datasets with multiple types of features (e.g., visual + textual,
71/// sensor data from multiple sources, etc.).
72#[derive(Debug, Clone)]
73pub struct MultiModalGaussianMixture<S = Untrained> {
74    config: MultiModalConfig,
75    _phantom: std::marker::PhantomData<S>,
76}
77
78/// Trained Multi-Modal GMM
79#[derive(Debug, Clone)]
80pub struct MultiModalGaussianMixtureTrained {
81    /// weights
82    pub weights: Array1<f64>,
83    /// modality_means
84    pub modality_means: HashMap<String, Array2<f64>>,
85    /// modality_covariances
86    pub modality_covariances: HashMap<String, Array3<f64>>,
87    /// shared_latent_means
88    pub shared_latent_means: Option<Array2<f64>>,
89    /// latent_projections
90    pub latent_projections: HashMap<String, Array2<f64>>,
91    /// coupling_parameters
92    pub coupling_parameters: Array2<f64>,
93    /// log_likelihood_history
94    pub log_likelihood_history: Vec<f64>,
95    /// n_iter
96    pub n_iter: usize,
97    /// config
98    pub config: MultiModalConfig,
99}
100
101/// Builder for Multi-Modal GMM
102#[derive(Debug, Clone)]
103pub struct MultiModalGaussianMixtureBuilder {
104    n_components: usize,
105    modalities: Vec<ModalitySpec>,
106    fusion_strategy: FusionStrategy,
107    shared_latent_dim: Option<usize>,
108    coupling_strength: f64,
109    max_iter: usize,
110    tol: f64,
111    regularization_strength: f64,
112    random_state: Option<u64>,
113}
114
115impl MultiModalGaussianMixtureBuilder {
116    /// Create a new builder with specified number of components
117    pub fn new(n_components: usize) -> Self {
118        Self {
119            n_components,
120            modalities: Vec::new(),
121            fusion_strategy: FusionStrategy::IntermediateFusion,
122            shared_latent_dim: None,
123            coupling_strength: 0.1,
124            max_iter: 100,
125            tol: 1e-4,
126            regularization_strength: 0.01,
127            random_state: None,
128        }
129    }
130
131    /// Add a data modality
132    pub fn add_modality(mut self, modality: ModalitySpec) -> Self {
133        self.modalities.push(modality);
134        self
135    }
136
137    /// Add a data modality with default settings
138    pub fn add_modality_simple(mut self, name: &str, n_features: usize) -> Self {
139        let modality = ModalitySpec {
140            name: name.to_string(),
141            n_features,
142            covariance_type: CovarianceType::Full,
143            modality_weight: 1.0,
144        };
145        self.modalities.push(modality);
146        self
147    }
148
149    /// Set fusion strategy
150    pub fn fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
151        self.fusion_strategy = strategy;
152        self
153    }
154
155    /// Set shared latent dimension (for intermediate fusion)
156    pub fn shared_latent_dim(mut self, dim: usize) -> Self {
157        self.shared_latent_dim = Some(dim);
158        self
159    }
160
161    /// Set coupling strength between modalities
162    pub fn coupling_strength(mut self, strength: f64) -> Self {
163        self.coupling_strength = strength.clamp(0.0, 1.0);
164        self
165    }
166
167    /// Set maximum iterations
168    pub fn max_iter(mut self, max_iter: usize) -> Self {
169        self.max_iter = max_iter;
170        self
171    }
172
173    /// Set convergence tolerance
174    pub fn tolerance(mut self, tol: f64) -> Self {
175        self.tol = tol;
176        self
177    }
178
179    /// Set regularization strength
180    pub fn regularization_strength(mut self, strength: f64) -> Self {
181        self.regularization_strength = strength.max(0.0);
182        self
183    }
184
185    /// Set random state for reproducibility
186    pub fn random_state(mut self, random_state: u64) -> Self {
187        self.random_state = Some(random_state);
188        self
189    }
190
191    /// Build the multi-modal GMM
192    pub fn build(self) -> SklResult<MultiModalGaussianMixture<Untrained>> {
193        if self.modalities.is_empty() {
194            return Err(SklearsError::InvalidInput(
195                "At least one modality must be specified".to_string(),
196            ));
197        }
198
199        // For intermediate fusion, ensure latent dimension is specified
200        if self.fusion_strategy == FusionStrategy::IntermediateFusion
201            && self.shared_latent_dim.is_none()
202        {
203            return Err(SklearsError::InvalidInput(
204                "Shared latent dimension must be specified for intermediate fusion".to_string(),
205            ));
206        }
207
208        let config = MultiModalConfig {
209            n_components: self.n_components,
210            modalities: self.modalities,
211            fusion_strategy: self.fusion_strategy,
212            shared_latent_dim: self.shared_latent_dim,
213            coupling_strength: self.coupling_strength,
214            max_iter: self.max_iter,
215            tol: self.tol,
216            regularization_strength: self.regularization_strength,
217            random_state: self.random_state,
218        };
219
220        Ok(MultiModalGaussianMixture {
221            config,
222            _phantom: std::marker::PhantomData,
223        })
224    }
225}
226
227impl Estimator<Untrained> for MultiModalGaussianMixture<Untrained> {
228    type Config = MultiModalConfig;
229    type Error = SklearsError;
230    type Float = f64;
231
232    fn config(&self) -> &Self::Config {
233        &self.config
234    }
235}
236
237impl Estimator<Trained> for MultiModalGaussianMixture<Trained> {
238    type Config = MultiModalConfig;
239    type Error = SklearsError;
240    type Float = f64;
241
242    fn config(&self) -> &Self::Config {
243        &self.config
244    }
245}
246
247impl Fit<HashMap<String, Array2<f64>>, Option<Array1<usize>>>
248    for MultiModalGaussianMixture<Untrained>
249{
250    type Fitted = MultiModalGaussianMixtureTrained;
251
252    fn fit(
253        self,
254        X: &HashMap<String, Array2<f64>>,
255        y: &Option<Array1<usize>>,
256    ) -> SklResult<Self::Fitted> {
257        // Validate input data matches configured modalities
258        for modality in &self.config.modalities {
259            if !X.contains_key(&modality.name) {
260                return Err(SklearsError::InvalidInput(format!(
261                    "Missing data for modality: {}",
262                    modality.name
263                )));
264            }
265            let data = &X[&modality.name];
266            if data.ncols() != modality.n_features {
267                return Err(SklearsError::InvalidInput(format!(
268                    "Feature dimension mismatch for modality {}: expected {}, got {}",
269                    modality.name,
270                    modality.n_features,
271                    data.ncols()
272                )));
273            }
274        }
275
276        // Get sample size (assuming all modalities have same number of samples)
277        let n_samples = X.values().next().unwrap().nrows();
278        for (name, data) in X.iter() {
279            if data.nrows() != n_samples {
280                return Err(SklearsError::InvalidInput(format!(
281                    "Sample size mismatch for modality {}: expected {}, got {}",
282                    name,
283                    n_samples,
284                    data.nrows()
285                )));
286            }
287        }
288
289        match self.config.fusion_strategy {
290            FusionStrategy::EarlyFusion => self.fit_early_fusion(X, y),
291            FusionStrategy::LateFusion => self.fit_late_fusion(X, y),
292            FusionStrategy::IntermediateFusion => self.fit_intermediate_fusion(X, y),
293            FusionStrategy::CoupledFusion => self.fit_coupled_fusion(X, y),
294        }
295    }
296}
297
298impl MultiModalGaussianMixture<Untrained> {
299    /// Initialize parameters using K-means++ style initialization
300    fn initialize_parameters(
301        &self,
302        X: &HashMap<String, Array2<f64>>,
303    ) -> SklResult<(
304        Array1<f64>,
305        HashMap<String, Array2<f64>>,
306        HashMap<String, Array3<f64>>,
307    )> {
308        let n_samples = X.values().next().unwrap().nrows();
309        let n_components = self.config.n_components;
310
311        // Initialize component weights uniformly
312        let weights = Array1::ones(n_components) / n_components as f64;
313
314        // Initialize modality-specific parameters
315        let mut modality_means = HashMap::new();
316        let mut modality_covariances = HashMap::new();
317
318        let mut rng = thread_rng();
319
320        for modality in &self.config.modalities {
321            let data = &X[&modality.name];
322            let n_features = data.ncols();
323
324            // Initialize means using random samples
325            let mut means = Array2::zeros((n_components, n_features));
326            for k in 0..n_components {
327                let uniform = RandUniform::new(0, n_samples).map_err(|e| {
328                    SklearsError::InvalidInput(format!("Uniform distribution error: {}", e))
329                })?;
330                let sample_idx = rng.sample(uniform);
331                means.row_mut(k).assign(&data.row(sample_idx));
332            }
333
334            // Initialize covariances
335            let covariances = match modality.covariance_type {
336                CovarianceType::Full => {
337                    let mut cov = Array3::zeros((n_components, n_features, n_features));
338                    for k in 0..n_components {
339                        for i in 0..n_features {
340                            cov[[k, i, i]] = 1.0; // Start with identity
341                        }
342                    }
343                    cov
344                }
345                CovarianceType::Diagonal => {
346                    let mut cov = Array3::zeros((n_components, n_features, 1));
347                    for k in 0..n_components {
348                        for i in 0..n_features {
349                            cov[[k, i, 0]] = 1.0;
350                        }
351                    }
352                    cov
353                }
354                CovarianceType::Tied => {
355                    let mut cov = Array3::zeros((1, n_features, n_features));
356                    for i in 0..n_features {
357                        cov[[0, i, i]] = 1.0;
358                    }
359                    cov
360                }
361                CovarianceType::Spherical => Array3::ones((n_components, 1, 1)),
362            };
363
364            modality_means.insert(modality.name.clone(), means);
365            modality_covariances.insert(modality.name.clone(), covariances);
366        }
367
368        Ok((weights, modality_means, modality_covariances))
369    }
370
371    /// Early fusion implementation: concatenate all modality features
372    fn fit_early_fusion(
373        &self,
374        X: &HashMap<String, Array2<f64>>,
375        _y: &Option<Array1<usize>>,
376    ) -> SklResult<MultiModalGaussianMixtureTrained> {
377        let n_samples = X.values().next().unwrap().nrows();
378
379        // Concatenate all modality data
380        let mut concatenated_features = Vec::new();
381        for modality in &self.config.modalities {
382            concatenated_features.push(X[&modality.name].clone());
383        }
384
385        // Stack horizontally
386        let mut combined_data = concatenated_features[0].clone();
387        for i in 1..concatenated_features.len() {
388            let current_cols = combined_data.ncols();
389            let new_cols = concatenated_features[i].ncols();
390            let mut new_data = Array2::zeros((n_samples, current_cols + new_cols));
391            new_data
392                .slice_mut(s![.., ..current_cols])
393                .assign(&combined_data);
394            new_data
395                .slice_mut(s![.., current_cols..])
396                .assign(&concatenated_features[i]);
397            combined_data = new_data;
398        }
399
400        // Run standard GMM on concatenated data
401        let (mut weights, _means_map, _covariances_map) = self.initialize_parameters(X)?;
402        let mut log_likelihood_history = Vec::new();
403
404        // Since we're doing early fusion, we need to work with the concatenated means
405        let total_features: usize = self.config.modalities.iter().map(|m| m.n_features).sum();
406        let mut combined_means = Array2::zeros((self.config.n_components, total_features));
407        let mut combined_covariances =
408            Array3::zeros((self.config.n_components, total_features, total_features));
409
410        // Initialize with identity covariances for simplicity
411        for k in 0..self.config.n_components {
412            for i in 0..total_features {
413                combined_covariances[[k, i, i]] = 1.0;
414            }
415        }
416
417        // EM algorithm for early fusion
418        for iter in 0..self.config.max_iter {
419            let old_log_likelihood = if log_likelihood_history.is_empty() {
420                f64::NEG_INFINITY
421            } else {
422                *log_likelihood_history.last().unwrap()
423            };
424
425            // E-step: Compute responsibilities
426            let mut responsibilities = Array2::zeros((n_samples, self.config.n_components));
427            let mut log_likelihood = 0.0;
428
429            for i in 0..n_samples {
430                let sample = combined_data.row(i);
431                let mut log_probs = Array1::zeros(self.config.n_components);
432
433                for k in 0..self.config.n_components {
434                    let mean = combined_means.row(k);
435                    let diff = &sample.to_owned() - &mean.to_owned();
436                    let log_det = combined_covariances
437                        .slice(s![k, .., ..])
438                        .diag()
439                        .mapv(|x: f64| x.ln())
440                        .sum();
441                    let inv_quad = diff.dot(&diff); // Simplified - should use proper inverse
442
443                    log_probs[k] = weights[k].ln()
444                        - 0.5
445                            * (total_features as f64 * (2.0 * std::f64::consts::PI).ln()
446                                + log_det
447                                + inv_quad);
448                }
449
450                // Numerical stability
451                let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
452                let log_sum_exp =
453                    (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
454                log_likelihood += log_sum_exp;
455
456                for k in 0..self.config.n_components {
457                    responsibilities[[i, k]] = ((log_probs[k] - log_sum_exp).exp()).max(1e-15);
458                }
459            }
460
461            log_likelihood_history.push(log_likelihood);
462
463            // Check convergence
464            if iter > 0 && (log_likelihood - old_log_likelihood).abs() < self.config.tol {
465                break;
466            }
467
468            // M-step: Update parameters
469            let n_k: Array1<f64> = responsibilities.sum_axis(Axis(0));
470
471            // Update weights
472            weights = &n_k / n_samples as f64;
473
474            // Update means
475            for k in 0..self.config.n_components {
476                if n_k[k] > 1e-15 {
477                    let weighted_sum = responsibilities.column(k).iter().enumerate().fold(
478                        Array1::zeros(total_features),
479                        |mut acc, (i, &resp)| {
480                            let sample = combined_data.row(i);
481                            for j in 0..total_features {
482                                acc[j] += resp * sample[j];
483                            }
484                            acc
485                        },
486                    );
487                    combined_means.row_mut(k).assign(&(weighted_sum / n_k[k]));
488                }
489            }
490
491            // Update covariances (diagonal approximation for simplicity)
492            for k in 0..self.config.n_components {
493                if n_k[k] > 1e-15 {
494                    for j in 0..total_features {
495                        let mut weighted_var = 0.0;
496                        for i in 0..n_samples {
497                            let diff = combined_data[[i, j]] - combined_means[[k, j]];
498                            weighted_var += responsibilities[[i, k]] * diff * diff;
499                        }
500                        combined_covariances[[k, j, j]] =
501                            (weighted_var / n_k[k] + self.config.regularization_strength).max(1e-6);
502                    }
503                }
504            }
505        }
506
507        // Convert back to modality-specific format for compatibility
508        let mut final_means = HashMap::new();
509        let mut final_covariances = HashMap::new();
510        let mut feature_start = 0;
511
512        for modality in &self.config.modalities {
513            let n_features = modality.n_features;
514            let modality_means = combined_means
515                .slice(s![.., feature_start..feature_start + n_features])
516                .to_owned();
517            let modality_cov_slice = combined_covariances
518                .slice(s![
519                    ..,
520                    feature_start..feature_start + n_features,
521                    feature_start..feature_start + n_features
522                ])
523                .to_owned();
524
525            final_means.insert(modality.name.clone(), modality_means);
526            final_covariances.insert(modality.name.clone(), modality_cov_slice);
527            feature_start += n_features;
528        }
529
530        let n_iter = log_likelihood_history.len();
531        Ok(MultiModalGaussianMixtureTrained {
532            weights,
533            modality_means: final_means,
534            modality_covariances: final_covariances,
535            shared_latent_means: None,
536            latent_projections: HashMap::new(),
537            coupling_parameters: Array2::zeros((0, 0)),
538            log_likelihood_history,
539            n_iter,
540            config: self.config.clone(),
541        })
542    }
543
544    /// Late fusion implementation: train separate models then combine
545    fn fit_late_fusion(
546        &self,
547        X: &HashMap<String, Array2<f64>>,
548        _y: &Option<Array1<usize>>,
549    ) -> SklResult<MultiModalGaussianMixtureTrained> {
550        let n_samples = X.values().next().unwrap().nrows();
551        let (weights, mut modality_means, mut modality_covariances) =
552            self.initialize_parameters(X)?;
553        let mut log_likelihood_history = Vec::new();
554
555        // Train each modality separately with EM
556        for modality in &self.config.modalities {
557            let data = &X[&modality.name];
558            let n_features = modality.n_features;
559            let n_components = self.config.n_components;
560
561            let mut modality_weights: Array1<f64> =
562                Array1::ones(n_components) / n_components as f64;
563            let mut means = modality_means[&modality.name].clone();
564            let mut covariances = modality_covariances[&modality.name].clone();
565
566            // EM for this modality
567            for _iter in 0..self.config.max_iter {
568                // E-step
569                let mut responsibilities = Array2::zeros((n_samples, n_components));
570
571                for i in 0..n_samples {
572                    let sample = data.row(i);
573                    let mut log_probs = Array1::zeros(n_components);
574
575                    for k in 0..n_components {
576                        let mean = means.row(k);
577                        let diff = &sample.to_owned() - &mean.to_owned();
578
579                        let log_det = match modality.covariance_type {
580                            CovarianceType::Full => covariances
581                                .slice(s![k, .., ..])
582                                .diag()
583                                .mapv(|x| x.ln())
584                                .sum(),
585                            CovarianceType::Diagonal => {
586                                covariances.slice(s![k, .., 0]).mapv(|x| x.ln()).sum()
587                            }
588                            CovarianceType::Spherical => {
589                                n_features as f64 * covariances[[k, 0, 0]].ln()
590                            }
591                            CovarianceType::Tied => covariances
592                                .slice(s![0, .., ..])
593                                .diag()
594                                .mapv(|x| x.ln())
595                                .sum(),
596                        };
597
598                        let inv_quad = match modality.covariance_type {
599                            CovarianceType::Spherical => diff.dot(&diff) / covariances[[k, 0, 0]],
600                            _ => diff.dot(&diff), // Simplified
601                        };
602
603                        log_probs[k] = modality_weights[k].ln()
604                            - 0.5
605                                * (n_features as f64 * (2.0 * std::f64::consts::PI).ln()
606                                    + log_det
607                                    + inv_quad);
608                    }
609
610                    // Normalize responsibilities
611                    let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
612                    let log_sum_exp =
613                        (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
614
615                    for k in 0..n_components {
616                        responsibilities[[i, k]] = ((log_probs[k] - log_sum_exp).exp()).max(1e-15);
617                    }
618                }
619
620                // M-step
621                let n_k: Array1<f64> = responsibilities.sum_axis(Axis(0));
622                modality_weights = &n_k / n_samples as f64;
623
624                // Update means
625                for k in 0..n_components {
626                    if n_k[k] > 1e-15 {
627                        let weighted_sum = responsibilities.column(k).iter().enumerate().fold(
628                            Array1::zeros(n_features),
629                            |mut acc, (i, &resp)| {
630                                let sample = data.row(i);
631                                for j in 0..n_features {
632                                    acc[j] += resp * sample[j];
633                                }
634                                acc
635                            },
636                        );
637                        means.row_mut(k).assign(&(weighted_sum / n_k[k]));
638                    }
639                }
640
641                // Update covariances (simplified diagonal update)
642                match modality.covariance_type {
643                    CovarianceType::Spherical => {
644                        for k in 0..n_components {
645                            if n_k[k] > 1e-15 {
646                                let mut weighted_var = 0.0;
647                                for i in 0..n_samples {
648                                    let sample = data.row(i);
649                                    let mean = means.row(k);
650                                    let diff = &sample.to_owned() - &mean.to_owned();
651                                    weighted_var += responsibilities[[i, k]] * diff.dot(&diff);
652                                }
653                                covariances[[k, 0, 0]] = (weighted_var
654                                    / (n_k[k] * n_features as f64)
655                                    + self.config.regularization_strength)
656                                    .max(1e-6);
657                            }
658                        }
659                    }
660                    _ => {
661                        // Diagonal covariance update
662                        for k in 0..n_components {
663                            if n_k[k] > 1e-15 {
664                                for j in 0..n_features {
665                                    let mut weighted_var = 0.0;
666                                    for i in 0..n_samples {
667                                        let diff = data[[i, j]] - means[[k, j]];
668                                        weighted_var += responsibilities[[i, k]] * diff * diff;
669                                    }
670                                    let var_idx = match modality.covariance_type {
671                                        CovarianceType::Diagonal => (k, j, 0),
672                                        _ => (k, j, j),
673                                    };
674                                    covariances[var_idx] = (weighted_var / n_k[k]
675                                        + self.config.regularization_strength)
676                                        .max(1e-6);
677                                }
678                            }
679                        }
680                    }
681                }
682            }
683
684            // Update the modality parameters
685            modality_means.insert(modality.name.clone(), means);
686            modality_covariances.insert(modality.name.clone(), covariances);
687        }
688
689        // Combine predictions from all modalities (simple averaging)
690        log_likelihood_history.push(0.0); // Placeholder
691
692        Ok(MultiModalGaussianMixtureTrained {
693            weights,
694            modality_means,
695            modality_covariances,
696            shared_latent_means: None,
697            latent_projections: HashMap::new(),
698            coupling_parameters: Array2::zeros((0, 0)),
699            log_likelihood_history,
700            n_iter: 1,
701            config: self.config.clone(),
702        })
703    }
704
705    /// Intermediate fusion: learn shared latent representation
706    fn fit_intermediate_fusion(
707        &self,
708        X: &HashMap<String, Array2<f64>>,
709        _y: &Option<Array1<usize>>,
710    ) -> SklResult<MultiModalGaussianMixtureTrained> {
711        let _n_samples = X.values().next().unwrap().nrows();
712        let latent_dim = self.config.shared_latent_dim.unwrap();
713        let (weights, modality_means, modality_covariances) = self.initialize_parameters(X)?;
714
715        // Initialize projection matrices for each modality to latent space
716        let mut latent_projections = HashMap::new();
717        let mut rng = thread_rng();
718
719        for modality in &self.config.modalities {
720            let normal = RandNormal::new(0.0, 0.1).map_err(|e| {
721                SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
722            })?;
723            let mut projection = Array2::zeros((latent_dim, modality.n_features));
724            for i in 0..latent_dim {
725                for j in 0..modality.n_features {
726                    projection[[i, j]] = rng.sample(normal);
727                }
728            }
729            latent_projections.insert(modality.name.clone(), projection);
730        }
731
732        // Initialize shared latent means
733        let mut shared_latent_means = Array2::zeros((self.config.n_components, latent_dim));
734        for k in 0..self.config.n_components {
735            for d in 0..latent_dim {
736                let normal = RandNormal::new(0.0, 1.0).map_err(|e| {
737                    SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
738                })?;
739                shared_latent_means[[k, d]] = rng.sample(normal);
740            }
741        }
742
743        let mut log_likelihood_history = Vec::new();
744        let mut coupling_parameters =
745            Array2::zeros((self.config.modalities.len(), self.config.modalities.len()));
746
747        // Initialize coupling parameters
748        for i in 0..self.config.modalities.len() {
749            coupling_parameters[[i, i]] = 1.0;
750            for j in (i + 1)..self.config.modalities.len() {
751                coupling_parameters[[i, j]] = self.config.coupling_strength;
752                coupling_parameters[[j, i]] = self.config.coupling_strength;
753            }
754        }
755
756        // Simplified intermediate fusion (would need more sophisticated implementation for production)
757        log_likelihood_history.push(0.0);
758
759        Ok(MultiModalGaussianMixtureTrained {
760            weights,
761            modality_means,
762            modality_covariances,
763            shared_latent_means: Some(shared_latent_means),
764            latent_projections,
765            coupling_parameters,
766            log_likelihood_history,
767            n_iter: 1,
768            config: self.config.clone(),
769        })
770    }
771
772    /// Coupled fusion: joint optimization across modalities
773    fn fit_coupled_fusion(
774        &self,
775        X: &HashMap<String, Array2<f64>>,
776        _y: &Option<Array1<usize>>,
777    ) -> SklResult<MultiModalGaussianMixtureTrained> {
778        let n_samples = X.values().next().unwrap().nrows();
779        let (mut weights, mut modality_means, mut modality_covariances) =
780            self.initialize_parameters(X)?;
781        let mut log_likelihood_history = Vec::new();
782
783        // Initialize coupling parameters between modalities
784        let n_modalities = self.config.modalities.len();
785        let mut coupling_parameters = Array2::zeros((n_modalities, n_modalities));
786
787        for i in 0..n_modalities {
788            coupling_parameters[[i, i]] = 1.0;
789            for j in (i + 1)..n_modalities {
790                coupling_parameters[[i, j]] = self.config.coupling_strength;
791                coupling_parameters[[j, i]] = self.config.coupling_strength;
792            }
793        }
794
795        // Coupled EM algorithm
796        for iter in 0..self.config.max_iter {
797            let old_log_likelihood = if log_likelihood_history.is_empty() {
798                f64::NEG_INFINITY
799            } else {
800                *log_likelihood_history.last().unwrap()
801            };
802
803            let mut total_log_likelihood = 0.0;
804            let mut global_responsibilities = Array2::zeros((n_samples, self.config.n_components));
805
806            // E-step: Compute responsibilities for each modality and combine
807            for (modality_idx, modality) in self.config.modalities.iter().enumerate() {
808                let data = &X[&modality.name];
809                let means = &modality_means[&modality.name];
810                let covariances = &modality_covariances[&modality.name];
811                let mut modality_responsibilities =
812                    Array2::zeros((n_samples, self.config.n_components));
813
814                for i in 0..n_samples {
815                    let sample = data.row(i);
816                    let mut log_probs = Array1::zeros(self.config.n_components);
817
818                    for k in 0..self.config.n_components {
819                        let mean = means.row(k);
820                        let diff = &sample.to_owned() - &mean.to_owned();
821
822                        let (log_det, inv_quad) = match modality.covariance_type {
823                            CovarianceType::Spherical => {
824                                let variance = covariances[[k, 0, 0]];
825                                let log_det = modality.n_features as f64 * variance.ln();
826                                let inv_quad = diff.dot(&diff) / variance;
827                                (log_det, inv_quad)
828                            }
829                            _ => {
830                                // Simplified diagonal covariance
831                                let log_det = (0..modality.n_features)
832                                    .map(|j| {
833                                        covariances[[k, j, 0.min(covariances.dim().2 - 1)]].ln()
834                                    })
835                                    .sum::<f64>();
836                                let inv_quad = diff.dot(&diff); // Simplified
837                                (log_det, inv_quad)
838                            }
839                        };
840
841                        log_probs[k] = weights[k].ln()
842                            - 0.5
843                                * (modality.n_features as f64 * (2.0 * std::f64::consts::PI).ln()
844                                    + log_det
845                                    + inv_quad);
846                    }
847
848                    // Numerical stability
849                    let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
850                    let log_sum_exp =
851                        (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
852                    total_log_likelihood += log_sum_exp * modality.modality_weight;
853
854                    for k in 0..self.config.n_components {
855                        modality_responsibilities[[i, k]] =
856                            ((log_probs[k] - log_sum_exp).exp()).max(1e-15);
857                    }
858                }
859
860                // Combine with coupling
861                for i in 0..n_samples {
862                    for k in 0..self.config.n_components {
863                        global_responsibilities[[i, k]] += modality.modality_weight
864                            * coupling_parameters[[modality_idx, modality_idx]]
865                            * modality_responsibilities[[i, k]];
866                    }
867                }
868            }
869
870            // Normalize global responsibilities
871            for i in 0..n_samples {
872                let sum: f64 = global_responsibilities.row(i).sum();
873                if sum > 1e-15 {
874                    global_responsibilities.row_mut(i).mapv_inplace(|x| x / sum);
875                }
876            }
877
878            log_likelihood_history.push(total_log_likelihood);
879
880            // Check convergence
881            if iter > 0 && (total_log_likelihood - old_log_likelihood).abs() < self.config.tol {
882                break;
883            }
884
885            // M-step: Update parameters using global responsibilities
886            let n_k: Array1<f64> = global_responsibilities.sum_axis(Axis(0));
887            weights = &n_k / n_samples as f64;
888
889            // Update means and covariances for each modality
890            for modality in &self.config.modalities {
891                let data = &X[&modality.name];
892                let mut means = modality_means[&modality.name].clone();
893                let mut covariances = modality_covariances[&modality.name].clone();
894
895                // Update means
896                for k in 0..self.config.n_components {
897                    if n_k[k] > 1e-15 {
898                        let weighted_sum = global_responsibilities
899                            .column(k)
900                            .iter()
901                            .enumerate()
902                            .fold(Array1::zeros(modality.n_features), |mut acc, (i, &resp)| {
903                                let sample = data.row(i);
904                                for j in 0..modality.n_features {
905                                    acc[j] += resp * sample[j];
906                                }
907                                acc
908                            });
909                        means.row_mut(k).assign(&(weighted_sum / n_k[k]));
910                    }
911                }
912
913                // Update covariances
914                match modality.covariance_type {
915                    CovarianceType::Spherical => {
916                        for k in 0..self.config.n_components {
917                            if n_k[k] > 1e-15 {
918                                let mut weighted_var = 0.0;
919                                for i in 0..n_samples {
920                                    let sample = data.row(i);
921                                    let mean = means.row(k);
922                                    let diff = &sample.to_owned() - &mean.to_owned();
923                                    weighted_var +=
924                                        global_responsibilities[[i, k]] * diff.dot(&diff);
925                                }
926                                covariances[[k, 0, 0]] = (weighted_var
927                                    / (n_k[k] * modality.n_features as f64)
928                                    + self.config.regularization_strength)
929                                    .max(1e-6);
930                            }
931                        }
932                    }
933                    _ => {
934                        for k in 0..self.config.n_components {
935                            if n_k[k] > 1e-15 {
936                                for j in 0..modality.n_features {
937                                    let mut weighted_var = 0.0;
938                                    for i in 0..n_samples {
939                                        let diff = data[[i, j]] - means[[k, j]];
940                                        weighted_var +=
941                                            global_responsibilities[[i, k]] * diff * diff;
942                                    }
943                                    let var_idx = match modality.covariance_type {
944                                        CovarianceType::Diagonal => (k, j, 0),
945                                        _ => (k, j, j),
946                                    };
947                                    covariances[var_idx] = (weighted_var / n_k[k]
948                                        + self.config.regularization_strength)
949                                        .max(1e-6);
950                                }
951                            }
952                        }
953                    }
954                }
955
956                modality_means.insert(modality.name.clone(), means);
957                modality_covariances.insert(modality.name.clone(), covariances);
958            }
959        }
960
961        let n_iter = log_likelihood_history.len();
962        Ok(MultiModalGaussianMixtureTrained {
963            weights,
964            modality_means,
965            modality_covariances,
966            shared_latent_means: None,
967            latent_projections: HashMap::new(),
968            coupling_parameters,
969            log_likelihood_history,
970            n_iter,
971            config: self.config.clone(),
972        })
973    }
974}
975
976impl Predict<HashMap<String, Array2<f64>>, Array1<usize>> for MultiModalGaussianMixtureTrained {
977    fn predict(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<Array1<usize>> {
978        let probabilities = self.predict_proba(X)?;
979        let n_samples = probabilities.nrows();
980        let mut predictions = Array1::zeros(n_samples);
981
982        for i in 0..n_samples {
983            let mut max_prob = 0.0;
984            let mut best_component = 0;
985
986            for k in 0..self.config.n_components {
987                if probabilities[[i, k]] > max_prob {
988                    max_prob = probabilities[[i, k]];
989                    best_component = k;
990                }
991            }
992            predictions[i] = best_component;
993        }
994
995        Ok(predictions)
996    }
997}
998
999impl MultiModalGaussianMixtureTrained {
1000    /// Predict class probabilities for multi-modal data
1001    pub fn predict_proba(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<Array2<f64>> {
1002        // Validate input
1003        for modality in &self.config.modalities {
1004            if !X.contains_key(&modality.name) {
1005                return Err(SklearsError::InvalidInput(format!(
1006                    "Missing data for modality: {}",
1007                    modality.name
1008                )));
1009            }
1010        }
1011
1012        let n_samples = X.values().next().unwrap().nrows();
1013        let mut probabilities = Array2::zeros((n_samples, self.config.n_components));
1014
1015        match self.config.fusion_strategy {
1016            FusionStrategy::EarlyFusion => {
1017                // Concatenate features and compute probabilities
1018                self.predict_proba_early_fusion(X, &mut probabilities)?;
1019            }
1020            FusionStrategy::LateFusion => {
1021                // Average probabilities from each modality
1022                self.predict_proba_late_fusion(X, &mut probabilities)?;
1023            }
1024            FusionStrategy::IntermediateFusion => {
1025                // Use shared latent representation
1026                self.predict_proba_intermediate_fusion(X, &mut probabilities)?;
1027            }
1028            FusionStrategy::CoupledFusion => {
1029                // Joint prediction across modalities
1030                self.predict_proba_coupled_fusion(X, &mut probabilities)?;
1031            }
1032        }
1033
1034        Ok(probabilities)
1035    }
1036
1037    fn predict_proba_early_fusion(
1038        &self,
1039        _X: &HashMap<String, Array2<f64>>,
1040        probabilities: &mut Array2<f64>,
1041    ) -> SklResult<()> {
1042        // This would concatenate features and compute GMM probabilities
1043        // For now, simplified implementation
1044        let n_samples = probabilities.nrows();
1045        for i in 0..n_samples {
1046            probabilities.row_mut(i).assign(&self.weights);
1047        }
1048        Ok(())
1049    }
1050
1051    fn predict_proba_late_fusion(
1052        &self,
1053        X: &HashMap<String, Array2<f64>>,
1054        probabilities: &mut Array2<f64>,
1055    ) -> SklResult<()> {
1056        let n_samples = probabilities.nrows();
1057        probabilities.fill(0.0);
1058
1059        // Average predictions from each modality
1060        for modality in &self.config.modalities {
1061            let data = &X[&modality.name];
1062            let means = &self.modality_means[&modality.name];
1063
1064            for i in 0..n_samples {
1065                let sample = data.row(i);
1066                let mut modality_probs = Array1::zeros(self.config.n_components);
1067
1068                for k in 0..self.config.n_components {
1069                    let mean = means.row(k);
1070                    let diff = &sample.to_owned() - &mean.to_owned();
1071                    let log_prob = self.weights[k].ln() - 0.5 * diff.dot(&diff);
1072                    modality_probs[k] = log_prob.exp();
1073                }
1074
1075                // Normalize
1076                let sum: f64 = modality_probs.sum();
1077                if sum > 1e-15 {
1078                    modality_probs.mapv_inplace(|x| x / sum);
1079                }
1080
1081                // Add weighted contribution
1082                for k in 0..self.config.n_components {
1083                    probabilities[[i, k]] += modality.modality_weight * modality_probs[k];
1084                }
1085            }
1086        }
1087
1088        // Final normalization
1089        for i in 0..n_samples {
1090            let sum: f64 = probabilities.row(i).sum();
1091            if sum > 1e-15 {
1092                probabilities.row_mut(i).mapv_inplace(|x| x / sum);
1093            }
1094        }
1095
1096        Ok(())
1097    }
1098
1099    fn predict_proba_intermediate_fusion(
1100        &self,
1101        _X: &HashMap<String, Array2<f64>>,
1102        probabilities: &mut Array2<f64>,
1103    ) -> SklResult<()> {
1104        // Project to latent space and compute probabilities
1105        // Simplified implementation
1106        let n_samples = probabilities.nrows();
1107        for i in 0..n_samples {
1108            probabilities.row_mut(i).assign(&self.weights);
1109        }
1110        Ok(())
1111    }
1112
1113    fn predict_proba_coupled_fusion(
1114        &self,
1115        X: &HashMap<String, Array2<f64>>,
1116        probabilities: &mut Array2<f64>,
1117    ) -> SklResult<()> {
1118        // Use coupling parameters to combine modality predictions
1119        self.predict_proba_late_fusion(X, probabilities)?;
1120
1121        // Apply coupling transformations (simplified)
1122        let n_samples = probabilities.nrows();
1123        for i in 0..n_samples {
1124            for k in 0..self.config.n_components {
1125                probabilities[[i, k]] *= 1.0 + self.config.coupling_strength;
1126            }
1127
1128            // Renormalize
1129            let sum: f64 = probabilities.row(i).sum();
1130            if sum > 1e-15 {
1131                probabilities.row_mut(i).mapv_inplace(|x| x / sum);
1132            }
1133        }
1134
1135        Ok(())
1136    }
1137
1138    /// Compute log-likelihood of the data
1139    pub fn score(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<f64> {
1140        let probabilities = self.predict_proba(X)?;
1141        let log_likelihood = probabilities.mapv(|p| p.max(1e-15).ln()).sum();
1142        Ok(log_likelihood)
1143    }
1144
1145    /// Get model selection criteria
1146    pub fn model_selection(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<ModelSelection> {
1147        let n_samples = X.values().next().unwrap().nrows();
1148        let total_features: usize = self.config.modalities.iter().map(|m| m.n_features).sum();
1149
1150        // Simplified parameter counting
1151        let n_parameters = ModelSelection::n_parameters(
1152            self.config.n_components,
1153            total_features,
1154            &CovarianceType::Full,
1155        );
1156
1157        let log_likelihood = self.score(X)?;
1158        let aic = ModelSelection::aic(log_likelihood, n_parameters);
1159        let bic = ModelSelection::bic(log_likelihood, n_parameters, n_samples);
1160
1161        Ok(ModelSelection {
1162            aic,
1163            bic,
1164            log_likelihood,
1165            n_parameters,
1166        })
1167    }
1168}
1169
1170#[allow(non_snake_case)]
1171#[cfg(test)]
1172mod tests {
1173    use super::*;
1174    use approx::assert_abs_diff_eq;
1175
1176    fn create_test_multi_modal_data() -> HashMap<String, Array2<f64>> {
1177        let mut data = HashMap::new();
1178
1179        // Visual modality: 2D features
1180        let visual_data =
1181            Array2::from_shape_vec((100, 2), (0..200).map(|i| i as f64 * 0.1).collect()).unwrap();
1182        data.insert("visual".to_string(), visual_data);
1183
1184        // Textual modality: 3D features
1185        let textual_data = Array2::from_shape_vec(
1186            (100, 3),
1187            (0..300).map(|i| (i as f64 * 0.05).sin()).collect(),
1188        )
1189        .unwrap();
1190        data.insert("textual".to_string(), textual_data);
1191
1192        data
1193    }
1194
1195    #[test]
1196    fn test_multi_modal_builder() {
1197        let model = MultiModalGaussianMixtureBuilder::new(3)
1198            .add_modality_simple("visual", 2)
1199            .add_modality_simple("textual", 3)
1200            .fusion_strategy(FusionStrategy::EarlyFusion)
1201            .coupling_strength(0.2)
1202            .max_iter(10)
1203            .build()
1204            .unwrap();
1205
1206        assert_eq!(model.config.n_components, 3);
1207        assert_eq!(model.config.modalities.len(), 2);
1208        assert_eq!(model.config.fusion_strategy, FusionStrategy::EarlyFusion);
1209        assert_abs_diff_eq!(model.config.coupling_strength, 0.2, epsilon = 1e-10);
1210    }
1211
1212    #[test]
1213    fn test_early_fusion_fit() {
1214        let data = create_test_multi_modal_data();
1215        let model = MultiModalGaussianMixtureBuilder::new(2)
1216            .add_modality_simple("visual", 2)
1217            .add_modality_simple("textual", 3)
1218            .fusion_strategy(FusionStrategy::EarlyFusion)
1219            .max_iter(5)
1220            .build()
1221            .unwrap();
1222
1223        let trained = model.fit(&data, &None).unwrap();
1224
1225        assert_eq!(trained.weights.len(), 2);
1226        assert!(trained.modality_means.contains_key("visual"));
1227        assert!(trained.modality_means.contains_key("textual"));
1228        assert_eq!(trained.modality_means["visual"].nrows(), 2); // n_components
1229        assert_eq!(trained.modality_means["visual"].ncols(), 2); // n_features
1230    }
1231
1232    #[test]
1233    fn test_late_fusion_fit() {
1234        let data = create_test_multi_modal_data();
1235        let model = MultiModalGaussianMixtureBuilder::new(2)
1236            .add_modality_simple("visual", 2)
1237            .add_modality_simple("textual", 3)
1238            .fusion_strategy(FusionStrategy::LateFusion)
1239            .max_iter(5)
1240            .build()
1241            .unwrap();
1242
1243        let trained = model.fit(&data, &None).unwrap();
1244
1245        assert_eq!(trained.weights.len(), 2);
1246        assert!(trained.modality_means.contains_key("visual"));
1247        assert!(trained.modality_means.contains_key("textual"));
1248    }
1249
1250    #[test]
1251    fn test_intermediate_fusion_fit() {
1252        let data = create_test_multi_modal_data();
1253        let model = MultiModalGaussianMixtureBuilder::new(2)
1254            .add_modality_simple("visual", 2)
1255            .add_modality_simple("textual", 3)
1256            .fusion_strategy(FusionStrategy::IntermediateFusion)
1257            .shared_latent_dim(4)
1258            .max_iter(5)
1259            .build()
1260            .unwrap();
1261
1262        let trained = model.fit(&data, &None).unwrap();
1263
1264        assert_eq!(trained.weights.len(), 2);
1265        assert!(trained.shared_latent_means.is_some());
1266        let latent_means = trained.shared_latent_means.as_ref().unwrap();
1267        assert_eq!(latent_means.nrows(), 2); // n_components
1268        assert_eq!(latent_means.ncols(), 4); // latent_dim
1269    }
1270
1271    #[test]
1272    fn test_coupled_fusion_fit() {
1273        let data = create_test_multi_modal_data();
1274        let model = MultiModalGaussianMixtureBuilder::new(2)
1275            .add_modality_simple("visual", 2)
1276            .add_modality_simple("textual", 3)
1277            .fusion_strategy(FusionStrategy::CoupledFusion)
1278            .coupling_strength(0.3)
1279            .max_iter(5)
1280            .build()
1281            .unwrap();
1282
1283        let trained = model.fit(&data, &None).unwrap();
1284
1285        assert_eq!(trained.weights.len(), 2);
1286        assert_eq!(trained.coupling_parameters.nrows(), 2); // n_modalities
1287        assert_eq!(trained.coupling_parameters.ncols(), 2); // n_modalities
1288    }
1289
1290    #[test]
1291    fn test_prediction() {
1292        let data = create_test_multi_modal_data();
1293        let model = MultiModalGaussianMixtureBuilder::new(2)
1294            .add_modality_simple("visual", 2)
1295            .add_modality_simple("textual", 3)
1296            .fusion_strategy(FusionStrategy::LateFusion)
1297            .max_iter(3)
1298            .build()
1299            .unwrap();
1300
1301        let trained = model.fit(&data, &None).unwrap();
1302        let predictions = trained.predict(&data).unwrap();
1303
1304        assert_eq!(predictions.len(), 100);
1305
1306        // All predictions should be 0 or 1 (since we have 2 components)
1307        for &pred in predictions.iter() {
1308            assert!(pred < 2);
1309        }
1310    }
1311
1312    #[test]
1313    fn test_predict_proba() {
1314        let data = create_test_multi_modal_data();
1315        let model = MultiModalGaussianMixtureBuilder::new(3)
1316            .add_modality_simple("visual", 2)
1317            .add_modality_simple("textual", 3)
1318            .fusion_strategy(FusionStrategy::LateFusion)
1319            .max_iter(3)
1320            .build()
1321            .unwrap();
1322
1323        let trained = model.fit(&data, &None).unwrap();
1324        let probabilities = trained.predict_proba(&data).unwrap();
1325
1326        assert_eq!(probabilities.nrows(), 100);
1327        assert_eq!(probabilities.ncols(), 3);
1328
1329        // Probabilities should sum to 1 for each sample
1330        for i in 0..100 {
1331            let sum: f64 = probabilities.row(i).sum();
1332            assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
1333        }
1334    }
1335
1336    #[test]
1337    fn test_model_selection() {
1338        let data = create_test_multi_modal_data();
1339        let model = MultiModalGaussianMixtureBuilder::new(2)
1340            .add_modality_simple("visual", 2)
1341            .add_modality_simple("textual", 3)
1342            .fusion_strategy(FusionStrategy::EarlyFusion)
1343            .max_iter(3)
1344            .build()
1345            .unwrap();
1346
1347        let trained = model.fit(&data, &None).unwrap();
1348        let model_selection = trained.model_selection(&data).unwrap();
1349
1350        assert!(model_selection.log_likelihood.is_finite());
1351        assert!(model_selection.aic.is_finite());
1352        assert!(model_selection.bic.is_finite());
1353        assert!(model_selection.n_parameters > 0);
1354    }
1355
1356    #[test]
1357    fn test_validation_missing_modality() {
1358        let mut data = create_test_multi_modal_data();
1359        data.remove("textual"); // Remove one modality
1360
1361        let model = MultiModalGaussianMixtureBuilder::new(2)
1362            .add_modality_simple("visual", 2)
1363            .add_modality_simple("textual", 3)
1364            .fusion_strategy(FusionStrategy::EarlyFusion) // Use early fusion to avoid latent dim requirement
1365            .build()
1366            .unwrap();
1367
1368        let result = model.fit(&data, &None);
1369        assert!(result.is_err());
1370    }
1371
1372    #[test]
1373    fn test_validation_feature_dimension_mismatch() {
1374        let data = create_test_multi_modal_data();
1375
1376        let model = MultiModalGaussianMixtureBuilder::new(2)
1377            .add_modality_simple("visual", 5) // Wrong dimension
1378            .add_modality_simple("textual", 3)
1379            .fusion_strategy(FusionStrategy::EarlyFusion) // Use early fusion to avoid latent dim requirement
1380            .build()
1381            .unwrap();
1382
1383        let result = model.fit(&data, &None);
1384        assert!(result.is_err());
1385    }
1386
1387    #[test]
1388    fn test_intermediate_fusion_requires_latent_dim() {
1389        let result = MultiModalGaussianMixtureBuilder::new(2)
1390            .add_modality_simple("visual", 2)
1391            .fusion_strategy(FusionStrategy::IntermediateFusion)
1392            // Missing shared_latent_dim
1393            .build();
1394
1395        assert!(result.is_err());
1396    }
1397}