1use crate::error::{NeuralError, Result};
7use crate::layers::{Dense, Dropout, Layer, LayerNorm, Sequential};
8use scirs2_core::ndarray::{Array, Axis, IxDyn, ScalarOperand};
9use scirs2_core::numeric::{Float, NumAssign};
10use scirs2_core::random::SeedableRng;
11use scirs2_core::simd_ops::SimdUnifiedOps;
12use serde::{Deserialize, Serialize};
13use std::fmt::Debug;
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum FusionMethod {
17 Concatenation,
19 Sum,
21 Product,
23 Attention,
25 Bilinear,
27 FiLM,
29}
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FeatureFusionConfig {
33 pub input_dims: Vec<usize>,
35 pub hidden_dim: usize,
37 pub fusion_method: FusionMethod,
39 pub dropout_rate: f64,
41 pub num_classes: usize,
43 pub include_head: bool,
45}
46
47#[derive(Debug, Clone)]
49pub struct FeatureAlignment<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
50where
51 F: SimdUnifiedOps,
52{
53 pub input_dim: usize,
55 pub output_dim: usize,
57 pub projection: Dense<F>,
59 pub norm: LayerNorm<F>,
61}
62
63impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FeatureAlignment<F>
64where
65 F: SimdUnifiedOps,
66{
67 pub fn new(input_dim: usize, output_dim: usize, _name: Option<&str>) -> Result<Self> {
69 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
70 let projection = Dense::<F>::new(input_dim, output_dim, None, &mut rng)?;
71 let norm = LayerNorm::<F>::new(output_dim, 1e-6, &mut rng)?;
72 Ok(Self {
73 input_dim,
74 output_dim,
75 projection,
76 norm,
77 })
78 }
79}
80
81impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FeatureAlignment<F>
82where
83 F: SimdUnifiedOps,
84{
85 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
86 let x = self.projection.forward(input)?;
87 let x = self.norm.forward(&x)?;
88 Ok(x)
89 }
90
91 fn as_any(&self) -> &dyn std::any::Any {
92 self
93 }
94
95 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
96 self
97 }
98
99 fn backward(
100 &self,
101 input: &Array<F, IxDyn>,
102 grad_output: &Array<F, IxDyn>,
103 ) -> Result<Array<F, IxDyn>> {
104 let proj_output = self.projection.forward(input)?;
107 let grad_proj = self.norm.backward(&proj_output, grad_output)?;
109 let grad_input = self.projection.backward(input, &grad_proj)?;
111 Ok(grad_input)
112 }
113
114 fn update(&mut self, learning_rate: F) -> Result<()> {
115 self.projection.update(learning_rate)?;
117 self.norm.update(learning_rate)?;
119 Ok(())
120 }
121
122 fn params(&self) -> Vec<Array<F, IxDyn>> {
123 let mut params = Vec::new();
124 params.extend(self.projection.params());
125 params.extend(self.norm.params());
126 params
127 }
128
129 fn set_training(&mut self, training: bool) {
130 self.projection.set_training(training);
131 self.norm.set_training(training);
132 }
133
134 fn is_training(&self) -> bool {
135 self.projection.is_training()
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct CrossModalAttention<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
142 pub query_proj: Dense<F>,
144 pub key_proj: Dense<F>,
146 pub value_proj: Dense<F>,
148 pub output_proj: Dense<F>,
150 pub hidden_dim: usize,
152 pub scale: F,
154}
155
156impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> CrossModalAttention<F> {
157 pub fn new(query_dim: usize, key_dim: usize, hidden_dim: usize) -> Result<Self> {
159 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
160 let query_proj = Dense::<F>::new(query_dim, hidden_dim, None, &mut rng)?;
161 let key_proj = Dense::<F>::new(key_dim, hidden_dim, None, &mut rng)?;
162 let value_proj = Dense::<F>::new(key_dim, hidden_dim, None, &mut rng)?;
163 let output_proj = Dense::<F>::new(hidden_dim, query_dim, None, &mut rng)?;
164 let scale = F::from(1.0 / (hidden_dim as f64).sqrt()).expect("Operation failed");
166 Ok(Self {
167 query_proj,
168 key_proj,
169 value_proj,
170 output_proj,
171 hidden_dim,
172 scale,
173 })
174 }
175
176 pub fn forward(
178 &self,
179 query: &Array<F, IxDyn>,
180 context: &Array<F, IxDyn>,
181 ) -> Result<Array<F, IxDyn>> {
182 let q = self.query_proj.forward(query)?;
184 let k = self.key_proj.forward(context)?;
185 let v = self.value_proj.forward(context)?;
186 let batch_size = q.shape()[0];
188 let query_len = q.shape()[1];
189 let context_len = k.shape()[1];
190 let q_2d = q
191 .clone()
192 .into_shape_with_order((batch_size * query_len, self.hidden_dim))?;
193 let k_2d = k.into_shape_with_order((batch_size * context_len, self.hidden_dim))?;
194 let v_2d = v.into_shape_with_order((batch_size * context_len, self.hidden_dim))?;
195
196 let scores = q_2d.dot(&k_2d.t()) * self.scale;
198 let scores_3d = scores.into_shape_with_order((batch_size, query_len, context_len))?;
200 let mut attention_weights = scores_3d.to_owned().into_dyn();
202 attention_weights.fill(F::zero());
203 for b in 0..batch_size {
204 for q in 0..query_len {
205 let mut row = scores_3d
206 .slice(scirs2_core::ndarray::s![b, q, ..])
207 .to_owned();
208 let max_val = row.fold(F::neg_infinity(), |m: F, &v: &F| m.max(v));
210 let mut exp_sum = F::zero();
212 for i in 0..context_len {
213 let exp_val = (row[i] - max_val).exp();
214 row[i] = exp_val;
215 exp_sum += exp_val;
216 }
217 if exp_sum > F::zero() {
219 for i in 0..context_len {
220 row[i] /= exp_sum;
221 }
222 }
223 for i in 0..context_len {
225 attention_weights[[b, q, i]] = row[i];
226 }
227 }
228 }
229
230 let attn_weights_2d = attention_weights
232 .into_shape_with_order((batch_size * query_len, batch_size * context_len))?;
233 let context_vec = attn_weights_2d.dot(&v_2d);
235 let context_vec_reshaped =
237 context_vec.into_shape_with_order((batch_size, query_len, self.hidden_dim))?;
238 let output = self.output_proj.forward(&context_vec_reshaped.into_dyn())?;
240 Ok(output)
241 }
242}
243
244impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F>
245 for CrossModalAttention<F>
246{
247 fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
248 Err(NeuralError::ValidationError(
251 "CrossModalAttention requires separate query and context inputs. Use the dedicated forward method."
252 .to_string(),
253 ))
254 }
255
256 fn as_any(&self) -> &dyn std::any::Any {
257 self
258 }
259
260 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
261 self
262 }
263
264 fn backward(
265 &self,
266 _input: &Array<F, IxDyn>,
267 grad_output: &Array<F, IxDyn>,
268 ) -> Result<Array<F, IxDyn>> {
269 Ok(grad_output.clone())
278 }
279
280 fn update(&mut self, learning_rate: F) -> Result<()> {
281 self.query_proj.update(learning_rate)?;
283 self.key_proj.update(learning_rate)?;
284 self.value_proj.update(learning_rate)?;
285 self.output_proj.update(learning_rate)?;
286 Ok(())
287 }
288
289 fn params(&self) -> Vec<Array<F, IxDyn>> {
290 let mut params = Vec::new();
291 params.extend(self.query_proj.params());
292 params.extend(self.key_proj.params());
293 params.extend(self.value_proj.params());
294 params.extend(self.output_proj.params());
295 params
296 }
297
298 fn set_training(&mut self, training: bool) {
299 self.query_proj.set_training(training);
300 self.key_proj.set_training(training);
301 self.value_proj.set_training(training);
302 self.output_proj.set_training(training);
303 }
304
305 fn is_training(&self) -> bool {
306 self.query_proj.is_training()
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct FiLMModule<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
313 pub feature_dim: usize,
315 pub cond_dim: usize,
317 pub gamma_proj: Dense<F>,
319 pub beta_proj: Dense<F>,
321}
322
323impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FiLMModule<F> {
324 pub fn new(feature_dim: usize, cond_dim: usize) -> Result<Self> {
326 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
327 let gamma_proj = Dense::<F>::new(cond_dim, feature_dim, None, &mut rng)?;
328 let beta_proj = Dense::<F>::new(cond_dim, feature_dim, None, &mut rng)?;
329 Ok(Self {
330 feature_dim,
331 cond_dim,
332 gamma_proj,
333 beta_proj,
334 })
335 }
336
337 pub fn forward(
339 &self,
340 features: &Array<F, IxDyn>,
341 conditioning: &Array<F, IxDyn>,
342 ) -> Result<Array<F, IxDyn>> {
343 let gamma = self.gamma_proj.forward(conditioning)?;
345 let beta = self.beta_proj.forward(conditioning)?;
346 let modulated = &gamma * features + β
348 Ok(modulated)
349 }
350}
351
352impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FiLMModule<F> {
353 fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
354 Err(NeuralError::ValidationError(
356 "FiLMModule requires separate feature and conditioning inputs. Use the dedicated forward method."
357 .to_string(),
358 ))
359 }
360
361 fn as_any(&self) -> &dyn std::any::Any {
362 self
363 }
364
365 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
366 self
367 }
368
369 fn backward(
370 &self,
371 _input: &Array<F, IxDyn>,
372 grad_output: &Array<F, IxDyn>,
373 ) -> Result<Array<F, IxDyn>> {
374 Ok(grad_output.clone())
380 }
381
382 fn update(&mut self, learning_rate: F) -> Result<()> {
383 self.gamma_proj.update(learning_rate)?;
385 self.beta_proj.update(learning_rate)?;
386 Ok(())
387 }
388
389 fn params(&self) -> Vec<Array<F, IxDyn>> {
390 let mut params = Vec::new();
391 params.extend(self.gamma_proj.params());
392 params.extend(self.beta_proj.params());
393 params
394 }
395
396 fn set_training(&mut self, training: bool) {
397 self.gamma_proj.set_training(training);
398 self.beta_proj.set_training(training);
399 }
400
401 fn is_training(&self) -> bool {
402 self.gamma_proj.is_training()
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct BilinearFusion<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
409 pub dim_a: usize,
411 pub dim_b: usize,
413 pub output_dim: usize,
415 pub proj_a: Dense<F>,
417 pub proj_b: Dense<F>,
419 pub low_rank_proj: Dense<F>,
421}
422
423impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> BilinearFusion<F> {
424 pub fn new(dim_a: usize, dim_b: usize, output_dim: usize, rank: usize) -> Result<Self> {
426 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
427 let proj_a = Dense::<F>::new(dim_a, rank, None, &mut rng)?;
428 let proj_b = Dense::<F>::new(dim_b, rank, None, &mut rng)?;
429 let low_rank_proj = Dense::<F>::new(rank, output_dim, None, &mut rng)?;
430 Ok(Self {
431 dim_a,
432 dim_b,
433 output_dim,
434 proj_a,
435 proj_b,
436 low_rank_proj,
437 })
438 }
439
440 pub fn forward(
442 &self,
443 features_a: &Array<F, IxDyn>,
444 features_b: &Array<F, IxDyn>,
445 ) -> Result<Array<F, IxDyn>> {
446 let a_proj = self.proj_a.forward(features_a)?;
448 let b_proj = self.proj_b.forward(features_b)?;
449 let bilinear = &a_proj * &b_proj;
451 let output = self.low_rank_proj.forward(&bilinear)?;
452 Ok(output)
453 }
454}
455
456impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for BilinearFusion<F> {
457 fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
458 Err(NeuralError::ValidationError(
460 "BilinearFusion requires separate feature inputs. Use the dedicated forward method."
461 .to_string(),
462 ))
463 }
464
465 fn as_any(&self) -> &dyn std::any::Any {
466 self
467 }
468
469 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
470 self
471 }
472
473 fn backward(
474 &self,
475 _input: &Array<F, IxDyn>,
476 grad_output: &Array<F, IxDyn>,
477 ) -> Result<Array<F, IxDyn>> {
478 Ok(grad_output.clone())
484 }
485
486 fn update(&mut self, learning_rate: F) -> Result<()> {
487 self.proj_a.update(learning_rate)?;
488 self.proj_b.update(learning_rate)?;
489 self.low_rank_proj.update(learning_rate)?;
490 Ok(())
491 }
492
493 fn params(&self) -> Vec<Array<F, IxDyn>> {
494 let mut params = Vec::new();
495 params.extend(self.proj_a.params());
496 params.extend(self.proj_b.params());
497 params.extend(self.low_rank_proj.params());
498 params
499 }
500
501 fn set_training(&mut self, training: bool) {
502 self.proj_a.set_training(training);
503 self.proj_b.set_training(training);
504 self.low_rank_proj.set_training(training);
505 }
506
507 fn is_training(&self) -> bool {
508 self.proj_a.is_training()
509 }
510}
511
512pub struct FeatureFusion<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign>
514where
515 F: SimdUnifiedOps,
516{
517 pub aligners: Vec<FeatureAlignment<F>>,
519 pub fusion_module: Option<Box<dyn Layer<F> + Send + Sync>>,
521 pub post_fusion: Sequential<F>,
523 pub classifier: Option<Dense<F>>,
525 pub config: FeatureFusionConfig,
527}
528
529impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Debug for FeatureFusion<F>
531where
532 F: SimdUnifiedOps,
533{
534 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535 f.debug_struct("FeatureFusion")
536 .field("aligners", &self.aligners)
537 .field(
538 "fusion_module",
539 &"<Box<dyn Layer<F> + Send + Sync>>".to_string(),
540 )
541 .field("post_fusion", &self.post_fusion)
542 .field("classifier", &self.classifier)
543 .field("config", &self.config)
544 .finish()
545 }
546}
547
548impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Clone for FeatureFusion<F>
550where
551 F: SimdUnifiedOps,
552{
553 fn clone(&self) -> Self {
554 Self {
559 aligners: self.aligners.clone(),
560 fusion_module: None, post_fusion: self.post_fusion.clone(),
562 classifier: self.classifier.clone(),
563 config: self.config.clone(),
564 }
565 }
566}
567
568impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> FeatureFusion<F>
569where
570 F: SimdUnifiedOps,
571{
572 pub fn new(config: FeatureFusionConfig) -> Result<Self> {
574 let mut aligners = Vec::with_capacity(config.input_dims.len());
576 for (i, &dim) in config.input_dims.iter().enumerate() {
577 aligners.push(FeatureAlignment::<F>::new(
578 dim,
579 config.hidden_dim,
580 Some(&format!("aligner_{}", i)),
581 )?);
582 }
583
584 let fusion_module: Option<Box<dyn Layer<F> + Send + Sync>> = match config.fusion_method {
586 FusionMethod::Attention => {
587 if config.input_dims.len() < 2 {
588 return Err(NeuralError::ValidationError(
589 "Attention fusion requires at least two modalities".to_string(),
590 ));
591 }
592 let attn = CrossModalAttention::<F>::new(
593 config.hidden_dim,
594 config.hidden_dim,
595 config.hidden_dim,
596 )?;
597 Some(Box::new(attn))
598 }
599 FusionMethod::Bilinear => {
600 if config.input_dims.len() != 2 {
601 return Err(NeuralError::ValidationError(
602 "Bilinear fusion requires exactly two modalities".to_string(),
603 ));
604 }
605 let bilinear = BilinearFusion::<F>::new(
606 config.hidden_dim,
607 config.hidden_dim,
608 config.hidden_dim,
609 config.hidden_dim / 4, )?;
611 Some(Box::new(bilinear))
612 }
613 FusionMethod::FiLM => {
614 if config.input_dims.len() != 2 {
615 return Err(NeuralError::ValidationError(
616 "FiLM fusion requires exactly two modalities".to_string(),
617 ));
618 }
619 let film = FiLMModule::<F>::new(config.hidden_dim, config.hidden_dim)?;
620 Some(Box::new(film))
621 }
622 _ => None,
624 };
625 let mut post_fusion = Sequential::new();
627 let post_fusion_input_dim = match config.fusion_method {
629 FusionMethod::Concatenation => config.hidden_dim * config.input_dims.len(),
630 _ => config.hidden_dim,
631 };
632
633 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
634 post_fusion.add(Dense::<F>::new(
635 post_fusion_input_dim,
636 config.hidden_dim * 2,
637 Some("gelu"),
638 &mut rng,
639 )?);
640 if config.dropout_rate > 0.0 {
641 post_fusion.add(Dropout::<F>::new(config.dropout_rate, &mut rng)?);
642 }
643 post_fusion.add(Dense::<F>::new(
644 config.hidden_dim * 2,
645 config.hidden_dim,
646 Some("gelu"),
647 &mut rng,
648 )?);
649
650 let classifier = if config.include_head {
652 Some(Dense::<F>::new(
653 config.hidden_dim,
654 config.num_classes,
655 None,
656 &mut rng,
657 )?)
658 } else {
659 None
660 };
661
662 Ok(Self {
663 aligners,
664 fusion_module,
665 post_fusion,
666 classifier,
667 config,
668 })
669 }
670
671 pub fn forward_multi(&self, inputs: &[Array<F, IxDyn>]) -> Result<Array<F, IxDyn>> {
673 if inputs.len() != self.config.input_dims.len() {
674 return Err(NeuralError::ValidationError(format!(
675 "Expected {} inputs, got {}",
676 self.config.input_dims.len(),
677 inputs.len()
678 )));
679 }
680
681 let mut aligned_features = Vec::with_capacity(inputs.len());
683 for (i, input) in inputs.iter().enumerate() {
684 aligned_features.push(self.aligners[i].forward(input)?);
685 }
686
687 let fused = match self.config.fusion_method {
689 FusionMethod::Concatenation => {
690 let batch_size = aligned_features[0].shape()[0];
692 let mut concatenated = Vec::new();
693 for batch_idx in 0..batch_size {
694 for features in &aligned_features {
695 let batch_features = features.slice_axis(
696 Axis(0),
697 scirs2_core::ndarray::Slice::from(batch_idx..batch_idx + 1),
698 );
699 concatenated.extend(batch_features.iter().cloned());
700 }
701 }
702 Array::from_shape_vec(
703 [batch_size, self.config.hidden_dim * aligned_features.len()],
704 concatenated,
705 )?
706 .into_dyn()
707 }
708 FusionMethod::Sum => {
709 let mut result = aligned_features[0].clone();
711 for features in &aligned_features[1..] {
712 result += features;
713 }
714 result
715 }
716 FusionMethod::Product => {
717 let mut result = aligned_features[0].clone();
719 for features in &aligned_features[1..] {
720 result *= features;
721 }
722 result
723 }
724 FusionMethod::Attention => {
725 if let Some(ref module) = self.fusion_module {
727 if let Some(attn) = module.as_any().downcast_ref::<CrossModalAttention<F>>() {
729 attn.forward(&aligned_features[0], &aligned_features[1])?
730 } else {
731 return Err(NeuralError::InferenceError(
732 "Failed to cast fusion module to CrossModalAttention".to_string(),
733 ));
734 }
735 } else {
736 return Err(NeuralError::InferenceError(
737 "Attention fusion module not initialized".to_string(),
738 ));
739 }
740 }
741 FusionMethod::Bilinear => {
742 if let Some(ref module) = self.fusion_module {
744 if let Some(bilinear) = module.as_any().downcast_ref::<BilinearFusion<F>>() {
746 bilinear.forward(&aligned_features[0], &aligned_features[1])?
747 } else {
748 return Err(NeuralError::InferenceError(
749 "Failed to cast fusion module to BilinearFusion".to_string(),
750 ));
751 }
752 } else {
753 return Err(NeuralError::InferenceError(
754 "Bilinear fusion module not initialized".to_string(),
755 ));
756 }
757 }
758 FusionMethod::FiLM => {
759 if let Some(ref module) = self.fusion_module {
761 if let Some(film) = module.as_any().downcast_ref::<FiLMModule<F>>() {
763 film.forward(&aligned_features[0], &aligned_features[1])?
764 } else {
765 return Err(NeuralError::InferenceError(
766 "Failed to cast fusion module to FiLMModule".to_string(),
767 ));
768 }
769 } else {
770 return Err(NeuralError::InferenceError(
771 "FiLM fusion module not initialized".to_string(),
772 ));
773 }
774 }
775 };
776
777 let features = self.post_fusion.forward(&fused)?;
779 if let Some(ref classifier) = self.classifier {
781 classifier.forward(&features)
782 } else {
783 Ok(features)
784 }
785 }
786
787 pub fn create_early_fusion(
789 dim_a: usize,
790 dim_b: usize,
791 hidden_dim: usize,
792 num_classes: usize,
793 include_head: bool,
794 ) -> Result<Self> {
795 let config = FeatureFusionConfig {
796 input_dims: vec![dim_a, dim_b],
797 hidden_dim,
798 fusion_method: FusionMethod::Concatenation,
799 dropout_rate: 0.1,
800 num_classes,
801 include_head,
802 };
803 Self::new(config)
804 }
805
806 pub fn create_attention_fusion(
808 dim_a: usize,
809 dim_b: usize,
810 hidden_dim: usize,
811 num_classes: usize,
812 include_head: bool,
813 ) -> Result<Self> {
814 let config = FeatureFusionConfig {
815 input_dims: vec![dim_a, dim_b],
816 hidden_dim,
817 fusion_method: FusionMethod::Attention,
818 dropout_rate: 0.1,
819 num_classes,
820 include_head,
821 };
822 Self::new(config)
823 }
824
825 pub fn create_film_fusion(
827 dim_a: usize,
828 dim_b: usize,
829 hidden_dim: usize,
830 num_classes: usize,
831 include_head: bool,
832 ) -> Result<Self> {
833 let config = FeatureFusionConfig {
834 input_dims: vec![dim_a, dim_b],
835 hidden_dim,
836 fusion_method: FusionMethod::FiLM,
837 dropout_rate: 0.1,
838 num_classes,
839 include_head,
840 };
841 Self::new(config)
842 }
843}
844
845impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> Layer<F> for FeatureFusion<F>
846where
847 F: SimdUnifiedOps,
848{
849 fn forward(&self, _input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
850 Err(NeuralError::ValidationError(
854 "FeatureFusion requires multiple inputs. Use forward_multi method instead.".to_string(),
855 ))
856 }
857
858 fn as_any(&self) -> &dyn std::any::Any {
859 self
860 }
861
862 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
863 self
864 }
865
866 fn backward(
867 &self,
868 _input: &Array<F, IxDyn>,
869 grad_output: &Array<F, IxDyn>,
870 ) -> Result<Array<F, IxDyn>> {
871 Ok(grad_output.clone())
881 }
882
883 fn update(&mut self, learning_rate: F) -> Result<()> {
884 for aligner in &mut self.aligners {
886 aligner.update(learning_rate)?;
887 }
888 if let Some(ref mut module) = self.fusion_module {
890 module.update(learning_rate)?;
891 }
892 self.post_fusion.update(learning_rate)?;
894 if let Some(ref mut classifier) = self.classifier {
896 classifier.update(learning_rate)?;
897 }
898 Ok(())
899 }
900
901 fn params(&self) -> Vec<Array<F, IxDyn>> {
902 let mut params = Vec::new();
903 for aligner in &self.aligners {
904 params.extend(aligner.params());
905 }
906 if let Some(ref module) = self.fusion_module {
907 params.extend(module.params());
908 }
909 params.extend(self.post_fusion.params());
910 if let Some(ref classifier) = self.classifier {
911 params.extend(classifier.params());
912 }
913 params
914 }
915
916 fn set_training(&mut self, training: bool) {
917 for aligner in &mut self.aligners {
918 aligner.set_training(training);
919 }
920 if let Some(ref mut module) = self.fusion_module {
921 module.set_training(training);
922 }
923 self.post_fusion.set_training(training);
924 if let Some(ref mut classifier) = self.classifier {
925 classifier.set_training(training);
926 }
927 }
928
929 fn is_training(&self) -> bool {
930 self.aligners[0].is_training()
931 }
932}