Skip to main content

scirs2_neural/models/architectures/
fusion.rs

1//! Feature Fusion Model Architectures
2//!
3//! This module implements various feature fusion approaches for multi-modal learning,
4//! allowing models to combine features from different modalities (e.g., vision, text, audio).
5
6use crate::error::{NeuralError, Result};
7use crate::layers::{Dense, Dropout, Layer, LayerNorm, Sequential};
8use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
9use scirs2_core::numeric::{Float, NumAssign};
10use scirs2_core::random::SeedableRng;
11use scirs2_core::simd_ops::SimdUnifiedOps;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14/// Fusion methods for multi-modal inputs
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum FusionMethod {
17    /// Concatenate features from different modalities
18    Concatenation,
19    /// Element-wise sum of features (requires same dimensions)
20    Sum,
21    /// Element-wise product of features (requires same dimensions)
22    Product,
23    /// Gated attention mechanism between modalities
24    Attention,
25    /// Bilinear fusion (outer product)
26    Bilinear,
27    /// FiLM conditioning (Feature-wise Linear Modulation)
28    FiLM,
29}
30/// Configuration for the Feature Fusion model
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FeatureFusionConfig {
33    /// Dimensions of each input modality
34    pub input_dims: Vec<usize>,
35    /// Hidden dimension for alignment (if needed)
36    pub hidden_dim: usize,
37    /// Fusion method to use
38    pub fusion_method: FusionMethod,
39    /// Dropout rate
40    pub dropout_rate: f64,
41    /// Number of output classes (if applicable)
42    pub num_classes: usize,
43    /// Whether to include the classifier head
44    pub include_head: bool,
45}
46
47/// Feature alignment module
48#[derive(Debug, Clone)]
49pub struct FeatureAlignment<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
50where
51    F: SimdUnifiedOps,
52{
53    /// Input dimension
54    pub input_dim: usize,
55    /// Output dimension for alignment
56    pub output_dim: usize,
57    /// Linear projection layer
58    pub projection: Dense<F>,
59    /// Normalization layer
60    pub norm: LayerNorm<F>,
61}
62
63impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FeatureAlignment<F>
64where
65    F: SimdUnifiedOps,
66{
67    /// Create a new FeatureAlignment module
68    pub fn new(input_dim: usize, output_dim: usize, _name: Option<&str>) -> Result<Self> {
69        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
70        let projection = Dense::<F>::new(input_dim, output_dim, None, &mut rng)?;
71        let norm = LayerNorm::<F>::new(output_dim, 1e-6, &mut rng)?;
72        Ok(Self {
73            input_dim,
74            output_dim,
75            projection,
76            norm,
77        })
78    }
79}
80
81impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FeatureAlignment<F>
82where
83    F: SimdUnifiedOps,
84{
85    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
86        let x = self.projection.forward(input)?;
87        let x = self.norm.forward(&x)?;
88        Ok(x)
89    }
90
91    fn as_any(&self) -> &dyn std::any::Any {
92        self
93    }
94
95    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
96        self
97    }
98
99    fn backward(
100        &self,
101        input: &Array<F, IxDyn>,
102        grad_output: &Array<F, IxDyn>,
103    ) -> Result<Array<F, IxDyn>> {
104        // Backward pass through the alignment layer (Dense -> LayerNorm)
105        // First, get the intermediate output from the projection
106        let proj_output = self.projection.forward(input)?;
107        // Backward through LayerNorm
108        let grad_proj = self.norm.backward(&proj_output, grad_output)?;
109        // Backward through Dense projection
110        let grad_input = self.projection.backward(input, &grad_proj)?;
111        Ok(grad_input)
112    }
113
114    fn update(&mut self, learning_rate: F) -> Result<()> {
115        // Update the Dense projection layer
116        self.projection.update(learning_rate)?;
117        // Update the LayerNorm layer
118        self.norm.update(learning_rate)?;
119        Ok(())
120    }
121
122    fn params(&self) -> Vec<Array<F, IxDyn>> {
123        let mut params = Vec::new();
124        params.extend(self.projection.params());
125        params.extend(self.norm.params());
126        params
127    }
128
129    fn set_training(&mut self, training: bool) {
130        self.projection.set_training(training);
131        self.norm.set_training(training);
132    }
133
134    fn is_training(&self) -> bool {
135        self.projection.is_training()
136    }
137}
138
139/// Cross-Modal Attention module
140#[derive(Debug, Clone)]
141pub struct CrossModalAttention<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
142    /// Query projection
143    pub query_proj: Dense<F>,
144    /// Key projection
145    pub key_proj: Dense<F>,
146    /// Value projection
147    pub value_proj: Dense<F>,
148    /// Output projection
149    pub output_proj: Dense<F>,
150    /// Hidden dimension
151    pub hidden_dim: usize,
152    /// Scale factor for attention
153    pub scale: F,
154}
155
156impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> CrossModalAttention<F> {
157    /// Create a new CrossModalAttention module
158    pub fn new(query_dim: usize, key_dim: usize, hidden_dim: usize) -> Result<Self> {
159        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
160        let query_proj = Dense::<F>::new(query_dim, hidden_dim, None, &mut rng)?;
161        let key_proj = Dense::<F>::new(key_dim, hidden_dim, None, &mut rng)?;
162        let value_proj = Dense::<F>::new(key_dim, hidden_dim, None, &mut rng)?;
163        let output_proj = Dense::<F>::new(hidden_dim, query_dim, None, &mut rng)?;
164        // Scale factor for dot product attention
165        let scale = F::from(1.0 / (hidden_dim as f64).sqrt()).expect("Operation failed");
166        Ok(Self {
167            query_proj,
168            key_proj,
169            value_proj,
170            output_proj,
171            hidden_dim,
172            scale,
173        })
174    }
175
176    /// Forward pass for cross-modal attention
177    pub fn forward(
178        &self,
179        query: &Array<F, IxDyn>,
180        context: &Array<F, IxDyn>,
181    ) -> Result<Array<F, IxDyn>> {
182        // Project query, key, and value
183        let q = self.query_proj.forward(query)?;
184        let k = self.key_proj.forward(context)?;
185        let v = self.value_proj.forward(context)?;
186        // Reshape for easier computation
187        let batch_size = q.shape()[0];
188        let query_len = q.shape()[1];
189        let context_len = k.shape()[1];
190        let q_2d = q
191            .clone()
192            .into_shape_with_order((batch_size * query_len, self.hidden_dim))?;
193        let k_2d = k.into_shape_with_order((batch_size * context_len, self.hidden_dim))?;
194        let v_2d = v.into_shape_with_order((batch_size * context_len, self.hidden_dim))?;
195
196        // Compute attention scores
197        let scores = q_2d.dot(&k_2d.t()) * self.scale;
198        // Reshape scores to (batch_size, query_len, context_len)
199        let scores_3d = scores.into_shape_with_order((batch_size, query_len, context_len))?;
200        // Apply softmax along the context dimension
201        let mut attention_weights = scores_3d.to_owned().into_dyn();
202        attention_weights.fill(F::zero());
203        for b in 0..batch_size {
204            for q in 0..query_len {
205                let mut row = scores_3d
206                    .slice(scirs2_core::ndarray::s![b, q, ..])
207                    .to_owned();
208                // Find max for numerical stability
209                let max_val = row.fold(F::neg_infinity(), |m: F, &v: &F| m.max(v));
210                // Compute exp and sum
211                let mut exp_sum = F::zero();
212                for i in 0..context_len {
213                    let exp_val = (row[i] - max_val).exp();
214                    row[i] = exp_val;
215                    exp_sum += exp_val;
216                }
217                // Normalize
218                if exp_sum > F::zero() {
219                    for i in 0..context_len {
220                        row[i] /= exp_sum;
221                    }
222                }
223                // Copy normalized weights
224                for i in 0..context_len {
225                    attention_weights[[b, q, i]] = row[i];
226                }
227            }
228        }
229
230        // Reshape attention weights for matrix multiplication
231        let attn_weights_2d = attention_weights
232            .into_shape_with_order((batch_size * query_len, batch_size * context_len))?;
233        // Apply attention weights to values
234        let context_vec = attn_weights_2d.dot(&v_2d);
235        // Reshape and project output
236        let context_vec_reshaped =
237            context_vec.into_shape_with_order((batch_size, query_len, self.hidden_dim))?;
238        // Final projection
239        let output = self.output_proj.forward(&context_vec_reshaped.into_dyn())?;
240        Ok(output)
241    }
242}
243
244impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F>
245    for CrossModalAttention<F>
246{
247    fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
248        // This assumes the input contains both query and context packed together
249        // In practical use, use the dedicated forward method with separate inputs
250        Err(NeuralError::ValidationError(
251            "CrossModalAttention requires separate query and context inputs. Use the dedicated forward method."
252                .to_string(),
253        ))
254    }
255
256    fn as_any(&self) -> &dyn std::any::Any {
257        self
258    }
259
260    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
261        self
262    }
263
264    fn backward(
265        &self,
266        _input: &Array<F, IxDyn>,
267        grad_output: &Array<F, IxDyn>,
268    ) -> Result<Array<F, IxDyn>> {
269        // For CrossModalAttention, the backward pass is complex because it involves
270        // two separate inputs (query and context). Since the Layer trait only provides
271        // one input, we cannot properly implement backward for the general case.
272        // This would require a custom backward method that takes both query and context.
273        // For now, we return a gradient with the same shape as the expected query input.
274        // Create a gradient tensor with appropriate shape
275        // This is a simplified implementation - a proper implementation would need
276        // to propagate gradients through the attention mechanism
277        Ok(grad_output.clone())
278    }
279
280    fn update(&mut self, learning_rate: F) -> Result<()> {
281        // Update all projection layers
282        self.query_proj.update(learning_rate)?;
283        self.key_proj.update(learning_rate)?;
284        self.value_proj.update(learning_rate)?;
285        self.output_proj.update(learning_rate)?;
286        Ok(())
287    }
288
289    fn params(&self) -> Vec<Array<F, IxDyn>> {
290        let mut params = Vec::new();
291        params.extend(self.query_proj.params());
292        params.extend(self.key_proj.params());
293        params.extend(self.value_proj.params());
294        params.extend(self.output_proj.params());
295        params
296    }
297
298    fn set_training(&mut self, training: bool) {
299        self.query_proj.set_training(training);
300        self.key_proj.set_training(training);
301        self.value_proj.set_training(training);
302        self.output_proj.set_training(training);
303    }
304
305    fn is_training(&self) -> bool {
306        self.query_proj.is_training()
307    }
308}
309
310/// FiLM (Feature-wise Linear Modulation) conditioning module
311#[derive(Debug, Clone)]
312pub struct FiLMModule<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
313    /// Feature dimension to be modulated
314    pub feature_dim: usize,
315    /// Conditioning input dimension
316    pub cond_dim: usize,
317    /// Gamma (scale) projection
318    pub gamma_proj: Dense<F>,
319    /// Beta (shift) projection
320    pub beta_proj: Dense<F>,
321}
322
323impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FiLMModule<F> {
324    /// Create a new FiLMModule
325    pub fn new(feature_dim: usize, cond_dim: usize) -> Result<Self> {
326        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
327        let gamma_proj = Dense::<F>::new(cond_dim, feature_dim, None, &mut rng)?;
328        let beta_proj = Dense::<F>::new(cond_dim, feature_dim, None, &mut rng)?;
329        Ok(Self {
330            feature_dim,
331            cond_dim,
332            gamma_proj,
333            beta_proj,
334        })
335    }
336
337    /// Forward pass with separate feature and conditioning inputs
338    pub fn forward(
339        &self,
340        features: &Array<F, IxDyn>,
341        conditioning: &Array<F, IxDyn>,
342    ) -> Result<Array<F, IxDyn>> {
343        // Generate gamma and beta for modulation
344        let gamma = self.gamma_proj.forward(conditioning)?;
345        let beta = self.beta_proj.forward(conditioning)?;
346        // Apply FiLM: gamma * features + beta
347        let modulated = &gamma * features + &beta;
348        Ok(modulated)
349    }
350}
351
352impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FiLMModule<F> {
353    fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
354        // This assumes the input contains both features and conditioning packed together
355        Err(NeuralError::ValidationError(
356            "FiLMModule requires separate feature and conditioning inputs. Use the dedicated forward method."
357                .to_string(),
358        ))
359    }
360
361    fn as_any(&self) -> &dyn std::any::Any {
362        self
363    }
364
365    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
366        self
367    }
368
369    fn backward(
370        &self,
371        _input: &Array<F, IxDyn>,
372        grad_output: &Array<F, IxDyn>,
373    ) -> Result<Array<F, IxDyn>> {
374        // For FiLMModule, the backward pass is complex because it involves
375        // two separate inputs (features and conditioning). Since the Layer trait only provides
376        // This would require a custom backward method that takes both inputs.
377        // For now, we return a gradient with the same shape as the expected feature input.
378        // to propagate gradients through the FiLM operation (gamma * features + beta)
379        Ok(grad_output.clone())
380    }
381
382    fn update(&mut self, learning_rate: F) -> Result<()> {
383        // Update gamma and beta projection layers
384        self.gamma_proj.update(learning_rate)?;
385        self.beta_proj.update(learning_rate)?;
386        Ok(())
387    }
388
389    fn params(&self) -> Vec<Array<F, IxDyn>> {
390        let mut params = Vec::new();
391        params.extend(self.gamma_proj.params());
392        params.extend(self.beta_proj.params());
393        params
394    }
395
396    fn set_training(&mut self, training: bool) {
397        self.gamma_proj.set_training(training);
398        self.beta_proj.set_training(training);
399    }
400
401    fn is_training(&self) -> bool {
402        self.gamma_proj.is_training()
403    }
404}
405
406/// Bilinear Fusion module for pairwise interactions between modalities
407#[derive(Debug, Clone)]
408pub struct BilinearFusion<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
409    /// First modality dimension
410    pub dim_a: usize,
411    /// Second modality dimension
412    pub dim_b: usize,
413    /// Output dimension
414    pub output_dim: usize,
415    /// Projection from A
416    pub proj_a: Dense<F>,
417    /// Projection from B
418    pub proj_b: Dense<F>,
419    /// Low-rank projection to output
420    pub low_rank_proj: Dense<F>,
421}
422
423impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> BilinearFusion<F> {
424    /// Create a new BilinearFusion module
425    pub fn new(dim_a: usize, dim_b: usize, output_dim: usize, rank: usize) -> Result<Self> {
426        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
427        let proj_a = Dense::<F>::new(dim_a, rank, None, &mut rng)?;
428        let proj_b = Dense::<F>::new(dim_b, rank, None, &mut rng)?;
429        let low_rank_proj = Dense::<F>::new(rank, output_dim, None, &mut rng)?;
430        Ok(Self {
431            dim_a,
432            dim_b,
433            output_dim,
434            proj_a,
435            proj_b,
436            low_rank_proj,
437        })
438    }
439
440    /// Forward pass with separate modality inputs
441    pub fn forward(
442        &self,
443        features_a: &Array<F, IxDyn>,
444        features_b: &Array<F, IxDyn>,
445    ) -> Result<Array<F, IxDyn>> {
446        // Project inputs to a common low-rank space
447        let a_proj = self.proj_a.forward(features_a)?;
448        let b_proj = self.proj_b.forward(features_b)?;
449        // Element-wise product for bilinear interaction
450        let bilinear = &a_proj * &b_proj;
451        let output = self.low_rank_proj.forward(&bilinear)?;
452        Ok(output)
453    }
454}
455
456impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for BilinearFusion<F> {
457    fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
458        // This assumes the input contains both feature sets packed together
459        Err(NeuralError::ValidationError(
460            "BilinearFusion requires separate feature inputs. Use the dedicated forward method."
461                .to_string(),
462        ))
463    }
464
465    fn as_any(&self) -> &dyn std::any::Any {
466        self
467    }
468
469    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
470        self
471    }
472
473    fn backward(
474        &self,
475        _input: &Array<F, IxDyn>,
476        grad_output: &Array<F, IxDyn>,
477    ) -> Result<Array<F, IxDyn>> {
478        // For BilinearFusion, the backward pass is complex because it involves
479        // two separate inputs (features_a and features_b). Since the Layer trait only provides
480        // This would require a custom backward method that takes both feature inputs.
481        // For now, we return a gradient with the same shape as the expected input.
482        // to propagate gradients through the bilinear interaction (proj_a * proj_b)
483        Ok(grad_output.clone())
484    }
485
486    fn update(&mut self, learning_rate: F) -> Result<()> {
487        self.proj_a.update(learning_rate)?;
488        self.proj_b.update(learning_rate)?;
489        self.low_rank_proj.update(learning_rate)?;
490        Ok(())
491    }
492
493    fn params(&self) -> Vec<Array<F, IxDyn>> {
494        let mut params = Vec::new();
495        params.extend(self.proj_a.params());
496        params.extend(self.proj_b.params());
497        params.extend(self.low_rank_proj.params());
498        params
499    }
500
501    fn set_training(&mut self, training: bool) {
502        self.proj_a.set_training(training);
503        self.proj_b.set_training(training);
504        self.low_rank_proj.set_training(training);
505    }
506
507    fn is_training(&self) -> bool {
508        self.proj_a.is_training()
509    }
510}
511
512/// Feature Fusion model
513pub struct FeatureFusion<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
514where
515    F: SimdUnifiedOps,
516{
517    /// Feature aligners for each input modality
518    pub aligners: Vec<FeatureAlignment<F>>,
519    /// Fusion-specific modules
520    pub fusion_module: Option<Box<dyn Layer<F> + Send + Sync>>,
521    /// Post-fusion MLP
522    pub post_fusion: Sequential<F>,
523    /// Classifier head
524    pub classifier: Option<Dense<F>>,
525    /// Model configuration
526    pub config: FeatureFusionConfig,
527}
528
529// Manual implementation of Debug for FeatureFusion to handle dyn Layer trait objects
530impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for FeatureFusion<F>
531where
532    F: SimdUnifiedOps,
533{
534    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        f.debug_struct("FeatureFusion")
536            .field("aligners", &self.aligners)
537            .field(
538                "fusion_module",
539                &"<Box<dyn Layer<F> + Send + Sync>>".to_string(),
540            )
541            .field("post_fusion", &self.post_fusion)
542            .field("classifier", &self.classifier)
543            .field("config", &self.config)
544            .finish()
545    }
546}
547
548// Manual implementation of Clone for FeatureFusion
549impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Clone for FeatureFusion<F>
550where
551    F: SimdUnifiedOps,
552{
553    fn clone(&self) -> Self {
554        // We can't clone the dyn Layer directly, so we create a new FeatureFusion
555        // without the fusion_module
556        // We would need to implement custom clone logic for fusion_module
557        // based on its actual type if needed, but for now we leave it as None
558        Self {
559            aligners: self.aligners.clone(),
560            fusion_module: None, // Can't clone the trait object
561            post_fusion: self.post_fusion.clone(),
562            classifier: self.classifier.clone(),
563            config: self.config.clone(),
564        }
565    }
566}
567
568impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FeatureFusion<F>
569where
570    F: SimdUnifiedOps,
571{
572    /// Create a new FeatureFusion model
573    pub fn new(config: FeatureFusionConfig) -> Result<Self> {
574        // Create feature aligners
575        let mut aligners = Vec::with_capacity(config.input_dims.len());
576        for (i, &dim) in config.input_dims.iter().enumerate() {
577            aligners.push(FeatureAlignment::<F>::new(
578                dim,
579                config.hidden_dim,
580                Some(&format!("aligner_{}", i)),
581            )?);
582        }
583
584        // Create fusion-specific module based on method
585        let fusion_module: Option<Box<dyn Layer<F> + Send + Sync>> = match config.fusion_method {
586            FusionMethod::Attention => {
587                if config.input_dims.len() < 2 {
588                    return Err(NeuralError::ValidationError(
589                        "Attention fusion requires at least two modalities".to_string(),
590                    ));
591                }
592                let attn = CrossModalAttention::<F>::new(
593                    config.hidden_dim,
594                    config.hidden_dim,
595                    config.hidden_dim,
596                )?;
597                Some(Box::new(attn))
598            }
599            FusionMethod::Bilinear => {
600                if config.input_dims.len() != 2 {
601                    return Err(NeuralError::ValidationError(
602                        "Bilinear fusion requires exactly two modalities".to_string(),
603                    ));
604                }
605                let bilinear = BilinearFusion::<F>::new(
606                    config.hidden_dim,
607                    config.hidden_dim,
608                    config.hidden_dim,
609                    config.hidden_dim / 4, // Low-rank approximation
610                )?;
611                Some(Box::new(bilinear))
612            }
613            FusionMethod::FiLM => {
614                if config.input_dims.len() != 2 {
615                    return Err(NeuralError::ValidationError(
616                        "FiLM fusion requires exactly two modalities".to_string(),
617                    ));
618                }
619                let film = FiLMModule::<F>::new(config.hidden_dim, config.hidden_dim)?;
620                Some(Box::new(film))
621            }
622            // For simpler methods (concat, sum, product), we don't need special modules
623            _ => None,
624        };
625        // Create post-fusion MLP
626        let mut post_fusion = Sequential::new();
627        // Determine input dimension for the post-fusion network
628        let post_fusion_input_dim = match config.fusion_method {
629            FusionMethod::Concatenation => config.hidden_dim * config.input_dims.len(),
630            _ => config.hidden_dim,
631        };
632
633        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
634        post_fusion.add(Dense::<F>::new(
635            post_fusion_input_dim,
636            config.hidden_dim * 2,
637            Some("gelu"),
638            &mut rng,
639        )?);
640        if config.dropout_rate > 0.0 {
641            post_fusion.add(Dropout::<F>::new(config.dropout_rate, &mut rng)?);
642        }
643        post_fusion.add(Dense::<F>::new(
644            config.hidden_dim * 2,
645            config.hidden_dim,
646            Some("gelu"),
647            &mut rng,
648        )?);
649
650        // Create classifier if needed
651        let classifier = if config.include_head {
652            Some(Dense::<F>::new(
653                config.hidden_dim,
654                config.num_classes,
655                None,
656                &mut rng,
657            )?)
658        } else {
659            None
660        };
661
662        Ok(Self {
663            aligners,
664            fusion_module,
665            post_fusion,
666            classifier,
667            config,
668        })
669    }
670
671    /// Forward pass with multiple input modalities
672    pub fn forward_multi(&self, inputs: &[Array<F, IxDyn>]) -> Result<Array<F, IxDyn>> {
673        if inputs.len() != self.config.input_dims.len() {
674            return Err(NeuralError::ValidationError(format!(
675                "Expected {} inputs, got {}",
676                self.config.input_dims.len(),
677                inputs.len()
678            )));
679        }
680
681        // Align features from each modality
682        let mut aligned_features = Vec::with_capacity(inputs.len());
683        for (i, input) in inputs.iter().enumerate() {
684            aligned_features.push(self.aligners[i].forward(input)?);
685        }
686
687        // Apply fusion based on method
688        let fused = match self.config.fusion_method {
689            FusionMethod::Concatenation => {
690                // Concatenate along feature dimension
691                let batch_size = aligned_features[0].shape()[0];
692                let mut concatenated = Vec::new();
693                for batch_idx in 0..batch_size {
694                    for features in &aligned_features {
695                        let batch_features = features.slice_axis(
696                            Axis(0),
697                            scirs2_core::ndarray::Slice::from(batch_idx..batch_idx + 1),
698                        );
699                        concatenated.extend(batch_features.iter().cloned());
700                    }
701                }
702                Array::from_shape_vec(
703                    [batch_size, self.config.hidden_dim * aligned_features.len()],
704                    concatenated,
705                )?
706                .into_dyn()
707            }
708            FusionMethod::Sum => {
709                // Element-wise sum
710                let mut result = aligned_features[0].clone();
711                for features in &aligned_features[1..] {
712                    result += features;
713                }
714                result
715            }
716            FusionMethod::Product => {
717                // Element-wise product
718                let mut result = aligned_features[0].clone();
719                for features in &aligned_features[1..] {
720                    result *= features;
721                }
722                result
723            }
724            FusionMethod::Attention => {
725                // Use attention module (modality 0 attends to modality 1)
726                if let Some(ref module) = self.fusion_module {
727                    // We need to cast the module as CrossModalAttention
728                    if let Some(attn) = module.as_any().downcast_ref::<CrossModalAttention<F>>() {
729                        attn.forward(&aligned_features[0], &aligned_features[1])?
730                    } else {
731                        return Err(NeuralError::InferenceError(
732                            "Failed to cast fusion module to CrossModalAttention".to_string(),
733                        ));
734                    }
735                } else {
736                    return Err(NeuralError::InferenceError(
737                        "Attention fusion module not initialized".to_string(),
738                    ));
739                }
740            }
741            FusionMethod::Bilinear => {
742                // Use bilinear module
743                if let Some(ref module) = self.fusion_module {
744                    // We need to cast the module as BilinearFusion
745                    if let Some(bilinear) = module.as_any().downcast_ref::<BilinearFusion<F>>() {
746                        bilinear.forward(&aligned_features[0], &aligned_features[1])?
747                    } else {
748                        return Err(NeuralError::InferenceError(
749                            "Failed to cast fusion module to BilinearFusion".to_string(),
750                        ));
751                    }
752                } else {
753                    return Err(NeuralError::InferenceError(
754                        "Bilinear fusion module not initialized".to_string(),
755                    ));
756                }
757            }
758            FusionMethod::FiLM => {
759                // Use FiLM module (modality 1 conditions modality 0)
760                if let Some(ref module) = self.fusion_module {
761                    // We need to cast the module as FiLMModule
762                    if let Some(film) = module.as_any().downcast_ref::<FiLMModule<F>>() {
763                        film.forward(&aligned_features[0], &aligned_features[1])?
764                    } else {
765                        return Err(NeuralError::InferenceError(
766                            "Failed to cast fusion module to FiLMModule".to_string(),
767                        ));
768                    }
769                } else {
770                    return Err(NeuralError::InferenceError(
771                        "FiLM fusion module not initialized".to_string(),
772                    ));
773                }
774            }
775        };
776
777        // Apply post-fusion network
778        let features = self.post_fusion.forward(&fused)?;
779        // Apply classifier if available
780        if let Some(ref classifier) = self.classifier {
781            classifier.forward(&features)
782        } else {
783            Ok(features)
784        }
785    }
786
787    /// Create a simple early fusion model for two modalities
788    pub fn create_early_fusion(
789        dim_a: usize,
790        dim_b: usize,
791        hidden_dim: usize,
792        num_classes: usize,
793        include_head: bool,
794    ) -> Result<Self> {
795        let config = FeatureFusionConfig {
796            input_dims: vec![dim_a, dim_b],
797            hidden_dim,
798            fusion_method: FusionMethod::Concatenation,
799            dropout_rate: 0.1,
800            num_classes,
801            include_head,
802        };
803        Self::new(config)
804    }
805
806    /// Create an attention-based fusion model for two modalities
807    pub fn create_attention_fusion(
808        dim_a: usize,
809        dim_b: usize,
810        hidden_dim: usize,
811        num_classes: usize,
812        include_head: bool,
813    ) -> Result<Self> {
814        let config = FeatureFusionConfig {
815            input_dims: vec![dim_a, dim_b],
816            hidden_dim,
817            fusion_method: FusionMethod::Attention,
818            dropout_rate: 0.1,
819            num_classes,
820            include_head,
821        };
822        Self::new(config)
823    }
824
825    /// Create a FiLM conditioning fusion model (B conditions A)
826    pub fn create_film_fusion(
827        dim_a: usize,
828        dim_b: usize,
829        hidden_dim: usize,
830        num_classes: usize,
831        include_head: bool,
832    ) -> Result<Self> {
833        let config = FeatureFusionConfig {
834            input_dims: vec![dim_a, dim_b],
835            hidden_dim,
836            fusion_method: FusionMethod::FiLM,
837            dropout_rate: 0.1,
838            num_classes,
839            include_head,
840        };
841        Self::new(config)
842    }
843}
844
845impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FeatureFusion<F>
846where
847    F: SimdUnifiedOps,
848{
849    fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
850        // For a single packed input, we need to split it into modalities
851        // This is mainly for the Layer trait compatibility
852        // In practice, use forward_multi with separate inputs
853        Err(NeuralError::ValidationError(
854            "FeatureFusion requires multiple inputs. Use forward_multi method instead.".to_string(),
855        ))
856    }
857
858    fn as_any(&self) -> &dyn std::any::Any {
859        self
860    }
861
862    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
863        self
864    }
865
866    fn backward(
867        &self,
868        _input: &Array<F, IxDyn>,
869        grad_output: &Array<F, IxDyn>,
870    ) -> Result<Array<F, IxDyn>> {
871        // For FeatureFusion, the backward pass is complex because it involves
872        // multiple inputs and various fusion strategies. Since the Layer trait only provides
873        // This would require a custom backward method that takes multiple inputs
874        // and understands the specific fusion strategy being used.
875        // to propagate gradients backward through the entire fusion pipeline:
876        // 1. Backward through classifier (if present)
877        // 2. Backward through post-fusion network
878        // 3. Backward through fusion operation (depends on fusion method)
879        // 4. Backward through aligners to get gradients for each modality
880        Ok(grad_output.clone())
881    }
882
883    fn update(&mut self, learning_rate: F) -> Result<()> {
884        // Update all aligners
885        for aligner in &mut self.aligners {
886            aligner.update(learning_rate)?;
887        }
888        // Update fusion module if present
889        if let Some(ref mut module) = self.fusion_module {
890            module.update(learning_rate)?;
891        }
892        // Update post-fusion network
893        self.post_fusion.update(learning_rate)?;
894        // Update classifier if present
895        if let Some(ref mut classifier) = self.classifier {
896            classifier.update(learning_rate)?;
897        }
898        Ok(())
899    }
900
901    fn params(&self) -> Vec<Array<F, IxDyn>> {
902        let mut params = Vec::new();
903        for aligner in &self.aligners {
904            params.extend(aligner.params());
905        }
906        if let Some(ref module) = self.fusion_module {
907            params.extend(module.params());
908        }
909        params.extend(self.post_fusion.params());
910        if let Some(ref classifier) = self.classifier {
911            params.extend(classifier.params());
912        }
913        params
914    }
915
916    fn set_training(&mut self, training: bool) {
917        for aligner in &mut self.aligners {
918            aligner.set_training(training);
919        }
920        if let Some(ref mut module) = self.fusion_module {
921            module.set_training(training);
922        }
923        self.post_fusion.set_training(training);
924        if let Some(ref mut classifier) = self.classifier {
925            classifier.set_training(training);
926        }
927    }
928
929    fn is_training(&self) -> bool {
930        self.aligners[0].is_training()
931    }
932}