1use crate::error::{NeuralError, Result};
40use crate::layers::{Dense, Dropout, Layer, LayerNorm};
41use scirs2_core::ndarray::{s, Array, Array2, Array3, Axis, IxDyn, ScalarOperand};
42use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
43use scirs2_core::random::{Rng, RngExt};
44use serde::{Deserialize, Serialize};
45use std::fmt::Debug;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MLPMixerConfig {
50 pub image_size: usize,
52 pub patch_size: usize,
54 pub num_classes: usize,
56 pub hidden_dim: usize,
58 pub num_blocks: usize,
60 pub token_mlp_dim: usize,
62 pub channel_mlp_dim: usize,
64 pub dropout_rate: f64,
66 pub in_channels: usize,
68}
69
70impl Default for MLPMixerConfig {
71 fn default() -> Self {
72 Self {
73 image_size: 224,
74 patch_size: 16,
75 num_classes: 1000,
76 hidden_dim: 512,
77 num_blocks: 8,
78 token_mlp_dim: 256,
79 channel_mlp_dim: 2048,
80 dropout_rate: 0.0,
81 in_channels: 3,
82 }
83 }
84}
85
86impl MLPMixerConfig {
87 pub fn mixer_s_32(num_classes: usize) -> Self {
89 Self {
90 image_size: 224,
91 patch_size: 32,
92 num_classes,
93 hidden_dim: 512,
94 num_blocks: 8,
95 token_mlp_dim: 256,
96 channel_mlp_dim: 2048,
97 dropout_rate: 0.0,
98 in_channels: 3,
99 }
100 }
101
102 pub fn mixer_s_16(num_classes: usize) -> Self {
104 Self {
105 image_size: 224,
106 patch_size: 16,
107 num_classes,
108 hidden_dim: 512,
109 num_blocks: 8,
110 token_mlp_dim: 256,
111 channel_mlp_dim: 2048,
112 dropout_rate: 0.0,
113 in_channels: 3,
114 }
115 }
116
117 pub fn mixer_b_32(num_classes: usize) -> Self {
119 Self {
120 image_size: 224,
121 patch_size: 32,
122 num_classes,
123 hidden_dim: 768,
124 num_blocks: 12,
125 token_mlp_dim: 384,
126 channel_mlp_dim: 3072,
127 dropout_rate: 0.0,
128 in_channels: 3,
129 }
130 }
131
132 pub fn mixer_b_16(num_classes: usize) -> Self {
134 Self {
135 image_size: 224,
136 patch_size: 16,
137 num_classes,
138 hidden_dim: 768,
139 num_blocks: 12,
140 token_mlp_dim: 384,
141 channel_mlp_dim: 3072,
142 dropout_rate: 0.0,
143 in_channels: 3,
144 }
145 }
146
147 pub fn num_patches(&self) -> usize {
149 (self.image_size / self.patch_size).pow(2)
150 }
151}
152
153#[derive(Debug, Clone)]
158pub struct MixerMLP<
159 F: Float
160 + Debug
161 + ScalarOperand
162 + Send
163 + Sync
164 + NumAssign
165 + scirs2_core::simd_ops::SimdUnifiedOps
166 + 'static,
167> {
168 fc1: Dense<F>,
170 fc2: Dense<F>,
172 dropout: Dropout<F>,
174}
175
176impl<
177 F: Float
178 + Debug
179 + ScalarOperand
180 + Send
181 + Sync
182 + NumAssign
183 + scirs2_core::simd_ops::SimdUnifiedOps
184 + 'static,
185 > MixerMLP<F>
186{
187 pub fn new<R: Rng + Clone + Send + Sync + 'static>(
196 in_features: usize,
197 hidden_features: usize,
198 out_features: usize,
199 dropout_rate: f64,
200 rng: &mut R,
201 ) -> Result<Self> {
202 let fc1 = Dense::new(in_features, hidden_features, Some("gelu"), rng)?;
203 let fc2 = Dense::new(hidden_features, out_features, None, rng)?;
204 let dropout = Dropout::new(dropout_rate, rng)?;
205
206 Ok(Self { fc1, fc2, dropout })
207 }
208
209 pub fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
211 let x = self.fc1.forward(input)?;
212 let x = self.dropout.forward(&x)?;
213 let x = self.fc2.forward(&x)?;
214 self.dropout.forward(&x)
215 }
216}
217
218#[derive(Debug, Clone)]
228pub struct MixerBlock<
229 F: Float
230 + Debug
231 + ScalarOperand
232 + Send
233 + Sync
234 + NumAssign
235 + scirs2_core::simd_ops::SimdUnifiedOps
236 + 'static,
237> {
238 norm1: LayerNorm<F>,
240 token_mixing: MixerMLP<F>,
242 norm2: LayerNorm<F>,
244 channel_mixing: MixerMLP<F>,
246 num_patches: usize,
248 hidden_dim: usize,
250}
251
252impl<
253 F: Float
254 + Debug
255 + ScalarOperand
256 + Send
257 + Sync
258 + NumAssign
259 + scirs2_core::simd_ops::SimdUnifiedOps
260 + 'static,
261 > MixerBlock<F>
262{
263 pub fn new<R: Rng + Clone + Send + Sync + 'static>(
273 num_patches: usize,
274 hidden_dim: usize,
275 token_mlp_dim: usize,
276 channel_mlp_dim: usize,
277 dropout_rate: f64,
278 rng: &mut R,
279 ) -> Result<Self> {
280 let norm1 = LayerNorm::new(hidden_dim, 1e-6, rng)?;
281 let token_mixing =
282 MixerMLP::new(num_patches, token_mlp_dim, num_patches, dropout_rate, rng)?;
283 let norm2 = LayerNorm::new(hidden_dim, 1e-6, rng)?;
284 let channel_mixing =
285 MixerMLP::new(hidden_dim, channel_mlp_dim, hidden_dim, dropout_rate, rng)?;
286
287 Ok(Self {
288 norm1,
289 token_mixing,
290 norm2,
291 channel_mixing,
292 num_patches,
293 hidden_dim,
294 })
295 }
296
297 pub fn forward(&self, input: &Array3<F>) -> Result<Array3<F>> {
302 let batch_size = input.shape()[0];
303
304 let normed1 = self.apply_layer_norm(&self.norm1, input)?;
309
310 let transposed = normed1.permuted_axes([0, 2, 1]);
312
313 let mut token_mixed = Array3::zeros(transposed.raw_dim());
315 for b in 0..batch_size {
316 let slice = transposed.slice(s![b, .., ..]).to_owned().into_dyn();
317 let mixed = self.token_mixing.forward(&slice)?;
318 let mixed_2d = mixed
319 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
320 .map_err(|e| {
321 NeuralError::InferenceError(format!("Failed to convert mixed to 2D: {}", e))
322 })?;
323 token_mixed.slice_mut(s![b, .., ..]).assign(&mixed_2d);
324 }
325
326 let token_mixed = token_mixed.permuted_axes([0, 2, 1]);
328
329 let x = input + &token_mixed;
331
332 let normed2 = self.apply_layer_norm(&self.norm2, &x)?;
334
335 let mut channel_mixed = Array3::zeros(normed2.raw_dim());
337 for b in 0..batch_size {
338 let slice = normed2.slice(s![b, .., ..]).to_owned().into_dyn();
339 let mixed = self.channel_mixing.forward(&slice)?;
340 let mixed_2d = mixed
341 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
342 .map_err(|e| {
343 NeuralError::InferenceError(format!("Failed to convert mixed to 2D: {}", e))
344 })?;
345 channel_mixed.slice_mut(s![b, .., ..]).assign(&mixed_2d);
346 }
347
348 Ok(&x + &channel_mixed)
350 }
351
352 fn apply_layer_norm(&self, norm: &LayerNorm<F>, input: &Array3<F>) -> Result<Array3<F>> {
354 let batch_size = input.shape()[0];
355 let seq_len = input.shape()[1];
356 let hidden_dim = input.shape()[2];
357
358 let mut output = Array3::zeros(input.raw_dim());
359
360 for b in 0..batch_size {
361 for s in 0..seq_len {
362 let slice = input.slice(s![b, s, ..]).to_owned().into_dyn();
363 let normed = norm.forward(&slice)?;
364 let normed_1d = normed
365 .into_dimensionality::<scirs2_core::ndarray::Ix1>()
366 .map_err(|e| {
367 NeuralError::InferenceError(format!(
368 "Failed to convert normed to 1D: {}",
369 e
370 ))
371 })?;
372 output.slice_mut(s![b, s, ..]).assign(&normed_1d);
373 }
374 }
375
376 Ok(output)
377 }
378}
379
380#[derive(Debug)]
387pub struct MLPMixer<
388 F: Float
389 + Debug
390 + ScalarOperand
391 + Send
392 + Sync
393 + NumAssign
394 + scirs2_core::simd_ops::SimdUnifiedOps
395 + 'static,
396> {
397 config: MLPMixerConfig,
399 patch_embed: Dense<F>,
401 blocks: Vec<MixerBlock<F>>,
403 norm: LayerNorm<F>,
405 head: Dense<F>,
407}
408
409impl<
410 F: Float
411 + Debug
412 + ScalarOperand
413 + Send
414 + Sync
415 + NumAssign
416 + FromPrimitive
417 + scirs2_core::simd_ops::SimdUnifiedOps
418 + 'static,
419 > MLPMixer<F>
420{
421 pub fn new<R: Rng + Clone + Send + Sync + 'static>(
427 config: MLPMixerConfig,
428 rng: &mut R,
429 ) -> Result<Self> {
430 let num_patches = config.num_patches();
431 let patch_dim = config.in_channels * config.patch_size * config.patch_size;
432
433 let patch_embed = Dense::new(patch_dim, config.hidden_dim, None, rng)?;
435
436 let mut blocks = Vec::with_capacity(config.num_blocks);
438 for _ in 0..config.num_blocks {
439 blocks.push(MixerBlock::new(
440 num_patches,
441 config.hidden_dim,
442 config.token_mlp_dim,
443 config.channel_mlp_dim,
444 config.dropout_rate,
445 rng,
446 )?);
447 }
448
449 let norm = LayerNorm::new(config.hidden_dim, 1e-6, rng)?;
451
452 let head = Dense::new(config.hidden_dim, config.num_classes, None, rng)?;
454
455 Ok(Self {
456 config,
457 patch_embed,
458 blocks,
459 norm,
460 head,
461 })
462 }
463
464 fn extract_patches(&self, images: &Array<F, IxDyn>) -> Result<Array3<F>> {
472 let shape = images.shape();
473 if shape.len() != 4 {
474 return Err(NeuralError::InvalidArchitecture(format!(
475 "Expected 4D input [B, C, H, W], got {:?}",
476 shape
477 )));
478 }
479
480 let batch_size = shape[0];
481 let channels = shape[1];
482 let height = shape[2];
483 let width = shape[3];
484
485 let patch_size = self.config.patch_size;
486 let patches_h = height / patch_size;
487 let patches_w = width / patch_size;
488 let num_patches = patches_h * patches_w;
489 let patch_dim = channels * patch_size * patch_size;
490
491 let mut patches = Array3::zeros((batch_size, num_patches, patch_dim));
492
493 for b in 0..batch_size {
494 for ph in 0..patches_h {
495 for pw in 0..patches_w {
496 let patch_idx = ph * patches_w + pw;
497 let h_start = ph * patch_size;
498 let w_start = pw * patch_size;
499
500 let mut flat_idx = 0;
502 for c in 0..channels {
503 for h in 0..patch_size {
504 for w in 0..patch_size {
505 patches[[b, patch_idx, flat_idx]] =
506 images[[b, c, h_start + h, w_start + w]];
507 flat_idx += 1;
508 }
509 }
510 }
511 }
512 }
513 }
514
515 Ok(patches)
516 }
517
518 pub fn forward(&self, images: &Array<F, IxDyn>) -> Result<Array2<F>> {
526 let batch_size = images.shape()[0];
527
528 let patches = self.extract_patches(images)?;
530
531 let mut embedded = Array3::zeros((
533 batch_size,
534 self.config.num_patches(),
535 self.config.hidden_dim,
536 ));
537 for b in 0..batch_size {
538 let patch_slice = patches.slice(s![b, .., ..]).to_owned().into_dyn();
539 let emb = self.patch_embed.forward(&patch_slice)?;
540 let emb_2d = emb
541 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
542 .map_err(|e| {
543 NeuralError::InferenceError(format!("Failed to convert embedding to 2D: {}", e))
544 })?;
545 embedded.slice_mut(s![b, .., ..]).assign(&emb_2d);
546 }
547
548 let mut x = embedded;
550 for block in &self.blocks {
551 x = block.forward(&x)?;
552 }
553
554 let pooled = x.mean_axis(Axis(1)).ok_or_else(|| {
557 NeuralError::InferenceError("Failed to compute mean across patches".to_string())
558 })?;
559
560 let mut normed = Array2::zeros(pooled.raw_dim());
562 for b in 0..batch_size {
563 let slice = pooled.slice(s![b, ..]).to_owned().into_dyn();
564 let n = self.norm.forward(&slice)?;
565 let n_1d = n
566 .into_dimensionality::<scirs2_core::ndarray::Ix1>()
567 .map_err(|e| {
568 NeuralError::InferenceError(format!("Failed to convert normed to 1D: {}", e))
569 })?;
570 normed.slice_mut(s![b, ..]).assign(&n_1d);
571 }
572
573 let mut output = Array2::zeros((batch_size, self.config.num_classes));
575 for b in 0..batch_size {
576 let slice = normed.slice(s![b, ..]).to_owned().into_dyn();
577 let logits = self.head.forward(&slice)?;
578 if logits.ndim() == 2 && logits.shape()[0] == 1 {
580 let logits_1d = logits
581 .into_shape_with_order(scirs2_core::ndarray::IxDyn(&[self.config.num_classes]))
582 .map_err(|e| {
583 NeuralError::InferenceError(format!(
584 "Failed to reshape logits to 1D: {}",
585 e
586 ))
587 })?
588 .into_dimensionality::<scirs2_core::ndarray::Ix1>()
589 .map_err(|e| {
590 NeuralError::InferenceError(format!(
591 "Failed to convert logits to 1D: {}",
592 e
593 ))
594 })?;
595 output.slice_mut(s![b, ..]).assign(&logits_1d);
596 } else {
597 let logits_1d = logits
598 .into_dimensionality::<scirs2_core::ndarray::Ix1>()
599 .map_err(|e| {
600 NeuralError::InferenceError(format!(
601 "Failed to convert logits to 1D: {}",
602 e
603 ))
604 })?;
605 output.slice_mut(s![b, ..]).assign(&logits_1d);
606 }
607 }
608
609 Ok(output)
610 }
611
612 pub fn config(&self) -> &MLPMixerConfig {
614 &self.config
615 }
616
617 pub fn num_parameters(&self) -> usize {
619 let num_patches = self.config.num_patches();
620 let patch_dim = self.config.in_channels * self.config.patch_size * self.config.patch_size;
621 let hidden_dim = self.config.hidden_dim;
622
623 let patch_embed_params = patch_dim * hidden_dim + hidden_dim;
625
626 let token_mlp_params = (num_patches * self.config.token_mlp_dim
628 + self.config.token_mlp_dim)
629 + (self.config.token_mlp_dim * num_patches + num_patches);
630 let channel_mlp_params = (hidden_dim * self.config.channel_mlp_dim
631 + self.config.channel_mlp_dim)
632 + (self.config.channel_mlp_dim * hidden_dim + hidden_dim);
633 let norm_params = 2 * hidden_dim; let block_params = 2 * norm_params + token_mlp_params + channel_mlp_params;
635 let all_blocks_params = self.config.num_blocks * block_params;
636
637 let head_params = hidden_dim * self.config.num_classes + self.config.num_classes;
639
640 let final_norm_params = 2 * hidden_dim;
642
643 patch_embed_params + all_blocks_params + head_params + final_norm_params
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use scirs2_core::ndarray::Array4;
651 use scirs2_core::random::rngs::SmallRng;
652 use scirs2_core::random::SeedableRng;
653
654 #[test]
655 fn test_mlp_mixer_config_default() {
656 let config = MLPMixerConfig::default();
657 assert_eq!(config.image_size, 224);
658 assert_eq!(config.patch_size, 16);
659 assert_eq!(config.num_patches(), 196); }
661
662 #[test]
663 fn test_mlp_mixer_config_variants() {
664 let s32 = MLPMixerConfig::mixer_s_32(10);
665 assert_eq!(s32.patch_size, 32);
666 assert_eq!(s32.hidden_dim, 512);
667 assert_eq!(s32.num_patches(), 49); let b16 = MLPMixerConfig::mixer_b_16(100);
670 assert_eq!(b16.patch_size, 16);
671 assert_eq!(b16.hidden_dim, 768);
672 assert_eq!(b16.num_blocks, 12);
673 }
674
675 #[test]
676 fn test_mixer_mlp() {
677 let mut rng = SmallRng::seed_from_u64(42);
678 let mlp = MixerMLP::<f32>::new(64, 128, 64, 0.0, &mut rng).expect("Operation failed");
679
680 let input = Array2::<f32>::zeros((10, 64)).into_dyn();
681 let output = mlp.forward(&input).expect("Operation failed");
682
683 assert_eq!(output.shape(), &[10, 64]);
684 }
685
686 #[test]
687 fn test_mixer_block() {
688 let mut rng = SmallRng::seed_from_u64(42);
689 let block = MixerBlock::<f32>::new(
690 16, 64, 32, 128, 0.0, &mut rng,
696 )
697 .expect("Operation failed");
698
699 let input = Array3::<f32>::zeros((2, 16, 64));
700 let output = block.forward(&input).expect("Operation failed");
701
702 assert_eq!(output.shape(), input.shape());
703 }
704
705 #[test]
706 fn test_mlp_mixer_small() {
707 let mut rng = SmallRng::seed_from_u64(42);
708
709 let config = MLPMixerConfig {
711 image_size: 32,
712 patch_size: 8,
713 num_classes: 10,
714 hidden_dim: 32,
715 num_blocks: 2,
716 token_mlp_dim: 16,
717 channel_mlp_dim: 64,
718 dropout_rate: 0.0,
719 in_channels: 3,
720 };
721
722 let mixer = MLPMixer::<f32>::new(config.clone(), &mut rng).expect("Operation failed");
723
724 let images = Array4::<f32>::zeros((2, 3, 32, 32)).into_dyn();
726 let output = mixer.forward(&images).expect("Operation failed");
727
728 assert_eq!(output.shape(), &[2, 10]);
729 }
730
731 #[test]
732 fn test_extract_patches() {
733 let mut rng = SmallRng::seed_from_u64(42);
734
735 let config = MLPMixerConfig {
736 image_size: 8,
737 patch_size: 4,
738 num_classes: 2,
739 hidden_dim: 16,
740 num_blocks: 1,
741 token_mlp_dim: 8,
742 channel_mlp_dim: 32,
743 dropout_rate: 0.0,
744 in_channels: 1,
745 };
746
747 let mixer = MLPMixer::<f32>::new(config.clone(), &mut rng).expect("Operation failed");
748
749 let mut images = Array4::<f32>::zeros((1, 1, 8, 8));
751 for h in 0..8 {
752 for w in 0..8 {
753 images[[0, 0, h, w]] = (h * 8 + w) as f32;
754 }
755 }
756
757 let patches = mixer
758 .extract_patches(&images.into_dyn())
759 .expect("Operation failed");
760
761 assert_eq!(patches.shape(), &[1, 4, 16]);
763
764 assert_eq!(patches[[0, 0, 0]], 0.0); }
768
769 #[test]
770 fn test_num_parameters() {
771 let config = MLPMixerConfig::mixer_s_16(1000);
772 let mut rng = SmallRng::seed_from_u64(42);
773 let mixer = MLPMixer::<f32>::new(config, &mut rng).expect("Operation failed");
774
775 let params = mixer.num_parameters();
776 assert!(params > 0);
777 println!("Estimated parameters: {}", params);
780 }
781}