1use crate::error::{NeuralError, Result};
10use crate::layers::{Dense, Dropout, Layer, LayerNorm, MultiHeadAttention, PatchEmbedding};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, NumAssign};
13use scirs2_core::random::SeedableRng;
14use scirs2_core::simd_ops::SimdUnifiedOps;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ViTConfig {
21 pub image_size: (usize, usize),
23 pub patch_size: (usize, usize),
25 pub in_channels: usize,
27 pub num_classes: usize,
29 pub embed_dim: usize,
31 pub num_layers: usize,
33 pub num_heads: usize,
35 pub mlp_dim: usize,
37 pub dropout_rate: f64,
39 pub attention_dropout_rate: f64,
41}
42
43impl ViTConfig {
44 pub fn vit_base(
46 image_size: (usize, usize),
47 patch_size: (usize, usize),
48 in_channels: usize,
49 num_classes: usize,
50 ) -> Self {
51 Self {
52 image_size,
53 patch_size,
54 in_channels,
55 num_classes,
56 embed_dim: 768,
57 num_layers: 12,
58 num_heads: 12,
59 mlp_dim: 3072,
60 dropout_rate: 0.1,
61 attention_dropout_rate: 0.0,
62 }
63 }
64
65 pub fn vit_large(
67 image_size: (usize, usize),
68 patch_size: (usize, usize),
69 in_channels: usize,
70 num_classes: usize,
71 ) -> Self {
72 Self {
73 image_size,
74 patch_size,
75 in_channels,
76 num_classes,
77 embed_dim: 1024,
78 num_layers: 24,
79 num_heads: 16,
80 mlp_dim: 4096,
81 dropout_rate: 0.1,
82 attention_dropout_rate: 0.0,
83 }
84 }
85
86 pub fn vit_huge(
88 image_size: (usize, usize),
89 patch_size: (usize, usize),
90 in_channels: usize,
91 num_classes: usize,
92 ) -> Self {
93 Self {
94 image_size,
95 patch_size,
96 in_channels,
97 num_classes,
98 embed_dim: 1280,
99 num_layers: 32,
100 num_heads: 16,
101 mlp_dim: 5120,
102 dropout_rate: 0.1,
103 attention_dropout_rate: 0.0,
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
110struct TransformerMlp<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign> {
111 dense1: Dense<F>,
112 dense2: Dense<F>,
113}
114
115impl<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> Layer<F>
116 for TransformerMlp<F>
117{
118 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
119 let mut x = self.dense1.forward(input)?;
120 x = x.mapv(|v| {
122 let x3 = v * v * v;
124 v * F::from(0.5).expect("Failed to convert constant to float")
125 * (F::one()
126 + (v + F::from(0.044715).expect("Failed to convert constant to float") * x3)
127 .tanh())
128 });
129 x = self.dense2.forward(&x)?;
130 Ok(x)
131 }
132
133 fn backward(
134 &self,
135 _input: &Array<F, IxDyn>,
136 grad_output: &Array<F, IxDyn>,
137 ) -> Result<Array<F, IxDyn>> {
138 Ok(grad_output.clone())
139 }
140
141 fn update(&mut self, learning_rate: F) -> Result<()> {
142 self.dense1.update(learning_rate)?;
143 self.dense2.update(learning_rate)?;
144 Ok(())
145 }
146
147 fn as_any(&self) -> &dyn std::any::Any {
148 self
149 }
150
151 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
152 self
153 }
154}
155
156struct TransformerEncoderBlock<
158 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign,
159> {
160 norm1: LayerNorm<F>,
162 attention: MultiHeadAttention<F>,
164 norm2: LayerNorm<F>,
166 mlp: TransformerMlp<F>,
168 attn_dropout: Dropout<F>,
170 mlp_dropout: Dropout<F>,
172}
173
174impl<
175 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
176 > Clone for TransformerEncoderBlock<F>
177{
178 fn clone(&self) -> Self {
179 Self {
180 norm1: self.norm1.clone(),
181 attention: self.attention.clone(),
182 norm2: self.norm2.clone(),
183 mlp: self.mlp.clone(),
184 attn_dropout: self.attn_dropout.clone(),
185 mlp_dropout: self.mlp_dropout.clone(),
186 }
187 }
188}
189
190impl<
191 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
192 > TransformerEncoderBlock<F>
193{
194 pub fn new(
196 dim: usize,
197 num_heads: usize,
198 mlp_dim: usize,
199 dropout_rate: F,
200 attention_dropout_rate: F,
201 ) -> Result<Self> {
202 let mut ln_rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
204 let norm1 = LayerNorm::new(dim, 1e-6, &mut ln_rng)?;
205
206 let attn_config = crate::layers::AttentionConfig {
208 num_heads,
209 head_dim: dim / num_heads,
210 dropout_prob: attention_dropout_rate.to_f64().expect("Operation failed"),
211 causal: false,
212 scale: None,
213 };
214 let mut attn_rng = scirs2_core::random::rngs::SmallRng::from_seed([43; 32]);
215 let attention = MultiHeadAttention::new(dim, attn_config, &mut attn_rng)?;
216
217 let mut ln2_rng = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
219 let norm2 = LayerNorm::new(dim, 1e-6, &mut ln2_rng)?;
220
221 let mut mlp_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
223 let mut mlp_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
224 let mlp = TransformerMlp {
225 dense1: Dense::new(dim, mlp_dim, None, &mut mlp_rng1)?,
226 dense2: Dense::new(mlp_dim, dim, None, &mut mlp_rng2)?,
227 };
228
229 let dropout_rate_f64 = dropout_rate.to_f64().expect("Operation failed");
231 let mut dropout_rng1 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
232 let mut dropout_rng2 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
233 let attn_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng1)?;
234 let mlp_dropout = Dropout::new(dropout_rate_f64, &mut dropout_rng2)?;
235
236 Ok(Self {
237 norm1,
238 attention,
239 norm2,
240 mlp,
241 attn_dropout,
242 mlp_dropout,
243 })
244 }
245}
246
247impl<
248 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
249 > Layer<F> for TransformerEncoderBlock<F>
250{
251 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
252 let norm1_out = self.norm1.forward(input)?;
254 let attn = self.attention.forward(&norm1_out)?;
255 let attn_drop = self.attn_dropout.forward(&attn)?;
256
257 let residual1 = input + &attn_drop;
259
260 let norm2_out = self.norm2.forward(&residual1)?;
262 let mlp_out = self.mlp.forward(&norm2_out)?;
263 let mlp_drop = self.mlp_dropout.forward(&mlp_out)?;
264
265 let residual2 = &residual1 + &mlp_drop;
267
268 Ok(residual2)
269 }
270
271 fn backward(
272 &self,
273 _input: &Array<F, IxDyn>,
274 grad_output: &Array<F, IxDyn>,
275 ) -> Result<Array<F, IxDyn>> {
276 Ok(grad_output.clone())
277 }
278
279 fn update(&mut self, learning_rate: F) -> Result<()> {
280 self.norm1.update(learning_rate)?;
281 self.attention.update(learning_rate)?;
282 self.norm2.update(learning_rate)?;
283 self.mlp.update(learning_rate)?;
284 Ok(())
285 }
286
287 fn as_any(&self) -> &dyn std::any::Any {
288 self
289 }
290
291 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
292 self
293 }
294}
295
296pub struct VisionTransformer<
302 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign,
303> {
304 patch_embed: PatchEmbedding<F>,
306 cls_token: Array<F, IxDyn>,
308 pos_embed: Array<F, IxDyn>,
310 dropout: Dropout<F>,
312 encoder_blocks: Vec<TransformerEncoderBlock<F>>,
314 norm: LayerNorm<F>,
316 classifier: Dense<F>,
318 config: ViTConfig,
320}
321
322impl<
323 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
324 > std::fmt::Debug for VisionTransformer<F>
325{
326 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327 f.debug_struct("VisionTransformer")
328 .field("patch_embed", &self.patch_embed)
329 .field("cls_token", &self.cls_token)
330 .field("pos_embed", &self.pos_embed)
331 .field("dropout", &self.dropout)
332 .field(
333 "encoder_blocks",
334 &format!("<{} blocks>", self.encoder_blocks.len()),
335 )
336 .field("norm", &self.norm)
337 .field("classifier", &self.classifier)
338 .field("config", &self.config)
339 .finish()
340 }
341}
342
343impl<
344 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
345 > Clone for VisionTransformer<F>
346{
347 fn clone(&self) -> Self {
348 Self {
349 patch_embed: self.patch_embed.clone(),
350 cls_token: self.cls_token.clone(),
351 pos_embed: self.pos_embed.clone(),
352 dropout: self.dropout.clone(),
353 encoder_blocks: self.encoder_blocks.clone(),
354 norm: self.norm.clone(),
355 classifier: self.classifier.clone(),
356 config: self.config.clone(),
357 }
358 }
359}
360
361impl<
362 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
363 > VisionTransformer<F>
364{
365 pub fn new(config: ViTConfig) -> Result<Self> {
367 let h_patches = config.image_size.0 / config.patch_size.0;
369 let w_patches = config.image_size.1 / config.patch_size.1;
370 let num_patches = h_patches * w_patches;
371
372 let mut pe_rng = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
374 let patch_embed = PatchEmbedding::new(
375 config.image_size,
376 config.patch_size,
377 config.in_channels,
378 config.embed_dim,
379 true,
380 &mut pe_rng,
381 )?;
382
383 let cls_token = Array::zeros(IxDyn(&[1, 1, config.embed_dim]));
385
386 let pos_embed = Array::zeros(IxDyn(&[1, num_patches + 1, config.embed_dim]));
388
389 let mut dropout_rng = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
391 let dropout = Dropout::new(config.dropout_rate, &mut dropout_rng)?;
392
393 let mut encoder_blocks = Vec::with_capacity(config.num_layers);
395 for i in 0..config.num_layers {
396 let block = TransformerEncoderBlock::new(
397 config.embed_dim,
398 config.num_heads,
399 config.mlp_dim,
400 F::from(config.dropout_rate).ok_or_else(|| {
401 NeuralError::InvalidArchitecture(
402 "Failed to convert dropout_rate to float".to_string(),
403 )
404 })?,
405 F::from(config.attention_dropout_rate).ok_or_else(|| {
406 NeuralError::InvalidArchitecture(
407 "Failed to convert attention_dropout_rate to float".to_string(),
408 )
409 })?,
410 )?;
411 encoder_blocks.push(block);
412 let _ = i;
413 }
414
415 let mut norm_rng = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
417 let norm = LayerNorm::new(config.embed_dim, 1e-6, &mut norm_rng)?;
418
419 let mut classifier_rng = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
421 let classifier = Dense::new(
422 config.embed_dim,
423 config.num_classes,
424 None,
425 &mut classifier_rng,
426 )?;
427
428 Ok(Self {
429 patch_embed,
430 cls_token,
431 pos_embed,
432 dropout,
433 encoder_blocks,
434 norm,
435 classifier,
436 config,
437 })
438 }
439
440 pub fn vit_base(
442 image_size: (usize, usize),
443 patch_size: (usize, usize),
444 in_channels: usize,
445 num_classes: usize,
446 ) -> Result<Self> {
447 let config = ViTConfig::vit_base(image_size, patch_size, in_channels, num_classes);
448 Self::new(config)
449 }
450
451 pub fn vit_large(
453 image_size: (usize, usize),
454 patch_size: (usize, usize),
455 in_channels: usize,
456 num_classes: usize,
457 ) -> Result<Self> {
458 let config = ViTConfig::vit_large(image_size, patch_size, in_channels, num_classes);
459 Self::new(config)
460 }
461
462 pub fn vit_huge(
464 image_size: (usize, usize),
465 patch_size: (usize, usize),
466 in_channels: usize,
467 num_classes: usize,
468 ) -> Result<Self> {
469 let config = ViTConfig::vit_huge(image_size, patch_size, in_channels, num_classes);
470 Self::new(config)
471 }
472
473 pub fn config(&self) -> &ViTConfig {
475 &self.config
476 }
477}
478
479impl<
480 F: Float + Debug + ScalarOperand + Clone + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
481 > Layer<F> for VisionTransformer<F>
482{
483 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
484 let shape = input.shape();
485 if shape.len() != 4
486 || shape[1] != self.config.in_channels
487 || shape[2] != self.config.image_size.0
488 || shape[3] != self.config.image_size.1
489 {
490 return Err(NeuralError::InferenceError(format!(
491 "Expected input shape [batch_size, {}, {}, {}], got {:?}",
492 self.config.in_channels, self.config.image_size.0, self.config.image_size.1, shape
493 )));
494 }
495
496 let batch_size = shape[0];
497
498 let x = self.patch_embed.forward(input)?;
500
501 let h_patches = self.config.image_size.0 / self.config.patch_size.0;
502 let w_patches = self.config.image_size.1 / self.config.patch_size.1;
503 let num_patches = h_patches * w_patches;
504
505 let mut x_with_cls =
507 Array::zeros(IxDyn(&[batch_size, num_patches + 1, self.config.embed_dim]));
508
509 for b in 0..batch_size {
511 for i in 0..self.config.embed_dim {
512 x_with_cls[[b, 0, i]] = self.cls_token[[0, 0, i]];
513 }
514 }
515
516 for b in 0..batch_size {
518 for p in 0..num_patches {
519 for i in 0..self.config.embed_dim {
520 x_with_cls[[b, p + 1, i]] = x[[b, p, i]];
521 }
522 }
523 }
524
525 for b in 0..batch_size {
527 for p in 0..num_patches + 1 {
528 for i in 0..self.config.embed_dim {
529 x_with_cls[[b, p, i]] += self.pos_embed[[0, p, i]];
530 }
531 }
532 }
533
534 let mut x = self.dropout.forward(&x_with_cls)?;
536 for block in &self.encoder_blocks {
537 x = block.forward(&x)?;
538 }
539 x = self.norm.forward(&x)?;
540
541 let mut cls_token_final = Array::zeros(IxDyn(&[batch_size, self.config.embed_dim]));
543 for b in 0..batch_size {
544 for i in 0..self.config.embed_dim {
545 cls_token_final[[b, i]] = x[[b, 0, i]];
546 }
547 }
548
549 self.classifier.forward(&cls_token_final)
551 }
552
553 fn backward(
554 &self,
555 _input: &Array<F, IxDyn>,
556 grad_output: &Array<F, IxDyn>,
557 ) -> Result<Array<F, IxDyn>> {
558 Ok(grad_output.clone())
559 }
560
561 fn update(&mut self, learning_rate: F) -> Result<()> {
562 self.patch_embed.update(learning_rate)?;
563 for block in &mut self.encoder_blocks {
564 block.update(learning_rate)?;
565 }
566 self.norm.update(learning_rate)?;
567 self.classifier.update(learning_rate)?;
568 Ok(())
569 }
570
571 fn as_any(&self) -> &dyn std::any::Any {
572 self
573 }
574
575 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
576 self
577 }
578
579 fn layer_type(&self) -> &str {
580 "VisionTransformer"
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587
588 #[test]
589 fn test_vit_config_base() {
590 let config = ViTConfig::vit_base((224, 224), (16, 16), 3, 1000);
591 assert_eq!(config.embed_dim, 768);
592 assert_eq!(config.num_layers, 12);
593 assert_eq!(config.num_heads, 12);
594 }
595
596 #[test]
597 fn test_vit_config_large() {
598 let config = ViTConfig::vit_large((224, 224), (16, 16), 3, 1000);
599 assert_eq!(config.embed_dim, 1024);
600 assert_eq!(config.num_layers, 24);
601 assert_eq!(config.num_heads, 16);
602 }
603
604 #[test]
605 fn test_vit_config_huge() {
606 let config = ViTConfig::vit_huge((224, 224), (16, 16), 3, 1000);
607 assert_eq!(config.embed_dim, 1280);
608 assert_eq!(config.num_layers, 32);
609 assert_eq!(config.num_heads, 16);
610 }
611}