Skip to main content

scirs2_neural/models/architectures/
vit.rs

1//! Vision Transformer (ViT) implementation
2//!
3//! Vision Transformer (ViT) is a transformer-based model for image classification
4//! that divides an image into fixed-size patches, linearly embeds them, adds position
5//! embeddings, and processes them using a standard Transformer encoder.
6//! Reference: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", Dosovitskiy et al. (2020)
7//! <https://arxiv.org/abs/2010.11929>
8
9use crate::error::{NeuralError, Result};
10use crate::layers::{Dense, Dropout, Layer, LayerNorm, MultiHeadAttention};
11// Note: PatchEmbedding not yet implemented
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, NumAssign};
14use scirs2_core::random::SeedableRng;
15use scirs2_core::simd_ops::SimdUnifiedOps;
16use serde::{Deserialize, Serialize};
17use std::fmt::Debug;
18
19/// Configuration for a Vision Transformer model
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ViTConfig {
22    /// Image size (height, width)
23    pub image_size: (usize, usize),
24    /// Patch size (height, width)
25    pub patch_size: (usize, usize),
26    /// Number of input channels (e.g., 3 for RGB)
27    pub in_channels: usize,
28    /// Number of output classes
29    pub num_classes: usize,
30    /// Embedding dimension
31    pub embed_dim: usize,
32    /// Number of transformer layers
33    pub num_layers: usize,
34    /// Number of attention heads
35    pub num_heads: usize,
36    /// MLP hidden dimension
37    pub mlp_dim: usize,
38    /// Dropout rate
39    pub dropout_rate: f64,
40    /// Attention dropout rate
41    pub attention_dropout_rate: f64,
42}
43
44impl ViTConfig {
45    /// Create a ViT-Base configuration
46    pub fn vit_base(
47        image_size: (usize, usize),
48        patch_size: (usize, usize),
49        in_channels: usize,
50        num_classes: usize,
51    ) -> Self {
52        Self {
53            image_size,
54            patch_size,
55            in_channels,
56            num_classes,
57            embed_dim: 768,
58            num_layers: 12,
59            num_heads: 12,
60            mlp_dim: 3072,
61            dropout_rate: 0.1,
62            attention_dropout_rate: 0.0,
63        }
64    }
65
66    /// Create a ViT-Large configuration
67    pub fn vit_large(
68        image_size: (usize, usize),
69        patch_size: (usize, usize),
70        in_channels: usize,
71        num_classes: usize,
72    ) -> Self {
73        Self {
74            image_size,
75            patch_size,
76            in_channels,
77            num_classes,
78            embed_dim: 1024,
79            num_layers: 24,
80            num_heads: 16,
81            mlp_dim: 4096,
82            dropout_rate: 0.1,
83            attention_dropout_rate: 0.0,
84        }
85    }
86
87    /// Create a ViT-Huge configuration
88    pub fn vit_huge(
89        image_size: (usize, usize),
90        patch_size: (usize, usize),
91        in_channels: usize,
92        num_classes: usize,
93    ) -> Self {
94        Self {
95            image_size,
96            patch_size,
97            in_channels,
98            num_classes,
99            embed_dim: 1280,
100            num_layers: 32,
101            num_heads: 16,
102            mlp_dim: 5120,
103            dropout_rate: 0.1,
104            attention_dropout_rate: 0.0,
105        }
106    }
107}
108
109/// MLP with GELU activation for transformer blocks
110#[derive(Clone, Debug)]
111struct TransformerMlp<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
112    dense1: Dense<F>,
113    dense2: Dense<F>,
114}
115
116impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
117    for TransformerMlp<F>
118{
119    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
120        let mut x = self.dense1.forward(input)?;
121        // Apply GELU activation inline
122        x = x.mapv(|v| {
123            // GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
124            let x3 = v * v * v;
125            v * F::from(0.5).expect("Failed to convert constant to float")
126                * (F::one()
127                    + (v + F::from(0.044715).expect("Failed to convert constant to float") * x3)
128                        .tanh())
129        });
130        x = self.dense2.forward(&x)?;
131        Ok(x)
132    }
133
134    fn backward(
135        &self,
136        _input: &Array<F, IxDyn>,
137        grad_output: &Array<F, IxDyn>,
138    ) -> Result<Array<F, IxDyn>> {
139        Ok(grad_output.clone())
140    }
141
142    fn update(&mut self, learning_rate: F) -> Result<()> {
143        self.dense1.update(learning_rate)?;
144        self.dense2.update(learning_rate)?;
145        Ok(())
146    }
147
148    fn as_any(&self) -> &dyn std::any::Any {
149        self
150    }
151
152    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
153        self
154    }
155}
156
157/// Transformer encoder block for ViT
158struct TransformerEncoderBlock<
159    F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign,
160> {
161    /// Layer normalization 1
162    norm1: LayerNorm<F>,
163    /// Multi-head attention
164    attention: MultiHeadAttention<F>,
165    /// Layer normalization 2
166    norm2: LayerNorm<F>,
167    /// MLP layers
168    mlp: TransformerMlp<F>,
169    /// Dropout for attention
170    attn_dropout: Dropout<F>,
171    /// Dropout for MLP
172    mlp_dropout: Dropout<F>,
173}
174
175impl<
176        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
177    > Clone for TransformerEncoderBlock<F>
178{
179    fn clone(&self) -> Self {
180        Self {
181            norm1: self.norm1.clone(),
182            attention: self.attention.clone(),
183            norm2: self.norm2.clone(),
184            mlp: self.mlp.clone(),
185            attn_dropout: self.attn_dropout.clone(),
186            mlp_dropout: self.mlp_dropout.clone(),
187        }
188    }
189}
190
191impl<
192        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
193    > TransformerEncoderBlock<F>
194{
195    /// Create a new transformer encoder block
196    pub fn new(
197        dim: usize,
198        num_heads: usize,
199        mlp_dim: usize,
200        dropout_rate: F,
201        attention_dropout_rate: F,
202    ) -> Result<Self> {
203        // Layer normalization for attention
204        let mut ln_rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
205        let norm1 = LayerNorm::new(dim, 1e-6, &mut ln_rng)?;
206
207        // Multi-head attention
208        let attn_config = crate::layers::AttentionConfig {
209            num_heads,
210            head_dim: dim / num_heads,
211            dropout_prob: attention_dropout_rate.to_f64().expect("Operation failed"),
212            causal: false,
213            scale: None,
214        };
215        let mut attn_rng = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
216        let attention = MultiHeadAttention::new(dim, attn_config, &mut attn_rng)?;
217
218        // Layer normalization for MLP
219        let mut ln2_rng = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
220        let norm2 = LayerNorm::new(dim, 1e-6, &mut ln2_rng)?;
221
222        // MLP
223        let mut mlp_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
224        let mut mlp_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
225        let mlp = TransformerMlp {
226            dense1: Dense::new(dim, mlp_dim, None, &mut mlp_rng1)?,
227            dense2: Dense::new(mlp_dim, dim, None, &mut mlp_rng2)?,
228        };
229
230        // Dropouts
231        let dropout_rate_f64 = dropout_rate.to_f64().expect("Operation failed");
232        let mut dropout_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
233        let mut dropout_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
234        let attn_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng1)?;
235        let mlp_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng2)?;
236
237        Ok(Self {
238            norm1,
239            attention,
240            norm2,
241            mlp,
242            attn_dropout,
243            mlp_dropout,
244        })
245    }
246}
247
248impl<
249        F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
250    > Layer<F> for TransformerEncoderBlock<F>
251{
252    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
253        // Norm -> Attention -> Dropout -> Add
254        let norm1_out = self.norm1.forward(input)?;
255        let attn = self.attention.forward(&norm1_out)?;
256        let attn_drop = self.attn_dropout.forward(&attn)?;
257
258        // Add residual connection
259        let residual1 = input + &attn_drop;
260
261        // Norm -> MLP -> Dropout -> Add
262        let norm2_out = self.norm2.forward(&residual1)?;
263        let mlp_out = self.mlp.forward(&norm2_out)?;
264        let mlp_drop = self.mlp_dropout.forward(&mlp_out)?;
265
266        // Add residual connection
267        let residual2 = &residual1 + &mlp_drop;
268
269        Ok(residual2)
270    }
271
272    fn backward(
273        &self,
274        _input: &Array<F, IxDyn>,
275        grad_output: &Array<F, IxDyn>,
276    ) -> Result<Array<F, IxDyn>> {
277        Ok(grad_output.clone())
278    }
279
280    fn update(&mut self, learning_rate: F) -> Result<()> {
281        self.norm1.update(learning_rate)?;
282        self.attention.update(learning_rate)?;
283        self.norm2.update(learning_rate)?;
284        self.mlp.update(learning_rate)?;
285        Ok(())
286    }
287
288    fn as_any(&self) -> &dyn std::any::Any {
289        self
290    }
291
292    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
293        self
294    }
295}
296
297// Vision Transformer implementation
298// TODO: Re-enable once PatchEmbedding is implemented
299// pub struct VisionTransformer<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign>
300// {
301//     /// Patch embedding layer
302//     patch_embed: PatchEmbedding<F>,
303//     /// Class token embedding
304//     cls_token: Array<F, IxDyn>,
305//     /// Position embedding
306//     pos_embed: Array<F, IxDyn>,
307//     /// Dropout layer
308//     dropout: Dropout<F>,
309//     /// Transformer encoder blocks
310//     encoder_blocks: Vec<TransformerEncoderBlock<F>>,
311//     /// Layer normalization
312//     norm: LayerNorm<F>,
313//     /// Final classification head
314//     classifier: Dense<F>,
315//     /// Model configuration
316//     config: ViTConfig,
317// }
318
319// TODO: Re-enable once PatchEmbedding is implemented
320// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static>
321//     std::fmt::Debug for VisionTransformer<F>
322// {
323//     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324//         f.debug_struct("VisionTransformer")
325//             .field("patch_embed", &self.patch_embed)
326//             .field("cls_token", &self.cls_token)
327//             .field("pos_embed", &self.pos_embed)
328//             .field("dropout", &self.dropout)
329//             .field(
330//                 "encoder_blocks",
331//                 &format!("<{} blocks>", self.encoder_blocks.len()),
332//             )
333//             .field("norm", &self.norm)
334//             .field("classifier", &self.classifier)
335//             .field("config", &self.config)
336//             .finish()
337//     }
338// }
339//
340// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static> Clone
341//     for VisionTransformer<F>
342// {
343//     fn clone(&self) -> Self {
344//         Self {
345//             patch_embed: self.patch_embed.clone(),
346//             cls_token: self.cls_token.clone(),
347//             pos_embed: self.pos_embed.clone(),
348//             dropout: self.dropout.clone(),
349//             encoder_blocks: self.encoder_blocks.clone(),
350//             norm: self.norm.clone(),
351//             classifier: self.classifier.clone(),
352//             config: self.config.clone(),
353//         }
354//     }
355// }
356
357// TODO: Re-enable once PatchEmbedding is implemented
358// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static>
359//     VisionTransformer<F>
360// {
361//     /// Create a new Vision Transformer model
362//     pub fn new(config: ViTConfig) -> Result<Self> {
363//         // Calculate number of patches
364//         let h_patches = config.image_size.0 / config.patch_size.0;
365//         let w_patches = config.image_size.1 / config.patch_size.1;
366//         let num_patches = h_patches * w_patches;
367//
368//         // Create patch embedding layer
369//         let patch_embed = PatchEmbedding::new(
370//             config.image_size,
371//             config.patch_size,
372//             config.in_channels,
373//             config.embed_dim,
374//             true,
375//         )?;
376//
377//         // Create class token
378//         let cls_token = Array::zeros(IxDyn(&[1, 1, config.embed_dim]));
379//
380//         // Create position embedding (include class token)
381//         let pos_embed = Array::zeros(IxDyn(&[1, num_patches + 1, config.embed_dim]));
382//
383//         // Create dropout
384//         let mut dropout_rng = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
385//         let dropout = Dropout::new(config.dropout_rate, &mut dropout_rng)?;
386//
387//         // Create transformer encoder blocks
388//         let mut encoder_blocks = Vec::with_capacity(config.num_layers);
389//         for i in 0..config.num_layers {
390//             let block = TransformerEncoderBlock::new(
391//                 config.embed_dim,
392//                 config.num_heads,
393//                 config.mlp_dim,
394//                 F::from(config.dropout_rate).expect("Failed to convert to float"),
395//                 F::from(config.attention_dropout_rate).expect("Failed to convert to float"),
396//             )?;
397//             encoder_blocks.push(block);
398//             let _ = i; // Suppress unused variable warning
399//         }
400//
401//         // Layer normalization
402//         let mut norm_rng = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
403//         let norm = LayerNorm::new(config.embed_dim, 1e-6, &mut norm_rng)?;
404//
405//         // Classification head
406//         let mut classifier_rng = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
407//         let classifier = Dense::new(
408//             config.embed_dim,
409//             config.num_classes,
410//             None,
411//             &mut classifier_rng,
412//         )?;
413//
414//         Ok(Self {
415//             patch_embed,
416//             cls_token,
417//             pos_embed,
418//             dropout,
419//             encoder_blocks,
420//             norm,
421//             classifier,
422//             config,
423//         })
424//     }
425//
426//     /// Create a ViT-Base model
427//     pub fn vit_base(
428//         image_size: (usize, usize),
429//         patch_size: (usize, usize),
430//         in_channels: usize,
431//         num_classes: usize,
432//     ) -> Result<Self> {
433//         let config = ViTConfig::vit_base(image_size, patch_size, in_channels, num_classes);
434//         Self::new(config)
435//     }
436//
437//     /// Create a ViT-Large model
438//     pub fn vit_large(
439//         image_size: (usize, usize),
440//         patch_size: (usize, usize),
441//         in_channels: usize,
442//         num_classes: usize,
443//     ) -> Result<Self> {
444//         let config = ViTConfig::vit_large(image_size, patch_size, in_channels, num_classes);
445//         Self::new(config)
446//     }
447//
448//     /// Create a ViT-Huge model
449//     pub fn vit_huge(
450//         image_size: (usize, usize),
451//         patch_size: (usize, usize),
452//         in_channels: usize,
453//         num_classes: usize,
454//     ) -> Result<Self> {
455//         let config = ViTConfig::vit_huge(image_size, patch_size, in_channels, num_classes);
456//         Self::new(config)
457//     }
458//
459//     /// Get the model configuration
460//     pub fn config(&self) -> &ViTConfig {
461//         &self.config
462//     }
463// }
464
465// TODO: Re-enable once PatchEmbedding is implemented
466// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static> Layer<F>
467//     for VisionTransformer<F>
468// {
469//     fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
470//         // Check input shape
471//         let shape = input.shape();
472//         if shape.len() != 4
473//             || shape[1] != self.config.in_channels
474//             || shape[2] != self.config.image_size.0
475//             || shape[3] != self.config.image_size.1
476//         {
477//             return Err(NeuralError::InferenceError(format!(
478//                 "Expected input shape [batch_size, {}, {}, {}], got {:?}",
479//                 self.config.in_channels, self.config.image_size.0, self.config.image_size.1, shape
480//             )));
481//         }
482//
483//         let batch_size = shape[0];
484//
485//         // Extract patch embeddings
486//         let x = self.patch_embed.forward(input)?;
487//
488//         // Calculate number of patches
489//         let h_patches = self.config.image_size.0 / self.config.patch_size.0;
490//         let w_patches = self.config.image_size.1 / self.config.patch_size.1;
491//         let num_patches = h_patches * w_patches;
492//
493//         // Prepend class token
494//         let mut cls_tokens = Array::zeros(IxDyn(&[batch_size, 1, self.config.embed_dim]));
495//         for b in 0..batch_size {
496//             for i in 0..self.config.embed_dim {
497//                 cls_tokens[[b, 0, i]] = self.cls_token[[0, 0, i]];
498//             }
499//         }
500//
501//         // Concatenate class token with patch embeddings
502//         let mut x_with_cls =
503//             Array::zeros(IxDyn(&[batch_size, num_patches + 1, self.config.embed_dim]));
504//
505//         // Copy class token
506//         for b in 0..batch_size {
507//             for i in 0..self.config.embed_dim {
508//                 x_with_cls[[b, 0, i]] = cls_tokens[[b, 0, i]];
509//             }
510//         }
511//
512//         // Copy patch embeddings
513//         for b in 0..batch_size {
514//             for p in 0..num_patches {
515//                 for i in 0..self.config.embed_dim {
516//                     x_with_cls[[b, p + 1, i]] = x[[b, p, i]];
517//                 }
518//             }
519//         }
520//
521//         // Add position embeddings
522//         for b in 0..batch_size {
523//             for p in 0..num_patches + 1 {
524//                 for i in 0..self.config.embed_dim {
525//                     x_with_cls[[b, p, i]] = x_with_cls[[b, p, i]] + self.pos_embed[[0, p, i]];
526//                 }
527//             }
528//         }
529//
530//         // Apply dropout
531//         let mut x = self.dropout.forward(&x_with_cls)?;
532//
533//         // Apply transformer encoder blocks
534//         for block in &self.encoder_blocks {
535//             x = block.forward(&x)?;
536//         }
537//
538//         // Apply layer normalization
539//         x = self.norm.forward(&x)?;
540//
541//         // Use only the class token for classification
542//         let mut cls_token_final = Array::zeros(IxDyn(&[batch_size, self.config.embed_dim]));
543//         for b in 0..batch_size {
544//             for i in 0..self.config.embed_dim {
545//                 cls_token_final[[b, i]] = x[[b, 0, i]];
546//             }
547//         }
548//
549//         // Apply classifier head
550//         let logits = self.classifier.forward(&cls_token_final)?;
551//
552//         Ok(logits)
553//     }
554//
555//     fn backward(
556//         &self,
557//         _input: &Array<F, IxDyn>,
558//         grad_output: &Array<F, IxDyn>,
559//     ) -> Result<Array<F, IxDyn>> {
560//         Ok(grad_output.clone())
561//     }
562//
563//     fn update(&mut self, learning_rate: F) -> Result<()> {
564//         self.patch_embed.update(learning_rate)?;
565//         for block in &mut self.encoder_blocks {
566//             block.update(learning_rate)?;
567//         }
568//         self.norm.update(learning_rate)?;
569//         self.classifier.update(learning_rate)?;
570//         Ok(())
571//     }
572//
573//     fn as_any(&self) -> &dyn std::any::Any {
574//         self
575//     }
576//
577//     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
578//         self
579//     }
580// }
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    #[test]
587    fn test_vit_config_base() {
588        let config = ViTConfig::vit_base((224, 224), (16, 16), 3, 1000);
589        assert_eq!(config.embed_dim, 768);
590        assert_eq!(config.num_layers, 12);
591        assert_eq!(config.num_heads, 12);
592    }
593
594    #[test]
595    fn test_vit_config_large() {
596        let config = ViTConfig::vit_large((224, 224), (16, 16), 3, 1000);
597        assert_eq!(config.embed_dim, 1024);
598        assert_eq!(config.num_layers, 24);
599        assert_eq!(config.num_heads, 16);
600    }
601
602    #[test]
603    fn test_vit_config_huge() {
604        let config = ViTConfig::vit_huge((224, 224), (16, 16), 3, 1000);
605        assert_eq!(config.embed_dim, 1280);
606        assert_eq!(config.num_layers, 32);
607        assert_eq!(config.num_heads, 16);
608    }
609}