1use crate::error::{Result, TrustformerError};
53use crate::stacks::{EncoderStack, EncoderStackConfig};
54use tensorlogic_ir::{EinsumGraph, EinsumNode};
55
56#[derive(Debug, Clone)]
58pub struct PatchEmbeddingConfig {
59 pub image_size: usize,
61 pub patch_size: usize,
63 pub in_channels: usize,
65 pub d_model: usize,
67}
68
69impl PatchEmbeddingConfig {
70 pub fn new(
72 image_size: usize,
73 patch_size: usize,
74 in_channels: usize,
75 d_model: usize,
76 ) -> Result<Self> {
77 if image_size == 0 {
78 return Err(TrustformerError::CompilationError(
79 "image_size must be > 0".into(),
80 ));
81 }
82 if patch_size == 0 {
83 return Err(TrustformerError::CompilationError(
84 "patch_size must be > 0".into(),
85 ));
86 }
87 if !image_size.is_multiple_of(patch_size) {
88 return Err(TrustformerError::CompilationError(format!(
89 "image_size ({}) must be divisible by patch_size ({})",
90 image_size, patch_size
91 )));
92 }
93 if in_channels == 0 {
94 return Err(TrustformerError::CompilationError(
95 "in_channels must be > 0".into(),
96 ));
97 }
98 if d_model == 0 {
99 return Err(TrustformerError::CompilationError(
100 "d_model must be > 0".into(),
101 ));
102 }
103
104 Ok(Self {
105 image_size,
106 patch_size,
107 in_channels,
108 d_model,
109 })
110 }
111
112 pub fn num_patches(&self) -> usize {
114 let patches_per_side = self.image_size / self.patch_size;
115 patches_per_side * patches_per_side
116 }
117
118 pub fn patch_dim(&self) -> usize {
120 self.patch_size * self.patch_size * self.in_channels
121 }
122
123 pub fn validate(&self) -> Result<()> {
125 if !self.image_size.is_multiple_of(self.patch_size) {
126 return Err(TrustformerError::CompilationError(
127 "image_size must be divisible by patch_size".into(),
128 ));
129 }
130 Ok(())
131 }
132}
133
134pub struct PatchEmbedding {
136 config: PatchEmbeddingConfig,
137}
138
139impl PatchEmbedding {
140 pub fn new(config: PatchEmbeddingConfig) -> Result<Self> {
142 config.validate()?;
143 Ok(Self { config })
144 }
145
146 pub fn build_patch_embed_graph(&self, graph: &mut EinsumGraph) -> Result<usize> {
157 let output_tensor = graph.add_tensor("patch_embeddings");
171 let node = EinsumNode::new("bnp,pd->bnd", vec![0, 1], vec![output_tensor]);
172 graph.add_node(node)?;
173
174 Ok(output_tensor)
175 }
176
177 pub fn config(&self) -> &PatchEmbeddingConfig {
179 &self.config
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct VisionTransformerConfig {
186 pub patch_embed: PatchEmbeddingConfig,
188 pub encoder: EncoderStackConfig,
190 pub num_classes: usize,
192 pub use_class_token: bool,
194 pub classifier_dropout: f64,
196}
197
198impl VisionTransformerConfig {
199 #[allow(clippy::too_many_arguments)]
201 pub fn new(
202 image_size: usize,
203 patch_size: usize,
204 in_channels: usize,
205 d_model: usize,
206 n_heads: usize,
207 d_ff: usize,
208 n_layers: usize,
209 num_classes: usize,
210 ) -> Result<Self> {
211 let patch_embed = PatchEmbeddingConfig::new(image_size, patch_size, in_channels, d_model)?;
212
213 let max_seq_len = patch_embed.num_patches() + 1; let encoder = EncoderStackConfig::new(n_layers, d_model, n_heads, d_ff, max_seq_len)?
216 .with_learned_position_encoding();
217
218 Ok(Self {
219 patch_embed,
220 encoder,
221 num_classes,
222 use_class_token: true,
223 classifier_dropout: 0.0,
224 })
225 }
226
227 pub fn with_class_token(mut self, use_class_token: bool) -> Self {
229 self.use_class_token = use_class_token;
230 self
231 }
232
233 pub fn with_classifier_dropout(mut self, dropout: f64) -> Self {
235 self.classifier_dropout = dropout;
236 self
237 }
238
239 pub fn with_learned_position_encoding(mut self) -> Self {
241 self.encoder = self.encoder.with_learned_position_encoding();
242 self
243 }
244
245 pub fn with_pre_norm(mut self, pre_norm: bool) -> Self {
247 self.encoder.layer_config = self.encoder.layer_config.with_pre_norm(pre_norm);
248 self
249 }
250
251 pub fn with_dropout(mut self, dropout: f64) -> Self {
253 self.encoder = self.encoder.with_dropout(dropout);
254 self.classifier_dropout = dropout;
255 self
256 }
257
258 pub fn seq_length(&self) -> usize {
260 let base = self.patch_embed.num_patches();
261 if self.use_class_token {
262 base + 1
263 } else {
264 base
265 }
266 }
267
268 pub fn validate(&self) -> Result<()> {
270 self.patch_embed.validate()?;
271 self.encoder.validate()?;
272 if self.num_classes == 0 {
273 return Err(TrustformerError::CompilationError(
274 "num_classes must be > 0".into(),
275 ));
276 }
277 Ok(())
278 }
279}
280
281pub struct VisionTransformer {
283 config: VisionTransformerConfig,
284 patch_embed: PatchEmbedding,
285 #[allow(dead_code)] encoder: EncoderStack,
287}
288
289impl VisionTransformer {
290 pub fn new(config: VisionTransformerConfig) -> Result<Self> {
292 config.validate()?;
293
294 let patch_embed = PatchEmbedding::new(config.patch_embed.clone())?;
295 let encoder = EncoderStack::new(config.encoder.clone())?;
296
297 Ok(Self {
298 config,
299 patch_embed,
300 encoder,
301 })
302 }
303
304 pub fn build_vit_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
317 let patches = self.patch_embed.build_patch_embed_graph(graph)?;
319
320 let positioned = graph.add_tensor("positioned_embeddings");
324 let pos_add_node = EinsumNode::elem_binary(
325 "add_pos_embed".to_string(),
326 patches,
327 2, positioned,
329 );
330 graph.add_node(pos_add_node)?;
331
332 Ok(vec![positioned])
346 }
347
348 pub fn config(&self) -> &VisionTransformerConfig {
350 &self.config
351 }
352
353 pub fn count_parameters(&self) -> usize {
355 let mut total = 0;
356
357 total += self.config.patch_embed.patch_dim() * self.config.patch_embed.d_model;
359
360 if self.config.use_class_token {
362 total += self.config.patch_embed.d_model;
363 }
364
365 total += self.config.seq_length() * self.config.patch_embed.d_model;
367
368 let params_per_layer =
370 crate::utils::count_encoder_layer_params(&self.config.encoder.layer_config);
371 total += params_per_layer * self.config.encoder.num_layers;
372
373 if self.config.encoder.final_layer_norm {
375 total +=
376 crate::utils::count_layernorm_params(&self.config.encoder.layer_config.layer_norm);
377 }
378
379 total +=
381 self.config.patch_embed.d_model * self.config.num_classes + self.config.num_classes;
382
383 total
384 }
385}
386
387pub enum ViTPreset {
389 Tiny16,
391 Small16,
393 Base16,
395 Large16,
397 Huge14,
399}
400
401impl ViTPreset {
402 pub fn config(&self, num_classes: usize) -> Result<VisionTransformerConfig> {
404 let (image_size, patch_size, d_model, n_heads, d_ff, n_layers) = match self {
405 ViTPreset::Tiny16 => (224, 16, 192, 3, 768, 12),
406 ViTPreset::Small16 => (224, 16, 384, 6, 1536, 12),
407 ViTPreset::Base16 => (224, 16, 768, 12, 3072, 12),
408 ViTPreset::Large16 => (224, 16, 1024, 16, 4096, 24),
409 ViTPreset::Huge14 => (224, 14, 1280, 16, 5120, 32),
410 };
411
412 VisionTransformerConfig::new(
413 image_size,
414 patch_size,
415 3, d_model,
417 n_heads,
418 d_ff,
419 n_layers,
420 num_classes,
421 )
422 }
423
424 pub fn name(&self) -> &'static str {
426 match self {
427 ViTPreset::Tiny16 => "ViT-Tiny/16",
428 ViTPreset::Small16 => "ViT-Small/16",
429 ViTPreset::Base16 => "ViT-Base/16",
430 ViTPreset::Large16 => "ViT-Large/16",
431 ViTPreset::Huge14 => "ViT-Huge/14",
432 }
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_patch_embedding_config() {
442 let config = PatchEmbeddingConfig::new(224, 16, 3, 768).unwrap();
443 assert_eq!(config.image_size, 224);
444 assert_eq!(config.patch_size, 16);
445 assert_eq!(config.in_channels, 3);
446 assert_eq!(config.d_model, 768);
447 assert_eq!(config.num_patches(), 196); assert_eq!(config.patch_dim(), 768); }
450
451 #[test]
452 fn test_patch_embedding_invalid_size() {
453 let result = PatchEmbeddingConfig::new(225, 16, 3, 768);
454 assert!(result.is_err()); }
456
457 #[test]
458 fn test_patch_embedding_graph() {
459 let config = PatchEmbeddingConfig::new(224, 16, 3, 768).unwrap();
460 let patch_embed = PatchEmbedding::new(config).unwrap();
461
462 let mut graph = EinsumGraph::new();
463 graph.add_tensor("image");
464 graph.add_tensor("W_patch_embed");
465
466 let output = patch_embed.build_patch_embed_graph(&mut graph).unwrap();
467 assert!(output > 0);
468 assert!(graph.validate().is_ok());
469 }
470
471 #[test]
472 fn test_vit_config_creation() {
473 let config = VisionTransformerConfig::new(
474 224, 16, 3, 768, 12, 3072, 12, 1000, )
483 .unwrap();
484
485 assert_eq!(config.num_classes, 1000);
486 assert!(config.use_class_token);
487 assert_eq!(config.seq_length(), 197); }
489
490 #[test]
491 fn test_vit_config_without_class_token() {
492 let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000)
493 .unwrap()
494 .with_class_token(false);
495
496 assert!(!config.use_class_token);
497 assert_eq!(config.seq_length(), 196); }
499
500 #[test]
501 fn test_vit_creation() {
502 let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000).unwrap();
503 let vit = VisionTransformer::new(config).unwrap();
504
505 assert!(vit.config().validate().is_ok());
506 }
507
508 #[test]
509 fn test_vit_graph_building() {
510 let config = VisionTransformerConfig::new(224, 16, 3, 384, 6, 1536, 2, 10).unwrap();
511 let vit = VisionTransformer::new(config).unwrap();
512
513 let mut graph = EinsumGraph::new();
514 graph.add_tensor("patches"); graph.add_tensor("W_patch_embed"); graph.add_tensor("pos_embed"); let result = vit.build_vit_graph(&mut graph);
520 assert!(result.is_ok());
522 let outputs = result.unwrap();
523 assert!(!outputs.is_empty());
524 }
525
526 #[test]
527 fn test_vit_parameter_count() {
528 let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000).unwrap();
529 let vit = VisionTransformer::new(config).unwrap();
530
531 let params = vit.count_parameters();
532 assert!(params > 0);
533 }
536
537 #[test]
538 fn test_vit_presets() {
539 for preset in [
540 ViTPreset::Tiny16,
541 ViTPreset::Small16,
542 ViTPreset::Base16,
543 ViTPreset::Large16,
544 ViTPreset::Huge14,
545 ] {
546 let config = preset.config(1000).unwrap();
547 assert!(config.validate().is_ok());
548 assert_eq!(config.num_classes, 1000);
549
550 let vit = VisionTransformer::new(config).unwrap();
551 assert!(vit.count_parameters() > 0);
552 }
553 }
554
555 #[test]
556 fn test_vit_preset_names() {
557 assert_eq!(ViTPreset::Tiny16.name(), "ViT-Tiny/16");
558 assert_eq!(ViTPreset::Small16.name(), "ViT-Small/16");
559 assert_eq!(ViTPreset::Base16.name(), "ViT-Base/16");
560 assert_eq!(ViTPreset::Large16.name(), "ViT-Large/16");
561 assert_eq!(ViTPreset::Huge14.name(), "ViT-Huge/14");
562 }
563
564 #[test]
565 fn test_different_image_sizes() {
566 for (image_size, patch_size) in [(224, 16), (384, 16), (512, 32)] {
567 let config = PatchEmbeddingConfig::new(image_size, patch_size, 3, 768).unwrap();
568 let expected_patches = (image_size / patch_size) * (image_size / patch_size);
569 assert_eq!(config.num_patches(), expected_patches);
570 }
571 }
572
573 #[test]
574 fn test_vit_config_builder() {
575 let config = VisionTransformerConfig::new(224, 16, 3, 768, 12, 3072, 12, 1000)
576 .unwrap()
577 .with_class_token(true)
578 .with_classifier_dropout(0.1)
579 .with_pre_norm(true)
580 .with_dropout(0.1);
581
582 assert!(config.use_class_token);
583 assert!((config.classifier_dropout - 0.1).abs() < 1e-10);
584 assert!(config.encoder.layer_config.pre_norm);
585 assert!(config.validate().is_ok());
586 }
587}