Skip to main content

scirs2_neural/models/architectures/
mlp_mixer.rs

1//! MLP-Mixer Architecture Implementation
2//!
3//! This module implements the MLP-Mixer architecture as described in:
4//! "MLP-Mixer: An all-MLP Architecture for Vision" (Tolstikhin et al., 2021)
5//!
6//! MLP-Mixer is an architecture based purely on multi-layer perceptrons (MLPs),
7//! that contains two types of mixing layers:
8//! - Token-mixing MLPs: Allow communication between different spatial locations
9//! - Channel-mixing MLPs: Allow communication between different channels/features
10//!
11//! # Architecture Overview
12//!
13//! 1. **Patch Embedding**: Image is split into patches, each linearly projected
14//! 2. **Mixer Layers**: Alternating token-mixing and channel-mixing MLPs
15//! 3. **Classification Head**: Global average pooling followed by linear classifier
16//!
17//! # Examples
18//!
19//! ```rust
20//! use scirs2_neural::models::architectures::{MLPMixer, MLPMixerConfig};
21//! use scirs2_core::random::SeedableRng;
22//!
23//! let config = MLPMixerConfig {
24//!     image_size: 224,
25//!     patch_size: 16,
26//!     num_classes: 1000,
27//!     hidden_dim: 512,
28//!     num_blocks: 8,
29//!     token_mlp_dim: 256,
30//!     channel_mlp_dim: 2048,
31//!     dropout_rate: 0.0,
32//!     in_channels: 3,
33//! };
34//!
35//! let mut rng = scirs2_core::random::rngs::SmallRng::seed_from_u64(42);
36//! let mixer = MLPMixer::<f32>::new(config, &mut rng).expect("Operation failed");
37//! ```
38
39use crate::error::{NeuralError, Result};
40use crate::layers::{Dense, Dropout, Layer, LayerNorm};
41use scirs2_core::ndarray::{s, Array, Array2, Array3, Axis, IxDyn, ScalarOperand};
42use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
43use scirs2_core::random::{Rng, RngExt};
44use serde::{Deserialize, Serialize};
45use std::fmt::Debug;
46
47/// Configuration for the MLP-Mixer model
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MLPMixerConfig {
50    /// Input image size (assumes square images)
51    pub image_size: usize,
52    /// Size of each patch
53    pub patch_size: usize,
54    /// Number of output classes
55    pub num_classes: usize,
56    /// Hidden dimension (channel dimension after patch embedding)
57    pub hidden_dim: usize,
58    /// Number of Mixer blocks
59    pub num_blocks: usize,
60    /// Dimension of the token-mixing MLP
61    pub token_mlp_dim: usize,
62    /// Dimension of the channel-mixing MLP
63    pub channel_mlp_dim: usize,
64    /// Dropout rate
65    pub dropout_rate: f64,
66    /// Number of input channels (3 for RGB images)
67    pub in_channels: usize,
68}
69
70impl Default for MLPMixerConfig {
71    fn default() -> Self {
72        Self {
73            image_size: 224,
74            patch_size: 16,
75            num_classes: 1000,
76            hidden_dim: 512,
77            num_blocks: 8,
78            token_mlp_dim: 256,
79            channel_mlp_dim: 2048,
80            dropout_rate: 0.0,
81            in_channels: 3,
82        }
83    }
84}
85
86impl MLPMixerConfig {
87    /// Create a Mixer-S/32 configuration
88    pub fn mixer_s_32(num_classes: usize) -> Self {
89        Self {
90            image_size: 224,
91            patch_size: 32,
92            num_classes,
93            hidden_dim: 512,
94            num_blocks: 8,
95            token_mlp_dim: 256,
96            channel_mlp_dim: 2048,
97            dropout_rate: 0.0,
98            in_channels: 3,
99        }
100    }
101
102    /// Create a Mixer-S/16 configuration
103    pub fn mixer_s_16(num_classes: usize) -> Self {
104        Self {
105            image_size: 224,
106            patch_size: 16,
107            num_classes,
108            hidden_dim: 512,
109            num_blocks: 8,
110            token_mlp_dim: 256,
111            channel_mlp_dim: 2048,
112            dropout_rate: 0.0,
113            in_channels: 3,
114        }
115    }
116
117    /// Create a Mixer-B/32 configuration
118    pub fn mixer_b_32(num_classes: usize) -> Self {
119        Self {
120            image_size: 224,
121            patch_size: 32,
122            num_classes,
123            hidden_dim: 768,
124            num_blocks: 12,
125            token_mlp_dim: 384,
126            channel_mlp_dim: 3072,
127            dropout_rate: 0.0,
128            in_channels: 3,
129        }
130    }
131
132    /// Create a Mixer-B/16 configuration
133    pub fn mixer_b_16(num_classes: usize) -> Self {
134        Self {
135            image_size: 224,
136            patch_size: 16,
137            num_classes,
138            hidden_dim: 768,
139            num_blocks: 12,
140            token_mlp_dim: 384,
141            channel_mlp_dim: 3072,
142            dropout_rate: 0.0,
143            in_channels: 3,
144        }
145    }
146
147    /// Get the number of patches
148    pub fn num_patches(&self) -> usize {
149        (self.image_size / self.patch_size).pow(2)
150    }
151}
152
153/// A simple MLP block with GELU activation
154///
155/// This is the building block for both token-mixing and channel-mixing operations.
156/// Structure: Linear -> GELU -> Dropout -> Linear -> Dropout
157#[derive(Debug, Clone)]
158pub struct MixerMLP<
159    F: Float
160        + Debug
161        + ScalarOperand
162        + Send
163        + Sync
164        + NumAssign
165        + scirs2_core::simd_ops::SimdUnifiedOps
166        + 'static,
167> {
168    /// First linear layer
169    fc1: Dense<F>,
170    /// Second linear layer
171    fc2: Dense<F>,
172    /// Dropout layer
173    dropout: Dropout<F>,
174}
175
176impl<
177        F: Float
178            + Debug
179            + ScalarOperand
180            + Send
181            + Sync
182            + NumAssign
183            + scirs2_core::simd_ops::SimdUnifiedOps
184            + 'static,
185    > MixerMLP<F>
186{
187    /// Create a new MixerMLP
188    ///
189    /// # Arguments
190    /// * `in_features` - Input dimension
191    /// * `hidden_features` - Hidden dimension
192    /// * `out_features` - Output dimension
193    /// * `dropout_rate` - Dropout probability
194    /// * `rng` - Random number generator
195    pub fn new<R: Rng + Clone + Send + Sync + 'static>(
196        in_features: usize,
197        hidden_features: usize,
198        out_features: usize,
199        dropout_rate: f64,
200        rng: &mut R,
201    ) -> Result<Self> {
202        let fc1 = Dense::new(in_features, hidden_features, Some("gelu"), rng)?;
203        let fc2 = Dense::new(hidden_features, out_features, None, rng)?;
204        let dropout = Dropout::new(dropout_rate, rng)?;
205
206        Ok(Self { fc1, fc2, dropout })
207    }
208
209    /// Forward pass through the MLP
210    pub fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
211        let x = self.fc1.forward(input)?;
212        let x = self.dropout.forward(&x)?;
213        let x = self.fc2.forward(&x)?;
214        self.dropout.forward(&x)
215    }
216}
217
218/// A single Mixer block containing token-mixing and channel-mixing
219///
220/// Each block consists of:
221/// 1. Layer normalization
222/// 2. Token-mixing MLP (across spatial dimension)
223/// 3. Skip connection
224/// 4. Layer normalization
225/// 5. Channel-mixing MLP (across channel dimension)
226/// 6. Skip connection
227#[derive(Debug, Clone)]
228pub struct MixerBlock<
229    F: Float
230        + Debug
231        + ScalarOperand
232        + Send
233        + Sync
234        + NumAssign
235        + scirs2_core::simd_ops::SimdUnifiedOps
236        + 'static,
237> {
238    /// Layer norm before token-mixing
239    norm1: LayerNorm<F>,
240    /// Token-mixing MLP
241    token_mixing: MixerMLP<F>,
242    /// Layer norm before channel-mixing
243    norm2: LayerNorm<F>,
244    /// Channel-mixing MLP
245    channel_mixing: MixerMLP<F>,
246    /// Number of patches (tokens)
247    num_patches: usize,
248    /// Hidden dimension
249    hidden_dim: usize,
250}
251
252impl<
253        F: Float
254            + Debug
255            + ScalarOperand
256            + Send
257            + Sync
258            + NumAssign
259            + scirs2_core::simd_ops::SimdUnifiedOps
260            + 'static,
261    > MixerBlock<F>
262{
263    /// Create a new MixerBlock
264    ///
265    /// # Arguments
266    /// * `num_patches` - Number of patches (spatial tokens)
267    /// * `hidden_dim` - Hidden/channel dimension
268    /// * `token_mlp_dim` - Token-mixing MLP hidden dimension
269    /// * `channel_mlp_dim` - Channel-mixing MLP hidden dimension
270    /// * `dropout_rate` - Dropout probability
271    /// * `rng` - Random number generator
272    pub fn new<R: Rng + Clone + Send + Sync + 'static>(
273        num_patches: usize,
274        hidden_dim: usize,
275        token_mlp_dim: usize,
276        channel_mlp_dim: usize,
277        dropout_rate: f64,
278        rng: &mut R,
279    ) -> Result<Self> {
280        let norm1 = LayerNorm::new(hidden_dim, 1e-6, rng)?;
281        let token_mixing =
282            MixerMLP::new(num_patches, token_mlp_dim, num_patches, dropout_rate, rng)?;
283        let norm2 = LayerNorm::new(hidden_dim, 1e-6, rng)?;
284        let channel_mixing =
285            MixerMLP::new(hidden_dim, channel_mlp_dim, hidden_dim, dropout_rate, rng)?;
286
287        Ok(Self {
288            norm1,
289            token_mixing,
290            norm2,
291            channel_mixing,
292            num_patches,
293            hidden_dim,
294        })
295    }
296
297    /// Forward pass through the Mixer block
298    ///
299    /// # Arguments
300    /// * `input` - Input tensor of shape [batch_size, num_patches, hidden_dim]
301    pub fn forward(&self, input: &Array3<F>) -> Result<Array3<F>> {
302        let batch_size = input.shape()[0];
303
304        // Token-mixing: transpose, apply MLP, transpose back
305        // Input: [B, S, C] -> [B, C, S] -> MLP -> [B, C, S] -> [B, S, C]
306
307        // First, apply layer norm along the last axis
308        let normed1 = self.apply_layer_norm(&self.norm1, input)?;
309
310        // Transpose for token-mixing: [B, S, C] -> [B, C, S]
311        let transposed = normed1.permuted_axes([0, 2, 1]);
312
313        // Apply token-mixing MLP
314        let mut token_mixed = Array3::zeros(transposed.raw_dim());
315        for b in 0..batch_size {
316            let slice = transposed.slice(s![b, .., ..]).to_owned().into_dyn();
317            let mixed = self.token_mixing.forward(&slice)?;
318            let mixed_2d = mixed
319                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
320                .map_err(|e| {
321                    NeuralError::InferenceError(format!("Failed to convert mixed to 2D: {}", e))
322                })?;
323            token_mixed.slice_mut(s![b, .., ..]).assign(&mixed_2d);
324        }
325
326        // Transpose back: [B, C, S] -> [B, S, C]
327        let token_mixed = token_mixed.permuted_axes([0, 2, 1]);
328
329        // Skip connection
330        let x = input + &token_mixed;
331
332        // Channel-mixing
333        let normed2 = self.apply_layer_norm(&self.norm2, &x)?;
334
335        // Apply channel-mixing MLP (operates on last dimension)
336        let mut channel_mixed = Array3::zeros(normed2.raw_dim());
337        for b in 0..batch_size {
338            let slice = normed2.slice(s![b, .., ..]).to_owned().into_dyn();
339            let mixed = self.channel_mixing.forward(&slice)?;
340            let mixed_2d = mixed
341                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
342                .map_err(|e| {
343                    NeuralError::InferenceError(format!("Failed to convert mixed to 2D: {}", e))
344                })?;
345            channel_mixed.slice_mut(s![b, .., ..]).assign(&mixed_2d);
346        }
347
348        // Skip connection
349        Ok(&x + &channel_mixed)
350    }
351
352    /// Apply layer norm to a 3D tensor
353    fn apply_layer_norm(&self, norm: &LayerNorm<F>, input: &Array3<F>) -> Result<Array3<F>> {
354        let batch_size = input.shape()[0];
355        let seq_len = input.shape()[1];
356        let hidden_dim = input.shape()[2];
357
358        let mut output = Array3::zeros(input.raw_dim());
359
360        for b in 0..batch_size {
361            for s in 0..seq_len {
362                let slice = input.slice(s![b, s, ..]).to_owned().into_dyn();
363                let normed = norm.forward(&slice)?;
364                let normed_1d = normed
365                    .into_dimensionality::<scirs2_core::ndarray::Ix1>()
366                    .map_err(|e| {
367                        NeuralError::InferenceError(format!(
368                            "Failed to convert normed to 1D: {}",
369                            e
370                        ))
371                    })?;
372                output.slice_mut(s![b, s, ..]).assign(&normed_1d);
373            }
374        }
375
376        Ok(output)
377    }
378}
379
380/// MLP-Mixer model for image classification
381///
382/// The model consists of:
383/// 1. Patch embedding layer
384/// 2. Multiple Mixer blocks
385/// 3. Classification head
386#[derive(Debug)]
387pub struct MLPMixer<
388    F: Float
389        + Debug
390        + ScalarOperand
391        + Send
392        + Sync
393        + NumAssign
394        + scirs2_core::simd_ops::SimdUnifiedOps
395        + 'static,
396> {
397    /// Model configuration
398    config: MLPMixerConfig,
399    /// Patch embedding projection
400    patch_embed: Dense<F>,
401    /// Mixer blocks
402    blocks: Vec<MixerBlock<F>>,
403    /// Final layer norm
404    norm: LayerNorm<F>,
405    /// Classification head
406    head: Dense<F>,
407}
408
409impl<
410        F: Float
411            + Debug
412            + ScalarOperand
413            + Send
414            + Sync
415            + NumAssign
416            + FromPrimitive
417            + scirs2_core::simd_ops::SimdUnifiedOps
418            + 'static,
419    > MLPMixer<F>
420{
421    /// Create a new MLPMixer model
422    ///
423    /// # Arguments
424    /// * `config` - Model configuration
425    /// * `rng` - Random number generator
426    pub fn new<R: Rng + Clone + Send + Sync + 'static>(
427        config: MLPMixerConfig,
428        rng: &mut R,
429    ) -> Result<Self> {
430        let num_patches = config.num_patches();
431        let patch_dim = config.in_channels * config.patch_size * config.patch_size;
432
433        // Patch embedding: flatten patch -> hidden_dim
434        let patch_embed = Dense::new(patch_dim, config.hidden_dim, None, rng)?;
435
436        // Create Mixer blocks
437        let mut blocks = Vec::with_capacity(config.num_blocks);
438        for _ in 0..config.num_blocks {
439            blocks.push(MixerBlock::new(
440                num_patches,
441                config.hidden_dim,
442                config.token_mlp_dim,
443                config.channel_mlp_dim,
444                config.dropout_rate,
445                rng,
446            )?);
447        }
448
449        // Final layer norm
450        let norm = LayerNorm::new(config.hidden_dim, 1e-6, rng)?;
451
452        // Classification head
453        let head = Dense::new(config.hidden_dim, config.num_classes, None, rng)?;
454
455        Ok(Self {
456            config,
457            patch_embed,
458            blocks,
459            norm,
460            head,
461        })
462    }
463
464    /// Extract patches from an image batch
465    ///
466    /// # Arguments
467    /// * `images` - Image batch of shape [B, C, H, W]
468    ///
469    /// # Returns
470    /// Patches of shape [B, num_patches, patch_dim]
471    fn extract_patches(&self, images: &Array<F, IxDyn>) -> Result<Array3<F>> {
472        let shape = images.shape();
473        if shape.len() != 4 {
474            return Err(NeuralError::InvalidArchitecture(format!(
475                "Expected 4D input [B, C, H, W], got {:?}",
476                shape
477            )));
478        }
479
480        let batch_size = shape[0];
481        let channels = shape[1];
482        let height = shape[2];
483        let width = shape[3];
484
485        let patch_size = self.config.patch_size;
486        let patches_h = height / patch_size;
487        let patches_w = width / patch_size;
488        let num_patches = patches_h * patches_w;
489        let patch_dim = channels * patch_size * patch_size;
490
491        let mut patches = Array3::zeros((batch_size, num_patches, patch_dim));
492
493        for b in 0..batch_size {
494            for ph in 0..patches_h {
495                for pw in 0..patches_w {
496                    let patch_idx = ph * patches_w + pw;
497                    let h_start = ph * patch_size;
498                    let w_start = pw * patch_size;
499
500                    // Extract and flatten the patch
501                    let mut flat_idx = 0;
502                    for c in 0..channels {
503                        for h in 0..patch_size {
504                            for w in 0..patch_size {
505                                patches[[b, patch_idx, flat_idx]] =
506                                    images[[b, c, h_start + h, w_start + w]];
507                                flat_idx += 1;
508                            }
509                        }
510                    }
511                }
512            }
513        }
514
515        Ok(patches)
516    }
517
518    /// Forward pass through the model
519    ///
520    /// # Arguments
521    /// * `images` - Image batch of shape [B, C, H, W]
522    ///
523    /// # Returns
524    /// Logits of shape [B, num_classes]
525    pub fn forward(&self, images: &Array<F, IxDyn>) -> Result<Array2<F>> {
526        let batch_size = images.shape()[0];
527
528        // Extract patches: [B, num_patches, patch_dim]
529        let patches = self.extract_patches(images)?;
530
531        // Patch embedding: [B, num_patches, hidden_dim]
532        let mut embedded = Array3::zeros((
533            batch_size,
534            self.config.num_patches(),
535            self.config.hidden_dim,
536        ));
537        for b in 0..batch_size {
538            let patch_slice = patches.slice(s![b, .., ..]).to_owned().into_dyn();
539            let emb = self.patch_embed.forward(&patch_slice)?;
540            let emb_2d = emb
541                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
542                .map_err(|e| {
543                    NeuralError::InferenceError(format!("Failed to convert embedding to 2D: {}", e))
544                })?;
545            embedded.slice_mut(s![b, .., ..]).assign(&emb_2d);
546        }
547
548        // Apply Mixer blocks
549        let mut x = embedded;
550        for block in &self.blocks {
551            x = block.forward(&x)?;
552        }
553
554        // Global average pooling over spatial dimension
555        // [B, num_patches, hidden_dim] -> [B, hidden_dim]
556        let pooled = x.mean_axis(Axis(1)).ok_or_else(|| {
557            NeuralError::InferenceError("Failed to compute mean across patches".to_string())
558        })?;
559
560        // Apply final layer norm
561        let mut normed = Array2::zeros(pooled.raw_dim());
562        for b in 0..batch_size {
563            let slice = pooled.slice(s![b, ..]).to_owned().into_dyn();
564            let n = self.norm.forward(&slice)?;
565            let n_1d = n
566                .into_dimensionality::<scirs2_core::ndarray::Ix1>()
567                .map_err(|e| {
568                    NeuralError::InferenceError(format!("Failed to convert normed to 1D: {}", e))
569                })?;
570            normed.slice_mut(s![b, ..]).assign(&n_1d);
571        }
572
573        // Classification head
574        let mut output = Array2::zeros((batch_size, self.config.num_classes));
575        for b in 0..batch_size {
576            let slice = normed.slice(s![b, ..]).to_owned().into_dyn();
577            let logits = self.head.forward(&slice)?;
578            // Dense layer may return [1, num_classes] for 1D input, so handle both shapes
579            if logits.ndim() == 2 && logits.shape()[0] == 1 {
580                let logits_1d = logits
581                    .into_shape_with_order(scirs2_core::ndarray::IxDyn(&[self.config.num_classes]))
582                    .map_err(|e| {
583                        NeuralError::InferenceError(format!(
584                            "Failed to reshape logits to 1D: {}",
585                            e
586                        ))
587                    })?
588                    .into_dimensionality::<scirs2_core::ndarray::Ix1>()
589                    .map_err(|e| {
590                        NeuralError::InferenceError(format!(
591                            "Failed to convert logits to 1D: {}",
592                            e
593                        ))
594                    })?;
595                output.slice_mut(s![b, ..]).assign(&logits_1d);
596            } else {
597                let logits_1d = logits
598                    .into_dimensionality::<scirs2_core::ndarray::Ix1>()
599                    .map_err(|e| {
600                        NeuralError::InferenceError(format!(
601                            "Failed to convert logits to 1D: {}",
602                            e
603                        ))
604                    })?;
605                output.slice_mut(s![b, ..]).assign(&logits_1d);
606            }
607        }
608
609        Ok(output)
610    }
611
612    /// Get the configuration
613    pub fn config(&self) -> &MLPMixerConfig {
614        &self.config
615    }
616
617    /// Get the number of parameters (approximate)
618    pub fn num_parameters(&self) -> usize {
619        let num_patches = self.config.num_patches();
620        let patch_dim = self.config.in_channels * self.config.patch_size * self.config.patch_size;
621        let hidden_dim = self.config.hidden_dim;
622
623        // Patch embedding
624        let patch_embed_params = patch_dim * hidden_dim + hidden_dim;
625
626        // Mixer blocks
627        let token_mlp_params = (num_patches * self.config.token_mlp_dim
628            + self.config.token_mlp_dim)
629            + (self.config.token_mlp_dim * num_patches + num_patches);
630        let channel_mlp_params = (hidden_dim * self.config.channel_mlp_dim
631            + self.config.channel_mlp_dim)
632            + (self.config.channel_mlp_dim * hidden_dim + hidden_dim);
633        let norm_params = 2 * hidden_dim; // gamma and beta
634        let block_params = 2 * norm_params + token_mlp_params + channel_mlp_params;
635        let all_blocks_params = self.config.num_blocks * block_params;
636
637        // Head
638        let head_params = hidden_dim * self.config.num_classes + self.config.num_classes;
639
640        // Final norm
641        let final_norm_params = 2 * hidden_dim;
642
643        patch_embed_params + all_blocks_params + head_params + final_norm_params
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use scirs2_core::ndarray::Array4;
651    use scirs2_core::random::rngs::SmallRng;
652    use scirs2_core::random::SeedableRng;
653
654    #[test]
655    fn test_mlp_mixer_config_default() {
656        let config = MLPMixerConfig::default();
657        assert_eq!(config.image_size, 224);
658        assert_eq!(config.patch_size, 16);
659        assert_eq!(config.num_patches(), 196); // 14 * 14
660    }
661
662    #[test]
663    fn test_mlp_mixer_config_variants() {
664        let s32 = MLPMixerConfig::mixer_s_32(10);
665        assert_eq!(s32.patch_size, 32);
666        assert_eq!(s32.hidden_dim, 512);
667        assert_eq!(s32.num_patches(), 49); // 7 * 7
668
669        let b16 = MLPMixerConfig::mixer_b_16(100);
670        assert_eq!(b16.patch_size, 16);
671        assert_eq!(b16.hidden_dim, 768);
672        assert_eq!(b16.num_blocks, 12);
673    }
674
675    #[test]
676    fn test_mixer_mlp() {
677        let mut rng = SmallRng::seed_from_u64(42);
678        let mlp = MixerMLP::<f32>::new(64, 128, 64, 0.0, &mut rng).expect("Operation failed");
679
680        let input = Array2::<f32>::zeros((10, 64)).into_dyn();
681        let output = mlp.forward(&input).expect("Operation failed");
682
683        assert_eq!(output.shape(), &[10, 64]);
684    }
685
686    #[test]
687    fn test_mixer_block() {
688        let mut rng = SmallRng::seed_from_u64(42);
689        let block = MixerBlock::<f32>::new(
690            16,  // num_patches
691            64,  // hidden_dim
692            32,  // token_mlp_dim
693            128, // channel_mlp_dim
694            0.0, // dropout
695            &mut rng,
696        )
697        .expect("Operation failed");
698
699        let input = Array3::<f32>::zeros((2, 16, 64));
700        let output = block.forward(&input).expect("Operation failed");
701
702        assert_eq!(output.shape(), input.shape());
703    }
704
705    #[test]
706    fn test_mlp_mixer_small() {
707        let mut rng = SmallRng::seed_from_u64(42);
708
709        // Small config for testing
710        let config = MLPMixerConfig {
711            image_size: 32,
712            patch_size: 8,
713            num_classes: 10,
714            hidden_dim: 32,
715            num_blocks: 2,
716            token_mlp_dim: 16,
717            channel_mlp_dim: 64,
718            dropout_rate: 0.0,
719            in_channels: 3,
720        };
721
722        let mixer = MLPMixer::<f32>::new(config.clone(), &mut rng).expect("Operation failed");
723
724        // Test forward pass
725        let images = Array4::<f32>::zeros((2, 3, 32, 32)).into_dyn();
726        let output = mixer.forward(&images).expect("Operation failed");
727
728        assert_eq!(output.shape(), &[2, 10]);
729    }
730
731    #[test]
732    fn test_extract_patches() {
733        let mut rng = SmallRng::seed_from_u64(42);
734
735        let config = MLPMixerConfig {
736            image_size: 8,
737            patch_size: 4,
738            num_classes: 2,
739            hidden_dim: 16,
740            num_blocks: 1,
741            token_mlp_dim: 8,
742            channel_mlp_dim: 32,
743            dropout_rate: 0.0,
744            in_channels: 1,
745        };
746
747        let mixer = MLPMixer::<f32>::new(config.clone(), &mut rng).expect("Operation failed");
748
749        // Create test image: 1 batch, 1 channel, 8x8
750        let mut images = Array4::<f32>::zeros((1, 1, 8, 8));
751        for h in 0..8 {
752            for w in 0..8 {
753                images[[0, 0, h, w]] = (h * 8 + w) as f32;
754            }
755        }
756
757        let patches = mixer
758            .extract_patches(&images.into_dyn())
759            .expect("Operation failed");
760
761        // Should have 4 patches (2x2 grid of 4x4 patches)
762        assert_eq!(patches.shape(), &[1, 4, 16]);
763
764        // First patch (top-left) should contain values 0-15 from a 4x4 region
765        // Actually it contains values from positions (0,0) to (3,3)
766        assert_eq!(patches[[0, 0, 0]], 0.0); // Top-left of first patch
767    }
768
769    #[test]
770    fn test_num_parameters() {
771        let config = MLPMixerConfig::mixer_s_16(1000);
772        let mut rng = SmallRng::seed_from_u64(42);
773        let mixer = MLPMixer::<f32>::new(config, &mut rng).expect("Operation failed");
774
775        let params = mixer.num_parameters();
776        assert!(params > 0);
777        // Mixer-S/16 should have roughly 18M parameters
778        // Our calculation is approximate
779        println!("Estimated parameters: {}", params);
780    }
781}