Skip to main content

tensorlogic_trustformers/
vision.rs

1//! Vision Transformer (ViT) components for image processing
2//!
3//! This module implements Vision Transformer architectures that process images
4//! as sequences of patches, enabling transformer models to be applied to computer vision tasks.
5//!
6//! ## Architecture
7//!
8//! Vision Transformers follow this processing pipeline:
9//!
10//! 1. **Patch Embedding**: Split image into fixed-size patches and linearly embed them
11//!    - Input: `[batch, channels, height, width]`
12//!    - Patches: `[batch, num_patches, patch_dim]` where `patch_dim = patch_size² × channels`
13//!    - Embedding: `einsum("bnp,pd->bnd", patches, W_embed)` → `[batch, num_patches, d_model]`
14//!
15//! 2. **Class Token**: Prepend learnable classification token
16//!    - `[batch, 1, d_model]` concatenated with patch embeddings
17//!    - Final: `[batch, num_patches + 1, d_model]`
18//!
19//! 3. **Position Embeddings**: Add 2D position encodings
20//!    - Learnable or sinusoidal embeddings for each patch position
21//!    - Shape: `[1, num_patches + 1, d_model]`
22//!
23//! 4. **Transformer Encoder**: Standard transformer encoder layers
24//!
25//! 5. **Classification Head**: MLP head applied to class token output
26//!
27//! ## Example
28//!
29//! ```rust
30//! use tensorlogic_trustformers::vision::{VisionTransformerConfig, VisionTransformer};
31//! use tensorlogic_ir::EinsumGraph;
32//!
33//! // Configure ViT-Base/16 (16x16 patches, 224x224 images)
34//! let config = VisionTransformerConfig::new(
35//!     224,  // image_size
36//!     16,   // patch_size
37//!     3,    // in_channels (RGB)
38//!     768,  // d_model
39//!     12,   // n_heads
40//!     3072, // d_ff
41//!     12,   // n_layers
42//!     1000, // num_classes
43//! ).unwrap();
44//!
45//! let vit = VisionTransformer::new(config).unwrap();
46//!
47//! let mut graph = EinsumGraph::new();
48//! graph.add_tensor("image");
49//! let output = vit.build_vit_graph(&mut graph).unwrap();
50//! ```
51
52use crate::error::{Result, TrustformerError};
53use crate::stacks::{EncoderStack, EncoderStackConfig};
54use tensorlogic_ir::{EinsumGraph, EinsumNode};
55
56/// Configuration for patch embedding layer
57#[derive(Debug, Clone)]
58pub struct PatchEmbeddingConfig {
59    /// Image size (assumes square images)
60    pub image_size: usize,
61    /// Patch size (assumes square patches)
62    pub patch_size: usize,
63    /// Number of input channels (e.g., 3 for RGB)
64    pub in_channels: usize,
65    /// Embedding dimension
66    pub d_model: usize,
67}
68
69impl PatchEmbeddingConfig {
70    /// Create new patch embedding configuration
71    pub fn new(
72        image_size: usize,
73        patch_size: usize,
74        in_channels: usize,
75        d_model: usize,
76    ) -> Result<Self> {
77        if image_size == 0 {
78            return Err(TrustformerError::CompilationError(
79                "image_size must be > 0".into(),
80            ));
81        }
82        if patch_size == 0 {
83            return Err(TrustformerError::CompilationError(
84                "patch_size must be > 0".into(),
85            ));
86        }
87        if !image_size.is_multiple_of(patch_size) {
88            return Err(TrustformerError::CompilationError(format!(
89                "image_size ({}) must be divisible by patch_size ({})",
90                image_size, patch_size
91            )));
92        }
93        if in_channels == 0 {
94            return Err(TrustformerError::CompilationError(
95                "in_channels must be > 0".into(),
96            ));
97        }
98        if d_model == 0 {
99            return Err(TrustformerError::CompilationError(
100                "d_model must be > 0".into(),
101            ));
102        }
103
104        Ok(Self {
105            image_size,
106            patch_size,
107            in_channels,
108            d_model,
109        })
110    }
111
112    /// Get number of patches
113    pub fn num_patches(&self) -> usize {
114        let patches_per_side = self.image_size / self.patch_size;
115        patches_per_side * patches_per_side
116    }
117
118    /// Get patch dimension (patch_size² × in_channels)
119    pub fn patch_dim(&self) -> usize {
120        self.patch_size * self.patch_size * self.in_channels
121    }
122
123    /// Validate configuration
124    pub fn validate(&self) -> Result<()> {
125        if !self.image_size.is_multiple_of(self.patch_size) {
126            return Err(TrustformerError::CompilationError(
127                "image_size must be divisible by patch_size".into(),
128            ));
129        }
130        Ok(())
131    }
132}
133
134/// Patch embedding layer for Vision Transformers
135pub struct PatchEmbedding {
136    config: PatchEmbeddingConfig,
137}
138
139impl PatchEmbedding {
140    /// Create new patch embedding layer
141    pub fn new(config: PatchEmbeddingConfig) -> Result<Self> {
142        config.validate()?;
143        Ok(Self { config })
144    }
145
146    /// Build patch embedding einsum graph
147    ///
148    /// Converts image patches to token embeddings
149    ///
150    /// # Graph Inputs
151    /// - Tensor 0: Input image tensor `[batch, in_channels, height, width]` or patches `[batch, num_patches, patch_dim]`
152    /// - Tensor 1: Patch embedding weights `[patch_dim, d_model]`
153    ///
154    /// # Graph Output
155    /// - Embedded patches `[batch, num_patches, d_model]`
156    pub fn build_patch_embed_graph(&self, graph: &mut EinsumGraph) -> Result<usize> {
157        // Expected inputs:
158        // - Tensor 0: patches [batch, num_patches, patch_dim]
159        // - Tensor 1: W_patch_embed [patch_dim, d_model]
160        //
161        // In practice, patching would involve:
162        // 1. Unfold: [B, C, H, W] → [B, C, n_patches_h, n_patches_w, patch_h, patch_w]
163        // 2. Reshape: [B, C, n_patches_h, n_patches_w, patch_h, patch_w] → [B, num_patches, patch_dim]
164        // 3. Linear: einsum("bnp,pd->bnd", patches, W_patch_embed)
165        //
166        // For simplicity, we assume the patching is already done and input tensor 0 contains patches
167
168        // Create einsum node for patch embedding
169        // einsum("bnp,pd->bnd", patches, W_patch_embed)
170        let output_tensor = graph.add_tensor("patch_embeddings");
171        let node = EinsumNode::new("bnp,pd->bnd", vec![0, 1], vec![output_tensor]);
172        graph.add_node(node)?;
173
174        Ok(output_tensor)
175    }
176
177    /// Get configuration
178    pub fn config(&self) -> &PatchEmbeddingConfig {
179        &self.config
180    }
181}
182
183/// Configuration for Vision Transformer
184#[derive(Debug, Clone)]
185pub struct VisionTransformerConfig {
186    /// Patch embedding configuration
187    pub patch_embed: PatchEmbeddingConfig,
188    /// Transformer encoder stack configuration
189    pub encoder: EncoderStackConfig,
190    /// Number of output classes
191    pub num_classes: usize,
192    /// Whether to use class token
193    pub use_class_token: bool,
194    /// Dropout rate for classification head
195    pub classifier_dropout: f64,
196}
197
198impl VisionTransformerConfig {
199    /// Create new Vision Transformer configuration
200    #[allow(clippy::too_many_arguments)]
201    pub fn new(
202        image_size: usize,
203        patch_size: usize,
204        in_channels: usize,
205        d_model: usize,
206        n_heads: usize,
207        d_ff: usize,
208        n_layers: usize,
209        num_classes: usize,
210    ) -> Result<Self> {
211        let patch_embed = PatchEmbeddingConfig::new(image_size, patch_size, in_channels, d_model)?;
212
213        // For Vision Transformers, we use learned position encoding by default
214        let max_seq_len = patch_embed.num_patches() + 1; // +1 for class token
215        let encoder = EncoderStackConfig::new(n_layers, d_model, n_heads, d_ff, max_seq_len)?
216            .with_learned_position_encoding();
217
218        Ok(Self {
219            patch_embed,
220            encoder,
221            num_classes,
222            use_class_token: true,
223            classifier_dropout: 0.0,
224        })
225    }
226
227    /// Builder: Set whether to use class token
228    pub fn with_class_token(mut self, use_class_token: bool) -> Self {
229        self.use_class_token = use_class_token;
230        self
231    }
232
233    /// Builder: Set classifier dropout
234    pub fn with_classifier_dropout(mut self, dropout: f64) -> Self {
235        self.classifier_dropout = dropout;
236        self
237    }
238
239    /// Builder: Set learned position encoding
240    pub fn with_learned_position_encoding(mut self) -> Self {
241        self.encoder = self.encoder.with_learned_position_encoding();
242        self
243    }
244
245    /// Builder: Set whether to use pre-norm
246    pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
247        self.encoder.layer_config = self.encoder.layer_config.with_pre_norm(pre_norm);
248        self
249    }
250
251    /// Builder: Set dropout
252    pub fn with_dropout(mut self, dropout: f64) -> Self {
253        self.encoder = self.encoder.with_dropout(dropout);
254        self.classifier_dropout = dropout;
255        self
256    }
257
258    /// Get sequence length (num_patches + optional class token)
259    pub fn seq_length(&self) -> usize {
260        let base = self.patch_embed.num_patches();
261        if self.use_class_token {
262            base + 1
263        } else {
264            base
265        }
266    }
267
268    /// Validate configuration
269    pub fn validate(&self) -> Result<()> {
270        self.patch_embed.validate()?;
271        self.encoder.validate()?;
272        if self.num_classes == 0 {
273            return Err(TrustformerError::CompilationError(
274                "num_classes must be > 0".into(),
275            ));
276        }
277        Ok(())
278    }
279}
280
281/// Vision Transformer (ViT) model
282pub struct VisionTransformer {
283    config: VisionTransformerConfig,
284    patch_embed: PatchEmbedding,
285    #[allow(dead_code)] // Will be used in future complete implementation
286    encoder: EncoderStack,
287}
288
289impl VisionTransformer {
290    /// Create new Vision Transformer
291    pub fn new(config: VisionTransformerConfig) -> Result<Self> {
292        config.validate()?;
293
294        let patch_embed = PatchEmbedding::new(config.patch_embed.clone())?;
295        let encoder = EncoderStack::new(config.encoder.clone())?;
296
297        Ok(Self {
298            config,
299            patch_embed,
300            encoder,
301        })
302    }
303
304    /// Build complete Vision Transformer einsum graph
305    ///
306    /// # Graph Inputs
307    /// - Tensor 0: Input patches `[batch, num_patches, patch_dim]`
308    /// - Tensor 1: Patch embedding weights `[patch_dim, d_model]`
309    /// - Additional encoder tensors (weights for each layer)
310    ///
311    /// # Graph Outputs
312    /// - Classification logits `[batch, num_classes]`
313    ///
314    /// Note: This is a simplified representation. In practice, you would need to provide
315    /// all encoder layer weights and handle class token prepending/extraction properly.
316    pub fn build_vit_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
317        // 1. Patch embedding: convert patches to embeddings
318        let patches = self.patch_embed.build_patch_embed_graph(graph)?;
319
320        // 2. Add position embeddings
321        // In a full implementation, this would add learned or sinusoidal position embeddings
322        // For now, we represent this as an element-wise addition
323        let positioned = graph.add_tensor("positioned_embeddings");
324        let pos_add_node = EinsumNode::elem_binary(
325            "add_pos_embed".to_string(),
326            patches,
327            2, // Position embedding tensor
328            positioned,
329        );
330        graph.add_node(pos_add_node)?;
331
332        // 3. Pass through transformer encoder
333        // Note: The encoder.build_encoder_graph() expects tensor 0 as input
334        // We would need to refactor this to properly handle the positioned embeddings
335        // For now, this is a placeholder that shows the structure
336
337        // 4. Classification head
338        // In a full implementation, this would:
339        // - Extract class token (first position) or pool all tokens
340        // - Apply linear layer: einsum("bd,dc->bc", class_repr, W_classifier)
341        // - Add bias
342
343        // For this simplified version, we just return the positioned embeddings
344        // as a placeholder output
345        Ok(vec![positioned])
346    }
347
348    /// Get configuration
349    pub fn config(&self) -> &VisionTransformerConfig {
350        &self.config
351    }
352
353    /// Count total parameters
354    pub fn count_parameters(&self) -> usize {
355        let mut total = 0;
356
357        // Patch embedding: patch_dim × d_model
358        total += self.config.patch_embed.patch_dim() * self.config.patch_embed.d_model;
359
360        // Class token (if used): d_model
361        if self.config.use_class_token {
362            total += self.config.patch_embed.d_model;
363        }
364
365        // Position embeddings: seq_length × d_model
366        total += self.config.seq_length() * self.config.patch_embed.d_model;
367
368        // Encoder parameters (all layers)
369        let params_per_layer =
370            crate::utils::count_encoder_layer_params(&self.config.encoder.layer_config);
371        total += params_per_layer * self.config.encoder.num_layers;
372
373        // Final layer norm (if enabled)
374        if self.config.encoder.final_layer_norm {
375            total +=
376                crate::utils::count_layernorm_params(&self.config.encoder.layer_config.layer_norm);
377        }
378
379        // Classification head: d_model × num_classes + num_classes (bias)
380        total +=
381            self.config.patch_embed.d_model * self.config.num_classes + self.config.num_classes;
382
383        total
384    }
385}
386
387/// Common Vision Transformer presets
388pub enum ViTPreset {
389    /// ViT-Tiny/16: 5.7M parameters
390    Tiny16,
391    /// ViT-Small/16: 22M parameters
392    Small16,
393    /// ViT-Base/16: 86M parameters
394    Base16,
395    /// ViT-Large/16: 307M parameters
396    Large16,
397    /// ViT-Huge/14: 632M parameters
398    Huge14,
399}
400
401impl ViTPreset {
402    /// Create configuration from preset
403    pub fn config(&self, num_classes: usize) -> Result<VisionTransformerConfig> {
404        let (image_size, patch_size, d_model, n_heads, d_ff, n_layers) = match self {
405            ViTPreset::Tiny16 => (224, 16, 192, 3, 768, 12),
406            ViTPreset::Small16 => (224, 16, 384, 6, 1536, 12),
407            ViTPreset::Base16 => (224, 16, 768, 12, 3072, 12),
408            ViTPreset::Large16 => (224, 16, 1024, 16, 4096, 24),
409            ViTPreset::Huge14 => (224, 14, 1280, 16, 5120, 32),
410        };
411
412        VisionTransformerConfig::new(
413            image_size,
414            patch_size,
415            3, // RGB
416            d_model,
417            n_heads,
418            d_ff,
419            n_layers,
420            num_classes,
421        )
422    }
423
424    /// Get preset name
425    pub fn name(&self) -> &'static str {
426        match self {
427            ViTPreset::Tiny16 => "ViT-Tiny/16",
428            ViTPreset::Small16 => "ViT-Small/16",
429            ViTPreset::Base16 => "ViT-Base/16",
430            ViTPreset::Large16 => "ViT-Large/16",
431            ViTPreset::Huge14 => "ViT-Huge/14",
432        }
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_patch_embedding_config() {
442        let config = PatchEmbeddingConfig::new(224, 16, 3, 768).unwrap();
443        assert_eq!(config.image_size, 224);
444        assert_eq!(config.patch_size, 16);
445        assert_eq!(config.in_channels, 3);
446        assert_eq!(config.d_model, 768);
447        assert_eq!(config.num_patches(), 196); // (224/16)^2
448        assert_eq!(config.patch_dim(), 768); // 16*16*3
449    }
450
451    #[test]
452    fn test_patch_embedding_invalid_size() {
453        let result = PatchEmbeddingConfig::new(225, 16, 3, 768);
454        assert!(result.is_err()); // 225 not divisible by 16
455    }
456
457    #[test]
458    fn test_patch_embedding_graph() {
459        let config = PatchEmbeddingConfig::new(224, 16, 3, 768).unwrap();
460        let patch_embed = PatchEmbedding::new(config).unwrap();
461
462        let mut graph = EinsumGraph::new();
463        graph.add_tensor("image");
464        graph.add_tensor("W_patch_embed");
465
466        let output = patch_embed.build_patch_embed_graph(&mut graph).unwrap();
467        assert!(output > 0);
468        assert!(graph.validate().is_ok());
469    }
470
471    #[test]
472    fn test_vit_config_creation() {
473        let config = VisionTransformerConfig::new(
474            224,  // image_size
475            16,   // patch_size
476            3,    // in_channels
477            768,  // d_model
478            12,   // n_heads
479            3072, // d_ff
480            12,   // n_layers
481            1000, // num_classes
482        )
483        .unwrap();
484
485        assert_eq!(config.num_classes, 1000);
486        assert!(config.use_class_token);
487        assert_eq!(config.seq_length(), 197); // 196 patches + 1 class token
488    }
489
490    #[test]
491    fn test_vit_config_without_class_token() {
492        let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000)
493            .unwrap()
494            .with_class_token(false);
495
496        assert!(!config.use_class_token);
497        assert_eq!(config.seq_length(), 196); // Only patches, no class token
498    }
499
500    #[test]
501    fn test_vit_creation() {
502        let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000).unwrap();
503        let vit = VisionTransformer::new(config).unwrap();
504
505        assert!(vit.config().validate().is_ok());
506    }
507
508    #[test]
509    fn test_vit_graph_building() {
510        let config = VisionTransformerConfig::new(224, 16, 3, 384, 6, 1536, 2, 10).unwrap();
511        let vit = VisionTransformer::new(config).unwrap();
512
513        let mut graph = EinsumGraph::new();
514        // Add required input tensors
515        graph.add_tensor("patches"); // Tensor 0
516        graph.add_tensor("W_patch_embed"); // Tensor 1
517        graph.add_tensor("pos_embed"); // Tensor 2
518
519        let result = vit.build_vit_graph(&mut graph);
520        // The graph building should succeed
521        assert!(result.is_ok());
522        let outputs = result.unwrap();
523        assert!(!outputs.is_empty());
524    }
525
526    #[test]
527    fn test_vit_parameter_count() {
528        let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000).unwrap();
529        let vit = VisionTransformer::new(config).unwrap();
530
531        let params = vit.count_parameters();
532        assert!(params > 0);
533        // ViT-Base/16 should have ~86M parameters
534        // We're not checking exact count due to implementation details
535    }
536
537    #[test]
538    fn test_vit_presets() {
539        for preset in [
540            ViTPreset::Tiny16,
541            ViTPreset::Small16,
542            ViTPreset::Base16,
543            ViTPreset::Large16,
544            ViTPreset::Huge14,
545        ] {
546            let config = preset.config(1000).unwrap();
547            assert!(config.validate().is_ok());
548            assert_eq!(config.num_classes, 1000);
549
550            let vit = VisionTransformer::new(config).unwrap();
551            assert!(vit.count_parameters() > 0);
552        }
553    }
554
555    #[test]
556    fn test_vit_preset_names() {
557        assert_eq!(ViTPreset::Tiny16.name(), "ViT-Tiny/16");
558        assert_eq!(ViTPreset::Small16.name(), "ViT-Small/16");
559        assert_eq!(ViTPreset::Base16.name(), "ViT-Base/16");
560        assert_eq!(ViTPreset::Large16.name(), "ViT-Large/16");
561        assert_eq!(ViTPreset::Huge14.name(), "ViT-Huge/14");
562    }
563
564    #[test]
565    fn test_different_image_sizes() {
566        for (image_size, patch_size) in [(224, 16), (384, 16), (512, 32)] {
567            let config = PatchEmbeddingConfig::new(image_size, patch_size, 3, 768).unwrap();
568            let expected_patches = (image_size / patch_size) * (image_size / patch_size);
569            assert_eq!(config.num_patches(), expected_patches);
570        }
571    }
572
573    #[test]
574    fn test_vit_config_builder() {
575        let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000)
576            .unwrap()
577            .with_class_token(true)
578            .with_classifier_dropout(0.1)
579            .with_pre_norm(true)
580            .with_dropout(0.1);
581
582        assert!(config.use_class_token);
583        assert!((config.classifier_dropout - 0.1).abs() < 1e-10);
584        assert!(config.encoder.layer_config.pre_norm);
585        assert!(config.validate().is_ok());
586    }
587}