scirs2_neural/models/architectures/vit.rs
1//! Vision Transformer (ViT) implementation
2//!
3//! Vision Transformer (ViT) is a transformer-based model for image classification
4//! that divides an image into fixed-size patches, linearly embeds them, adds position
5//! embeddings, and processes them using a standard Transformer encoder.
6//! Reference: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", Dosovitskiy et al. (2020)
7//! <https://arxiv.org/abs/2010.11929>
8
9use crate::error::{NeuralError, Result};
10use crate::layers::{Dense, Dropout, Layer, LayerNorm, MultiHeadAttention};
11// Note: PatchEmbedding not yet implemented
12use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
13use scirs2_core::numeric::{Float, NumAssign};
14use scirs2_core::random::SeedableRng;
15use scirs2_core::simd_ops::SimdUnifiedOps;
16use serde::{Deserialize, Serialize};
17use std::fmt::Debug;
18
19/// Configuration for a Vision Transformer model
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ViTConfig {
22 /// Image size (height, width)
23 pub image_size: (usize, usize),
24 /// Patch size (height, width)
25 pub patch_size: (usize, usize),
26 /// Number of input channels (e.g., 3 for RGB)
27 pub in_channels: usize,
28 /// Number of output classes
29 pub num_classes: usize,
30 /// Embedding dimension
31 pub embed_dim: usize,
32 /// Number of transformer layers
33 pub num_layers: usize,
34 /// Number of attention heads
35 pub num_heads: usize,
36 /// MLP hidden dimension
37 pub mlp_dim: usize,
38 /// Dropout rate
39 pub dropout_rate: f64,
40 /// Attention dropout rate
41 pub attention_dropout_rate: f64,
42}
43
44impl ViTConfig {
45 /// Create a ViT-Base configuration
46 pub fn vit_base(
47 image_size: (usize, usize),
48 patch_size: (usize, usize),
49 in_channels: usize,
50 num_classes: usize,
51 ) -> Self {
52 Self {
53 image_size,
54 patch_size,
55 in_channels,
56 num_classes,
57 embed_dim: 768,
58 num_layers: 12,
59 num_heads: 12,
60 mlp_dim: 3072,
61 dropout_rate: 0.1,
62 attention_dropout_rate: 0.0,
63 }
64 }
65
66 /// Create a ViT-Large configuration
67 pub fn vit_large(
68 image_size: (usize, usize),
69 patch_size: (usize, usize),
70 in_channels: usize,
71 num_classes: usize,
72 ) -> Self {
73 Self {
74 image_size,
75 patch_size,
76 in_channels,
77 num_classes,
78 embed_dim: 1024,
79 num_layers: 24,
80 num_heads: 16,
81 mlp_dim: 4096,
82 dropout_rate: 0.1,
83 attention_dropout_rate: 0.0,
84 }
85 }
86
87 /// Create a ViT-Huge configuration
88 pub fn vit_huge(
89 image_size: (usize, usize),
90 patch_size: (usize, usize),
91 in_channels: usize,
92 num_classes: usize,
93 ) -> Self {
94 Self {
95 image_size,
96 patch_size,
97 in_channels,
98 num_classes,
99 embed_dim: 1280,
100 num_layers: 32,
101 num_heads: 16,
102 mlp_dim: 5120,
103 dropout_rate: 0.1,
104 attention_dropout_rate: 0.0,
105 }
106 }
107}
108
109/// MLP with GELU activation for transformer blocks
110#[derive(Clone, Debug)]
111struct TransformerMlp<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
112 dense1: Dense<F>,
113 dense2: Dense<F>,
114}
115
116impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
117 for TransformerMlp<F>
118{
119 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
120 let mut x = self.dense1.forward(input)?;
121 // Apply GELU activation inline
122 x = x.mapv(|v| {
123 // GELU approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
124 let x3 = v * v * v;
125 v * F::from(0.5).expect("Failed to convert constant to float")
126 * (F::one()
127 + (v + F::from(0.044715).expect("Failed to convert constant to float") * x3)
128 .tanh())
129 });
130 x = self.dense2.forward(&x)?;
131 Ok(x)
132 }
133
134 fn backward(
135 &self,
136 _input: &Array<F, IxDyn>,
137 grad_output: &Array<F, IxDyn>,
138 ) -> Result<Array<F, IxDyn>> {
139 Ok(grad_output.clone())
140 }
141
142 fn update(&mut self, learning_rate: F) -> Result<()> {
143 self.dense1.update(learning_rate)?;
144 self.dense2.update(learning_rate)?;
145 Ok(())
146 }
147
148 fn as_any(&self) -> &dyn std::any::Any {
149 self
150 }
151
152 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
153 self
154 }
155}
156
157/// Transformer encoder block for ViT
158struct TransformerEncoderBlock<
159 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign,
160> {
161 /// Layer normalization 1
162 norm1: LayerNorm<F>,
163 /// Multi-head attention
164 attention: MultiHeadAttention<F>,
165 /// Layer normalization 2
166 norm2: LayerNorm<F>,
167 /// MLP layers
168 mlp: TransformerMlp<F>,
169 /// Dropout for attention
170 attn_dropout: Dropout<F>,
171 /// Dropout for MLP
172 mlp_dropout: Dropout<F>,
173}
174
175impl<
176 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
177 > Clone for TransformerEncoderBlock<F>
178{
179 fn clone(&self) -> Self {
180 Self {
181 norm1: self.norm1.clone(),
182 attention: self.attention.clone(),
183 norm2: self.norm2.clone(),
184 mlp: self.mlp.clone(),
185 attn_dropout: self.attn_dropout.clone(),
186 mlp_dropout: self.mlp_dropout.clone(),
187 }
188 }
189}
190
191impl<
192 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
193 > TransformerEncoderBlock<F>
194{
195 /// Create a new transformer encoder block
196 pub fn new(
197 dim: usize,
198 num_heads: usize,
199 mlp_dim: usize,
200 dropout_rate: F,
201 attention_dropout_rate: F,
202 ) -> Result<Self> {
203 // Layer normalization for attention
204 let mut ln_rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
205 let norm1 = LayerNorm::new(dim, 1e-6, &mut ln_rng)?;
206
207 // Multi-head attention
208 let attn_config = crate::layers::AttentionConfig {
209 num_heads,
210 head_dim: dim / num_heads,
211 dropout_prob: attention_dropout_rate.to_f64().expect("Operation failed"),
212 causal: false,
213 scale: None,
214 };
215 let mut attn_rng = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
216 let attention = MultiHeadAttention::new(dim, attn_config, &mut attn_rng)?;
217
218 // Layer normalization for MLP
219 let mut ln2_rng = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
220 let norm2 = LayerNorm::new(dim, 1e-6, &mut ln2_rng)?;
221
222 // MLP
223 let mut mlp_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
224 let mut mlp_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
225 let mlp = TransformerMlp {
226 dense1: Dense::new(dim, mlp_dim, None, &mut mlp_rng1)?,
227 dense2: Dense::new(mlp_dim, dim, None, &mut mlp_rng2)?,
228 };
229
230 // Dropouts
231 let dropout_rate_f64 = dropout_rate.to_f64().expect("Operation failed");
232 let mut dropout_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
233 let mut dropout_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
234 let attn_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng1)?;
235 let mlp_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng2)?;
236
237 Ok(Self {
238 norm1,
239 attention,
240 norm2,
241 mlp,
242 attn_dropout,
243 mlp_dropout,
244 })
245 }
246}
247
248impl<
249 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
250 > Layer<F> for TransformerEncoderBlock<F>
251{
252 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
253 // Norm -> Attention -> Dropout -> Add
254 let norm1_out = self.norm1.forward(input)?;
255 let attn = self.attention.forward(&norm1_out)?;
256 let attn_drop = self.attn_dropout.forward(&attn)?;
257
258 // Add residual connection
259 let residual1 = input + &attn_drop;
260
261 // Norm -> MLP -> Dropout -> Add
262 let norm2_out = self.norm2.forward(&residual1)?;
263 let mlp_out = self.mlp.forward(&norm2_out)?;
264 let mlp_drop = self.mlp_dropout.forward(&mlp_out)?;
265
266 // Add residual connection
267 let residual2 = &residual1 + &mlp_drop;
268
269 Ok(residual2)
270 }
271
272 fn backward(
273 &self,
274 _input: &Array<F, IxDyn>,
275 grad_output: &Array<F, IxDyn>,
276 ) -> Result<Array<F, IxDyn>> {
277 Ok(grad_output.clone())
278 }
279
280 fn update(&mut self, learning_rate: F) -> Result<()> {
281 self.norm1.update(learning_rate)?;
282 self.attention.update(learning_rate)?;
283 self.norm2.update(learning_rate)?;
284 self.mlp.update(learning_rate)?;
285 Ok(())
286 }
287
288 fn as_any(&self) -> &dyn std::any::Any {
289 self
290 }
291
292 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
293 self
294 }
295}
296
297// Vision Transformer implementation
298// TODO: Re-enable once PatchEmbedding is implemented
299// pub struct VisionTransformer<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign>
300// {
301// /// Patch embedding layer
302// patch_embed: PatchEmbedding<F>,
303// /// Class token embedding
304// cls_token: Array<F, IxDyn>,
305// /// Position embedding
306// pos_embed: Array<F, IxDyn>,
307// /// Dropout layer
308// dropout: Dropout<F>,
309// /// Transformer encoder blocks
310// encoder_blocks: Vec<TransformerEncoderBlock<F>>,
311// /// Layer normalization
312// norm: LayerNorm<F>,
313// /// Final classification head
314// classifier: Dense<F>,
315// /// Model configuration
316// config: ViTConfig,
317// }
318
319// TODO: Re-enable once PatchEmbedding is implemented
320// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static>
321// std::fmt::Debug for VisionTransformer<F>
322// {
323// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324// f.debug_struct("VisionTransformer")
325// .field("patch_embed", &self.patch_embed)
326// .field("cls_token", &self.cls_token)
327// .field("pos_embed", &self.pos_embed)
328// .field("dropout", &self.dropout)
329// .field(
330// "encoder_blocks",
331// &format!("<{} blocks>", self.encoder_blocks.len()),
332// )
333// .field("norm", &self.norm)
334// .field("classifier", &self.classifier)
335// .field("config", &self.config)
336// .finish()
337// }
338// }
339//
340// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static> Clone
341// for VisionTransformer<F>
342// {
343// fn clone(&self) -> Self {
344// Self {
345// patch_embed: self.patch_embed.clone(),
346// cls_token: self.cls_token.clone(),
347// pos_embed: self.pos_embed.clone(),
348// dropout: self.dropout.clone(),
349// encoder_blocks: self.encoder_blocks.clone(),
350// norm: self.norm.clone(),
351// classifier: self.classifier.clone(),
352// config: self.config.clone(),
353// }
354// }
355// }
356
357// TODO: Re-enable once PatchEmbedding is implemented
358// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static>
359// VisionTransformer<F>
360// {
361// /// Create a new Vision Transformer model
362// pub fn new(config: ViTConfig) -> Result<Self> {
363// // Calculate number of patches
364// let h_patches = config.image_size.0 / config.patch_size.0;
365// let w_patches = config.image_size.1 / config.patch_size.1;
366// let num_patches = h_patches * w_patches;
367//
368// // Create patch embedding layer
369// let patch_embed = PatchEmbedding::new(
370// config.image_size,
371// config.patch_size,
372// config.in_channels,
373// config.embed_dim,
374// true,
375// )?;
376//
377// // Create class token
378// let cls_token = Array::zeros(IxDyn(&[1, 1, config.embed_dim]));
379//
380// // Create position embedding (include class token)
381// let pos_embed = Array::zeros(IxDyn(&[1, num_patches + 1, config.embed_dim]));
382//
383// // Create dropout
384// let mut dropout_rng = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
385// let dropout = Dropout::new(config.dropout_rate, &mut dropout_rng)?;
386//
387// // Create transformer encoder blocks
388// let mut encoder_blocks = Vec::with_capacity(config.num_layers);
389// for i in 0..config.num_layers {
390// let block = TransformerEncoderBlock::new(
391// config.embed_dim,
392// config.num_heads,
393// config.mlp_dim,
394// F::from(config.dropout_rate).expect("Failed to convert to float"),
395// F::from(config.attention_dropout_rate).expect("Failed to convert to float"),
396// )?;
397// encoder_blocks.push(block);
398// let _ = i; // Suppress unused variable warning
399// }
400//
401// // Layer normalization
402// let mut norm_rng = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
403// let norm = LayerNorm::new(config.embed_dim, 1e-6, &mut norm_rng)?;
404//
405// // Classification head
406// let mut classifier_rng = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
407// let classifier = Dense::new(
408// config.embed_dim,
409// config.num_classes,
410// None,
411// &mut classifier_rng,
412// )?;
413//
414// Ok(Self {
415// patch_embed,
416// cls_token,
417// pos_embed,
418// dropout,
419// encoder_blocks,
420// norm,
421// classifier,
422// config,
423// })
424// }
425//
426// /// Create a ViT-Base model
427// pub fn vit_base(
428// image_size: (usize, usize),
429// patch_size: (usize, usize),
430// in_channels: usize,
431// num_classes: usize,
432// ) -> Result<Self> {
433// let config = ViTConfig::vit_base(image_size, patch_size, in_channels, num_classes);
434// Self::new(config)
435// }
436//
437// /// Create a ViT-Large model
438// pub fn vit_large(
439// image_size: (usize, usize),
440// patch_size: (usize, usize),
441// in_channels: usize,
442// num_classes: usize,
443// ) -> Result<Self> {
444// let config = ViTConfig::vit_large(image_size, patch_size, in_channels, num_classes);
445// Self::new(config)
446// }
447//
448// /// Create a ViT-Huge model
449// pub fn vit_huge(
450// image_size: (usize, usize),
451// patch_size: (usize, usize),
452// in_channels: usize,
453// num_classes: usize,
454// ) -> Result<Self> {
455// let config = ViTConfig::vit_huge(image_size, patch_size, in_channels, num_classes);
456// Self::new(config)
457// }
458//
459// /// Get the model configuration
460// pub fn config(&self) -> &ViTConfig {
461// &self.config
462// }
463// }
464
465// TODO: Re-enable once PatchEmbedding is implemented
466// impl<F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + 'static> Layer<F>
467// for VisionTransformer<F>
468// {
469// fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
470// // Check input shape
471// let shape = input.shape();
472// if shape.len() != 4
473// || shape[1] != self.config.in_channels
474// || shape[2] != self.config.image_size.0
475// || shape[3] != self.config.image_size.1
476// {
477// return Err(NeuralError::InferenceError(format!(
478// "Expected input shape [batch_size, {}, {}, {}], got {:?}",
479// self.config.in_channels, self.config.image_size.0, self.config.image_size.1, shape
480// )));
481// }
482//
483// let batch_size = shape[0];
484//
485// // Extract patch embeddings
486// let x = self.patch_embed.forward(input)?;
487//
488// // Calculate number of patches
489// let h_patches = self.config.image_size.0 / self.config.patch_size.0;
490// let w_patches = self.config.image_size.1 / self.config.patch_size.1;
491// let num_patches = h_patches * w_patches;
492//
493// // Prepend class token
494// let mut cls_tokens = Array::zeros(IxDyn(&[batch_size, 1, self.config.embed_dim]));
495// for b in 0..batch_size {
496// for i in 0..self.config.embed_dim {
497// cls_tokens[[b, 0, i]] = self.cls_token[[0, 0, i]];
498// }
499// }
500//
501// // Concatenate class token with patch embeddings
502// let mut x_with_cls =
503// Array::zeros(IxDyn(&[batch_size, num_patches + 1, self.config.embed_dim]));
504//
505// // Copy class token
506// for b in 0..batch_size {
507// for i in 0..self.config.embed_dim {
508// x_with_cls[[b, 0, i]] = cls_tokens[[b, 0, i]];
509// }
510// }
511//
512// // Copy patch embeddings
513// for b in 0..batch_size {
514// for p in 0..num_patches {
515// for i in 0..self.config.embed_dim {
516// x_with_cls[[b, p + 1, i]] = x[[b, p, i]];
517// }
518// }
519// }
520//
521// // Add position embeddings
522// for b in 0..batch_size {
523// for p in 0..num_patches + 1 {
524// for i in 0..self.config.embed_dim {
525// x_with_cls[[b, p, i]] = x_with_cls[[b, p, i]] + self.pos_embed[[0, p, i]];
526// }
527// }
528// }
529//
530// // Apply dropout
531// let mut x = self.dropout.forward(&x_with_cls)?;
532//
533// // Apply transformer encoder blocks
534// for block in &self.encoder_blocks {
535// x = block.forward(&x)?;
536// }
537//
538// // Apply layer normalization
539// x = self.norm.forward(&x)?;
540//
541// // Use only the class token for classification
542// let mut cls_token_final = Array::zeros(IxDyn(&[batch_size, self.config.embed_dim]));
543// for b in 0..batch_size {
544// for i in 0..self.config.embed_dim {
545// cls_token_final[[b, i]] = x[[b, 0, i]];
546// }
547// }
548//
549// // Apply classifier head
550// let logits = self.classifier.forward(&cls_token_final)?;
551//
552// Ok(logits)
553// }
554//
555// fn backward(
556// &self,
557// _input: &Array<F, IxDyn>,
558// grad_output: &Array<F, IxDyn>,
559// ) -> Result<Array<F, IxDyn>> {
560// Ok(grad_output.clone())
561// }
562//
563// fn update(&mut self, learning_rate: F) -> Result<()> {
564// self.patch_embed.update(learning_rate)?;
565// for block in &mut self.encoder_blocks {
566// block.update(learning_rate)?;
567// }
568// self.norm.update(learning_rate)?;
569// self.classifier.update(learning_rate)?;
570// Ok(())
571// }
572//
573// fn as_any(&self) -> &dyn std::any::Any {
574// self
575// }
576//
577// fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
578// self
579// }
580// }
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn test_vit_config_base() {
588 let config = ViTConfig::vit_base((224, 224), (16, 16), 3, 1000);
589 assert_eq!(config.embed_dim, 768);
590 assert_eq!(config.num_layers, 12);
591 assert_eq!(config.num_heads, 12);
592 }
593
594 #[test]
595 fn test_vit_config_large() {
596 let config = ViTConfig::vit_large((224, 224), (16, 16), 3, 1000);
597 assert_eq!(config.embed_dim, 1024);
598 assert_eq!(config.num_layers, 24);
599 assert_eq!(config.num_heads, 16);
600 }
601
602 #[test]
603 fn test_vit_config_huge() {
604 let config = ViTConfig::vit_huge((224, 224), (16, 16), 3, 1000);
605 assert_eq!(config.embed_dim, 1280);
606 assert_eq!(config.num_layers, 32);
607 assert_eq!(config.num_heads, 16);
608 }
609}