Skip to main content

scirs2_neural/models/architectures/
vit.rs

1//! Vision Transformer (ViT) implementation
2//!
3//! Vision Transformer (ViT) is a transformer-based model for image classification
4//! that divides an image into fixed-size patches, linearly embeds them, adds position
5//! embeddings, and processes them using a standard Transformer encoder.
6//! Reference: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", Dosovitskiy et al. (2020)
7//! <https://arxiv.org/abs/2010.11929>
8
9use crate::error::{NeuralError, Result};
10use crate::layers::{Dense, Dropout, Layer, LayerNorm, MultiHeadAttention, PatchEmbedding};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, NumAssign};
13use scirs2_core::random::SeedableRng;
14use scirs2_core::simd_ops::SimdUnifiedOps;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// Configuration for a Vision Transformer model
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ViTConfig {
21    /// Image size (height, width)
22    pub image_size: (usize, usize),
23    /// Patch size (height, width)
24    pub patch_size: (usize, usize),
25    /// Number of input channels (e.g., 3 for RGB)
26    pub in_channels: usize,
27    /// Number of output classes
28    pub num_classes: usize,
29    /// Embedding dimension
30    pub embed_dim: usize,
31    /// Number of transformer layers
32    pub num_layers: usize,
33    /// Number of attention heads
34    pub num_heads: usize,
35    /// MLP hidden dimension
36    pub mlp_dim: usize,
37    /// Dropout rate
38    pub dropout_rate: f64,
39    /// Attention dropout rate
40    pub attention_dropout_rate: f64,
41}
42
43impl ViTConfig {
44    /// Create a ViT-Base configuration
45    pub fn vit_base(
46        image_size: (usize, usize),
47        patch_size: (usize, usize),
48        in_channels: usize,
49        num_classes: usize,
50    ) -> Self {
51        Self {
52            image_size,
53            patch_size,
54            in_channels,
55            num_classes,
56            embed_dim: 768,
57            num_layers: 12,
58            num_heads: 12,
59            mlp_dim: 3072,
60            dropout_rate: 0.1,
61            attention_dropout_rate: 0.0,
62        }
63    }
64
65    /// Create a ViT-Large configuration
66    pub fn vit_large(
67        image_size: (usize, usize),
68        patch_size: (usize, usize),
69        in_channels: usize,
70        num_classes: usize,
71    ) -> Self {
72        Self {
73            image_size,
74            patch_size,
75            in_channels,
76            num_classes,
77            embed_dim: 1024,
78            num_layers: 24,
79            num_heads: 16,
80            mlp_dim: 4096,
81            dropout_rate: 0.1,
82            attention_dropout_rate: 0.0,
83        }
84    }
85
86    /// Create a ViT-Huge configuration
87    pub fn vit_huge(
88        image_size: (usize, usize),
89        patch_size: (usize, usize),
90        in_channels: usize,
91        num_classes: usize,
92    ) -> Self {
93        Self {
94            image_size,
95            patch_size,
96            in_channels,
97            num_classes,
98            embed_dim: 1280,
99            num_layers: 32,
100            num_heads: 16,
101            mlp_dim: 5120,
102            dropout_rate: 0.1,
103            attention_dropout_rate: 0.0,
104        }
105    }
106}
107
108/// MLP with GELU activation for transformer blocks
109#[derive(Clone, Debug)]
110struct TransformerMlp<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
111    dense1: Dense<F>,
112    dense2: Dense<F>,
113}
114
115impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
116    for TransformerMlp<F>
117{
118    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
119        let mut x = self.dense1.forward(input)?;
120        // Apply GELU activation inline
121        x = x.mapv(|v| {
122            // GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
123            let x3 = v * v * v;
124            v * F::from(0.5).expect("Failed to convert constant to float")
125                * (F::one()
126                    + (v + F::from(0.044715).expect("Failed to convert constant to float") * x3)
127                        .tanh())
128        });
129        x = self.dense2.forward(&x)?;
130        Ok(x)
131    }
132
133    fn backward(
134        &self,
135        _input: &Array<F, IxDyn>,
136        grad_output: &Array<F, IxDyn>,
137    ) -> Result<Array<F, IxDyn>> {
138        Ok(grad_output.clone())
139    }
140
141    fn update(&mut self, learning_rate: F) -> Result<()> {
142        self.dense1.update(learning_rate)?;
143        self.dense2.update(learning_rate)?;
144        Ok(())
145    }
146
147    fn as_any(&self) -> &dyn std::any::Any {
148        self
149    }
150
151    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
152        self
153    }
154}
155
156/// Transformer encoder block for ViT
157struct TransformerEncoderBlock<
158    F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign,
159> {
160    /// Layer normalization 1
161    norm1: LayerNorm<F>,
162    /// Multi-head attention
163    attention: MultiHeadAttention<F>,
164    /// Layer normalization 2
165    norm2: LayerNorm<F>,
166    /// MLP layers
167    mlp: TransformerMlp<F>,
168    /// Dropout for attention
169    attn_dropout: Dropout<F>,
170    /// Dropout for MLP
171    mlp_dropout: Dropout<F>,
172}
173
174impl<
175        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
176    > Clone for TransformerEncoderBlock<F>
177{
178    fn clone(&self) -> Self {
179        Self {
180            norm1: self.norm1.clone(),
181            attention: self.attention.clone(),
182            norm2: self.norm2.clone(),
183            mlp: self.mlp.clone(),
184            attn_dropout: self.attn_dropout.clone(),
185            mlp_dropout: self.mlp_dropout.clone(),
186        }
187    }
188}
189
190impl<
191        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
192    > TransformerEncoderBlock<F>
193{
194    /// Create a new transformer encoder block
195    pub fn new(
196        dim: usize,
197        num_heads: usize,
198        mlp_dim: usize,
199        dropout_rate: F,
200        attention_dropout_rate: F,
201    ) -> Result<Self> {
202        // Layer normalization for attention
203        let mut ln_rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
204        let norm1 = LayerNorm::new(dim, 1e-6, &mut ln_rng)?;
205
206        // Multi-head attention
207        let attn_config = crate::layers::AttentionConfig {
208            num_heads,
209            head_dim: dim / num_heads,
210            dropout_prob: attention_dropout_rate.to_f64().expect("Operation failed"),
211            causal: false,
212            scale: None,
213        };
214        let mut attn_rng = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
215        let attention = MultiHeadAttention::new(dim, attn_config, &mut attn_rng)?;
216
217        // Layer normalization for MLP
218        let mut ln2_rng = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
219        let norm2 = LayerNorm::new(dim, 1e-6, &mut ln2_rng)?;
220
221        // MLP
222        let mut mlp_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
223        let mut mlp_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
224        let mlp = TransformerMlp {
225            dense1: Dense::new(dim, mlp_dim, None, &mut mlp_rng1)?,
226            dense2: Dense::new(mlp_dim, dim, None, &mut mlp_rng2)?,
227        };
228
229        // Dropouts
230        let dropout_rate_f64 = dropout_rate.to_f64().expect("Operation failed");
231        let mut dropout_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
232        let mut dropout_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
233        let attn_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng1)?;
234        let mlp_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng2)?;
235
236        Ok(Self {
237            norm1,
238            attention,
239            norm2,
240            mlp,
241            attn_dropout,
242            mlp_dropout,
243        })
244    }
245}
246
247impl<
248        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
249    > Layer<F> for TransformerEncoderBlock<F>
250{
251    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
252        // Norm -> Attention -> Dropout -> Add
253        let norm1_out = self.norm1.forward(input)?;
254        let attn = self.attention.forward(&norm1_out)?;
255        let attn_drop = self.attn_dropout.forward(&attn)?;
256
257        // Add residual connection
258        let residual1 = input + &attn_drop;
259
260        // Norm -> MLP -> Dropout -> Add
261        let norm2_out = self.norm2.forward(&residual1)?;
262        let mlp_out = self.mlp.forward(&norm2_out)?;
263        let mlp_drop = self.mlp_dropout.forward(&mlp_out)?;
264
265        // Add residual connection
266        let residual2 = &residual1 + &mlp_drop;
267
268        Ok(residual2)
269    }
270
271    fn backward(
272        &self,
273        _input: &Array<F, IxDyn>,
274        grad_output: &Array<F, IxDyn>,
275    ) -> Result<Array<F, IxDyn>> {
276        Ok(grad_output.clone())
277    }
278
279    fn update(&mut self, learning_rate: F) -> Result<()> {
280        self.norm1.update(learning_rate)?;
281        self.attention.update(learning_rate)?;
282        self.norm2.update(learning_rate)?;
283        self.mlp.update(learning_rate)?;
284        Ok(())
285    }
286
287    fn as_any(&self) -> &dyn std::any::Any {
288        self
289    }
290
291    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
292        self
293    }
294}
295
296/// Vision Transformer (ViT) model.
297///
298/// Divides images into patches, embeds them, prepends a learnable class token,
299/// adds position embeddings, and processes through a stack of Transformer encoder blocks.
300/// The class-token output is then fed to a linear classification head.
301pub struct VisionTransformer<
302    F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign,
303> {
304    /// Patch embedding layer
305    patch_embed: PatchEmbedding<F>,
306    /// Class token embedding
307    cls_token: Array<F, IxDyn>,
308    /// Position embedding
309    pos_embed: Array<F, IxDyn>,
310    /// Dropout layer
311    dropout: Dropout<F>,
312    /// Transformer encoder blocks
313    encoder_blocks: Vec<TransformerEncoderBlock<F>>,
314    /// Layer normalization
315    norm: LayerNorm<F>,
316    /// Final classification head
317    classifier: Dense<F>,
318    /// Model configuration
319    config: ViTConfig,
320}
321
322impl<
323        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
324    > std::fmt::Debug for VisionTransformer<F>
325{
326    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        f.debug_struct("VisionTransformer")
328            .field("patch_embed", &self.patch_embed)
329            .field("cls_token", &self.cls_token)
330            .field("pos_embed", &self.pos_embed)
331            .field("dropout", &self.dropout)
332            .field(
333                "encoder_blocks",
334                &format!("<{} blocks>", self.encoder_blocks.len()),
335            )
336            .field("norm", &self.norm)
337            .field("classifier", &self.classifier)
338            .field("config", &self.config)
339            .finish()
340    }
341}
342
343impl<
344        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
345    > Clone for VisionTransformer<F>
346{
347    fn clone(&self) -> Self {
348        Self {
349            patch_embed: self.patch_embed.clone(),
350            cls_token: self.cls_token.clone(),
351            pos_embed: self.pos_embed.clone(),
352            dropout: self.dropout.clone(),
353            encoder_blocks: self.encoder_blocks.clone(),
354            norm: self.norm.clone(),
355            classifier: self.classifier.clone(),
356            config: self.config.clone(),
357        }
358    }
359}
360
361impl<
362        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
363    > VisionTransformer<F>
364{
365    /// Create a new Vision Transformer model
366    pub fn new(config: ViTConfig) -> Result<Self> {
367        // Calculate number of patches
368        let h_patches = config.image_size.0 / config.patch_size.0;
369        let w_patches = config.image_size.1 / config.patch_size.1;
370        let num_patches = h_patches * w_patches;
371
372        // Create patch embedding layer
373        let mut pe_rng = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
374        let patch_embed = PatchEmbedding::new(
375            config.image_size,
376            config.patch_size,
377            config.in_channels,
378            config.embed_dim,
379            true,
380            &mut pe_rng,
381        )?;
382
383        // Create class token
384        let cls_token = Array::zeros(IxDyn(&[1, 1, config.embed_dim]));
385
386        // Create position embedding (include class token)
387        let pos_embed = Array::zeros(IxDyn(&[1, num_patches + 1, config.embed_dim]));
388
389        // Create dropout
390        let mut dropout_rng = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
391        let dropout = Dropout::new(config.dropout_rate, &mut dropout_rng)?;
392
393        // Create transformer encoder blocks
394        let mut encoder_blocks = Vec::with_capacity(config.num_layers);
395        for i in 0..config.num_layers {
396            let block = TransformerEncoderBlock::new(
397                config.embed_dim,
398                config.num_heads,
399                config.mlp_dim,
400                F::from(config.dropout_rate).ok_or_else(|| {
401                    NeuralError::InvalidArchitecture(
402                        "Failed to convert dropout_rate to float".to_string(),
403                    )
404                })?,
405                F::from(config.attention_dropout_rate).ok_or_else(|| {
406                    NeuralError::InvalidArchitecture(
407                        "Failed to convert attention_dropout_rate to float".to_string(),
408                    )
409                })?,
410            )?;
411            encoder_blocks.push(block);
412            let _ = i;
413        }
414
415        // Layer normalization
416        let mut norm_rng = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
417        let norm = LayerNorm::new(config.embed_dim, 1e-6, &mut norm_rng)?;
418
419        // Classification head
420        let mut classifier_rng = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
421        let classifier = Dense::new(
422            config.embed_dim,
423            config.num_classes,
424            None,
425            &mut classifier_rng,
426        )?;
427
428        Ok(Self {
429            patch_embed,
430            cls_token,
431            pos_embed,
432            dropout,
433            encoder_blocks,
434            norm,
435            classifier,
436            config,
437        })
438    }
439
440    /// Create a ViT-Base model
441    pub fn vit_base(
442        image_size: (usize, usize),
443        patch_size: (usize, usize),
444        in_channels: usize,
445        num_classes: usize,
446    ) -> Result<Self> {
447        let config = ViTConfig::vit_base(image_size, patch_size, in_channels, num_classes);
448        Self::new(config)
449    }
450
451    /// Create a ViT-Large model
452    pub fn vit_large(
453        image_size: (usize, usize),
454        patch_size: (usize, usize),
455        in_channels: usize,
456        num_classes: usize,
457    ) -> Result<Self> {
458        let config = ViTConfig::vit_large(image_size, patch_size, in_channels, num_classes);
459        Self::new(config)
460    }
461
462    /// Create a ViT-Huge model
463    pub fn vit_huge(
464        image_size: (usize, usize),
465        patch_size: (usize, usize),
466        in_channels: usize,
467        num_classes: usize,
468    ) -> Result<Self> {
469        let config = ViTConfig::vit_huge(image_size, patch_size, in_channels, num_classes);
470        Self::new(config)
471    }
472
473    /// Get the model configuration
474    pub fn config(&self) -> &ViTConfig {
475        &self.config
476    }
477}
478
479impl<
480        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
481    > Layer<F> for VisionTransformer<F>
482{
483    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
484        let shape = input.shape();
485        if shape.len() != 4
486            || shape[1] != self.config.in_channels
487            || shape[2] != self.config.image_size.0
488            || shape[3] != self.config.image_size.1
489        {
490            return Err(NeuralError::InferenceError(format!(
491                "Expected input shape [batch_size, {}, {}, {}], got {:?}",
492                self.config.in_channels, self.config.image_size.0, self.config.image_size.1, shape
493            )));
494        }
495
496        let batch_size = shape[0];
497
498        // Extract patch embeddings: [batch, num_patches, embed_dim]
499        let x = self.patch_embed.forward(input)?;
500
501        let h_patches = self.config.image_size.0 / self.config.patch_size.0;
502        let w_patches = self.config.image_size.1 / self.config.patch_size.1;
503        let num_patches = h_patches * w_patches;
504
505        // Concatenate class token + patch embeddings: [batch, num_patches+1, embed_dim]
506        let mut x_with_cls =
507            Array::zeros(IxDyn(&[batch_size, num_patches + 1, self.config.embed_dim]));
508
509        // Copy class token (broadcast from [1,1,embed_dim])
510        for b in 0..batch_size {
511            for i in 0..self.config.embed_dim {
512                x_with_cls[[b, 0, i]] = self.cls_token[[0, 0, i]];
513            }
514        }
515
516        // Copy patch embeddings
517        for b in 0..batch_size {
518            for p in 0..num_patches {
519                for i in 0..self.config.embed_dim {
520                    x_with_cls[[b, p + 1, i]] = x[[b, p, i]];
521                }
522            }
523        }
524
525        // Add position embeddings (broadcast from [1, num_patches+1, embed_dim])
526        for b in 0..batch_size {
527            for p in 0..num_patches + 1 {
528                for i in 0..self.config.embed_dim {
529                    x_with_cls[[b, p, i]] += self.pos_embed[[0, p, i]];
530                }
531            }
532        }
533
534        // Dropout → encoder blocks → layer norm
535        let mut x = self.dropout.forward(&x_with_cls)?;
536        for block in &self.encoder_blocks {
537            x = block.forward(&x)?;
538        }
539        x = self.norm.forward(&x)?;
540
541        // Extract class token: [batch, embed_dim]
542        let mut cls_token_final = Array::zeros(IxDyn(&[batch_size, self.config.embed_dim]));
543        for b in 0..batch_size {
544            for i in 0..self.config.embed_dim {
545                cls_token_final[[b, i]] = x[[b, 0, i]];
546            }
547        }
548
549        // Classification head
550        self.classifier.forward(&cls_token_final)
551    }
552
553    fn backward(
554        &self,
555        _input: &Array<F, IxDyn>,
556        grad_output: &Array<F, IxDyn>,
557    ) -> Result<Array<F, IxDyn>> {
558        Ok(grad_output.clone())
559    }
560
561    fn update(&mut self, learning_rate: F) -> Result<()> {
562        self.patch_embed.update(learning_rate)?;
563        for block in &mut self.encoder_blocks {
564            block.update(learning_rate)?;
565        }
566        self.norm.update(learning_rate)?;
567        self.classifier.update(learning_rate)?;
568        Ok(())
569    }
570
571    fn as_any(&self) -> &dyn std::any::Any {
572        self
573    }
574
575    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
576        self
577    }
578
579    fn layer_type(&self) -> &str {
580        "VisionTransformer"
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    #[test]
589    fn test_vit_config_base() {
590        let config = ViTConfig::vit_base((224, 224), (16, 16), 3, 1000);
591        assert_eq!(config.embed_dim, 768);
592        assert_eq!(config.num_layers, 12);
593        assert_eq!(config.num_heads, 12);
594    }
595
596    #[test]
597    fn test_vit_config_large() {
598        let config = ViTConfig::vit_large((224, 224), (16, 16), 3, 1000);
599        assert_eq!(config.embed_dim, 1024);
600        assert_eq!(config.num_layers, 24);
601        assert_eq!(config.num_heads, 16);
602    }
603
604    #[test]
605    fn test_vit_config_huge() {
606        let config = ViTConfig::vit_huge((224, 224), (16, 16), 3, 1000);
607        assert_eq!(config.embed_dim, 1280);
608        assert_eq!(config.num_layers, 32);
609        assert_eq!(config.num_heads, 16);
610    }
611}