1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6pub trait ModelConfig: Clone + Default {
8 fn model_name(&self) -> String;
10
11 fn variant(&self) -> String;
13
14 fn validate(&self) -> Result<(), String>;
16
17 fn estimated_parameters(&self) -> u64;
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct VisionModelConfig {
24 pub architecture: VisionArchitecture,
26 pub input_size: (usize, usize),
28 pub in_channels: usize,
30 pub num_classes: usize,
32 pub arch_params: VisionArchParams,
34 pub training: TrainingConfig,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub enum VisionArchitecture {
41 ResNet,
42 EfficientNet,
43 VisionTransformer,
44 MobileNet,
45 DenseNet,
46 ConvNeXt,
47 Swin,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum VisionArchParams {
53 ResNet(ResNetConfig),
54 EfficientNet(EfficientNetConfig),
55 VisionTransformer(ViTConfig),
56 MobileNet(MobileNetConfig),
57 DenseNet(DenseNetConfig),
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ResNetConfig {
63 pub layers: [usize; 4],
65 pub bottleneck: bool,
67 pub width_mult: f32,
69 pub stem_type: StemType,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct EfficientNetConfig {
76 pub width_mult: f32,
78 pub depth_mult: f32,
80 pub resolution_mult: f32,
82 pub dropout_rate: f32,
84 pub se_ratio: f32,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ViTConfig {
91 pub patch_size: usize,
93 pub embed_dim: usize,
95 pub depth: usize,
97 pub num_heads: usize,
99 pub mlp_ratio: f32,
101 pub dropout_rate: f32,
103 pub attn_dropout_rate: f32,
105 pub pos_embed_type: PositionEmbedType,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct MobileNetConfig {
112 pub version: MobileNetVersion,
114 pub width_mult: f32,
116 pub min_ch: usize,
118 pub dropout_rate: f32,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DenseNetConfig {
125 pub growth_rate: usize,
127 pub block_config: Vec<usize>,
129 pub num_init_features: usize,
131 pub bn_size: usize,
133 pub compression: f32,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct NlpModelConfig {
140 pub architecture: NlpArchitecture,
142 pub vocab_size: usize,
144 pub max_length: usize,
146 pub arch_params: NlpArchParams,
148 pub training: TrainingConfig,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
154pub enum NlpArchitecture {
155 BERT,
156 GPT,
157 T5,
158 RoBERTa,
159 BART,
160 ELECTRA,
161 DeBERTa,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum NlpArchParams {
167 BERT(BERTConfig),
168 GPT(GPTConfig),
169 T5(T5Config),
170 RoBERTa(RoBERTaConfig),
171 BART(BARTConfig),
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct BERTConfig {
177 pub hidden_size: usize,
179 pub num_hidden_layers: usize,
181 pub num_attention_heads: usize,
183 pub intermediate_size: usize,
185 pub hidden_act: String,
187 pub hidden_dropout_prob: f32,
189 pub attention_probs_dropout_prob: f32,
191 pub layer_norm_eps: f32,
193 pub use_absolute_pos: bool,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct GPTConfig {
200 pub n_embd: usize,
202 pub n_layer: usize,
204 pub n_head: usize,
206 pub n_ctx: usize,
208 pub resid_pdrop: f32,
210 pub embd_pdrop: f32,
212 pub attn_pdrop: f32,
214 pub use_bias: bool,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct T5Config {
221 pub d_model: usize,
223 pub d_kv: usize,
225 pub d_ff: usize,
227 pub num_layers: usize,
229 pub num_heads: usize,
231 pub relative_attention_num_buckets: usize,
233 pub dropout_rate: f32,
235 pub layer_norm_epsilon: f32,
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct RoBERTaConfig {
242 pub bert_config: BERTConfig,
244 pub use_alternate_layernorm: bool,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct BARTConfig {
251 pub vocab_size: usize,
253 pub d_model: usize,
255 pub encoder_layers: usize,
257 pub decoder_layers: usize,
259 pub encoder_attention_heads: usize,
261 pub decoder_attention_heads: usize,
263 pub encoder_ffn_dim: usize,
265 pub decoder_ffn_dim: usize,
266 pub dropout: f32,
268 pub attention_dropout: f32,
270 pub activation_function: String,
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct TrainingConfig {
277 pub learning_rate: f32,
279 pub batch_size: usize,
281 pub epochs: usize,
283 pub weight_decay: f32,
285 pub optimizer: OptimizerType,
287 pub lr_scheduler: LRSchedulerType,
289 pub max_grad_norm: Option<f32>,
291 pub warmup_steps: Option<usize>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297pub enum OptimizerType {
298 SGD { momentum: f32, nesterov: bool },
299 Adam { beta1: f32, beta2: f32, eps: f32 },
300 AdamW { beta1: f32, beta2: f32, eps: f32 },
301 RMSprop { alpha: f32, eps: f32 },
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub enum LRSchedulerType {
307 Constant,
308 Linear { total_steps: usize },
309 Cosine { total_steps: usize, min_lr: f32 },
310 StepLR { step_size: usize, gamma: f32 },
311 ExponentialLR { gamma: f32 },
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum StemType {
317 Standard,
319 Deep,
321 Patch,
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
327pub enum PositionEmbedType {
328 Learned,
329 Sinusoidal,
330 Relative,
331 Rotary,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
336pub enum MobileNetVersion {
337 V1,
338 V2,
339 V3Small,
340 V3Large,
341}
342
343impl Default for VisionModelConfig {
344 fn default() -> Self {
345 Self {
346 architecture: VisionArchitecture::ResNet,
347 input_size: (224, 224),
348 in_channels: 3,
349 num_classes: 1000,
350 arch_params: VisionArchParams::ResNet(ResNetConfig::default()),
351 training: TrainingConfig::default(),
352 }
353 }
354}
355
356impl Default for ResNetConfig {
357 fn default() -> Self {
358 Self {
359 layers: [2, 2, 2, 2], bottleneck: false,
361 width_mult: 1.0,
362 stem_type: StemType::Standard,
363 }
364 }
365}
366
367impl Default for EfficientNetConfig {
368 fn default() -> Self {
369 Self {
370 width_mult: 1.0,
371 depth_mult: 1.0,
372 resolution_mult: 1.0,
373 dropout_rate: 0.2,
374 se_ratio: 0.25,
375 }
376 }
377}
378
379impl Default for ViTConfig {
380 fn default() -> Self {
381 Self {
382 patch_size: 16,
383 embed_dim: 768,
384 depth: 12,
385 num_heads: 12,
386 mlp_ratio: 4.0,
387 dropout_rate: 0.1,
388 attn_dropout_rate: 0.0,
389 pos_embed_type: PositionEmbedType::Learned,
390 }
391 }
392}
393
394impl Default for MobileNetConfig {
395 fn default() -> Self {
396 Self {
397 version: MobileNetVersion::V2,
398 width_mult: 1.0,
399 min_ch: 8,
400 dropout_rate: 0.2,
401 }
402 }
403}
404
405impl Default for DenseNetConfig {
406 fn default() -> Self {
407 Self {
408 growth_rate: 32,
409 block_config: vec![6, 12, 24, 16], num_init_features: 64,
411 bn_size: 4,
412 compression: 0.5,
413 }
414 }
415}
416
417impl Default for NlpModelConfig {
418 fn default() -> Self {
419 Self {
420 architecture: NlpArchitecture::BERT,
421 vocab_size: 30522,
422 max_length: 512,
423 arch_params: NlpArchParams::BERT(BERTConfig::default()),
424 training: TrainingConfig::default(),
425 }
426 }
427}
428
429impl Default for BERTConfig {
430 fn default() -> Self {
431 Self {
432 hidden_size: 768,
433 num_hidden_layers: 12,
434 num_attention_heads: 12,
435 intermediate_size: 3072,
436 hidden_act: "gelu".to_string(),
437 hidden_dropout_prob: 0.1,
438 attention_probs_dropout_prob: 0.1,
439 layer_norm_eps: 1e-12,
440 use_absolute_pos: true,
441 }
442 }
443}
444
445impl Default for GPTConfig {
446 fn default() -> Self {
447 Self {
448 n_embd: 768,
449 n_layer: 12,
450 n_head: 12,
451 n_ctx: 1024,
452 resid_pdrop: 0.1,
453 embd_pdrop: 0.1,
454 attn_pdrop: 0.1,
455 use_bias: true,
456 }
457 }
458}
459
460impl Default for T5Config {
461 fn default() -> Self {
462 Self {
463 d_model: 512,
464 d_kv: 64,
465 d_ff: 2048,
466 num_layers: 6,
467 num_heads: 8,
468 relative_attention_num_buckets: 32,
469 dropout_rate: 0.1,
470 layer_norm_epsilon: 1e-6,
471 }
472 }
473}
474
475impl Default for RoBERTaConfig {
476 fn default() -> Self {
477 Self {
478 bert_config: BERTConfig::default(),
479 use_alternate_layernorm: false,
480 }
481 }
482}
483
484impl Default for BARTConfig {
485 fn default() -> Self {
486 Self {
487 vocab_size: 50265,
488 d_model: 768,
489 encoder_layers: 6,
490 decoder_layers: 6,
491 encoder_attention_heads: 12,
492 decoder_attention_heads: 12,
493 encoder_ffn_dim: 3072,
494 decoder_ffn_dim: 3072,
495 dropout: 0.1,
496 attention_dropout: 0.0,
497 activation_function: "gelu".to_string(),
498 }
499 }
500}
501
502impl Default for TrainingConfig {
503 fn default() -> Self {
504 Self {
505 learning_rate: 1e-4,
506 batch_size: 32,
507 epochs: 10,
508 weight_decay: 0.01,
509 optimizer: OptimizerType::AdamW {
510 beta1: 0.9,
511 beta2: 0.999,
512 eps: 1e-8,
513 },
514 lr_scheduler: LRSchedulerType::Constant,
515 max_grad_norm: Some(1.0),
516 warmup_steps: None,
517 }
518 }
519}
520
521impl ModelConfig for VisionModelConfig {
522 fn model_name(&self) -> String {
523 format!("{:?}", self.architecture)
524 }
525
526 fn variant(&self) -> String {
527 match &self.arch_params {
528 VisionArchParams::ResNet(config) => {
529 let total_layers: usize = config.layers.iter().sum::<usize>() * 2 + 2;
530 format!("resnet{}", total_layers)
531 }
532 VisionArchParams::EfficientNet(config) => {
533 if config.width_mult == 1.0 && config.depth_mult == 1.0 {
534 "efficientnet_b0".to_string()
535 } else {
536 format!(
537 "efficientnet_w{:.1}_d{:.1}",
538 config.width_mult, config.depth_mult
539 )
540 }
541 }
542 VisionArchParams::VisionTransformer(config) => {
543 if config.embed_dim == 768 {
544 format!("vit_base_patch{}", config.patch_size)
545 } else if config.embed_dim == 1024 {
546 format!("vit_large_patch{}", config.patch_size)
547 } else {
548 format!("vit_{}d_patch{}", config.embed_dim, config.patch_size)
549 }
550 }
551 VisionArchParams::MobileNet(config) => {
552 format!("mobilenet_{:?}_w{:.1}", config.version, config.width_mult)
553 }
554 VisionArchParams::DenseNet(config) => {
555 let total_layers: usize = config.block_config.iter().sum::<usize>() * 2 + 1;
556 format!("densenet{}", total_layers)
557 }
558 }
559 }
560
561 fn validate(&self) -> Result<(), String> {
562 if self.input_size.0 == 0 || self.input_size.1 == 0 {
563 return Err("Input size must be positive".to_string());
564 }
565 if self.in_channels == 0 {
566 return Err("Input channels must be positive".to_string());
567 }
568 if self.num_classes == 0 {
569 return Err("Number of classes must be positive".to_string());
570 }
571
572 match &self.arch_params {
573 VisionArchParams::VisionTransformer(config) => {
574 if self.input_size.0 % config.patch_size != 0
575 || self.input_size.1 % config.patch_size != 0
576 {
577 return Err("Image size must be divisible by patch size".to_string());
578 }
579 if config.embed_dim % config.num_heads != 0 {
580 return Err(
581 "Embedding dimension must be divisible by number of heads".to_string()
582 );
583 }
584 }
585 _ => {}
586 }
587
588 Ok(())
589 }
590
591 fn estimated_parameters(&self) -> u64 {
592 match &self.arch_params {
593 VisionArchParams::ResNet(config) => {
594 let base_params = if config.bottleneck {
595 25_000_000
596 } else {
597 11_000_000
598 };
599 (base_params as f32 * config.width_mult * config.width_mult) as u64
600 }
601 VisionArchParams::EfficientNet(config) => {
602 let base_params = 5_300_000u64;
603 (base_params as f32 * config.width_mult * config.width_mult * config.depth_mult)
604 as u64
605 }
606 VisionArchParams::VisionTransformer(config) => {
607 let patch_embed_params =
608 self.in_channels * config.embed_dim * config.patch_size * config.patch_size;
609 let transformer_params = config.depth
610 * (4 * config.embed_dim * config.embed_dim
611 + 4 * config.embed_dim * (config.embed_dim * config.mlp_ratio as usize));
612 let head_params = config.embed_dim * self.num_classes;
613 (patch_embed_params + transformer_params + head_params) as u64
614 }
615 VisionArchParams::MobileNet(_) => 3_500_000,
616 VisionArchParams::DenseNet(config) => {
617 let base_params = config.growth_rate as u64
618 * config.block_config.iter().sum::<usize>() as u64
619 * 1000;
620 base_params
621 }
622 }
623 }
624}
625
626impl ModelConfig for NlpModelConfig {
627 fn model_name(&self) -> String {
628 format!("{:?}", self.architecture)
629 }
630
631 fn variant(&self) -> String {
632 match &self.arch_params {
633 NlpArchParams::BERT(config) => {
634 if config.hidden_size == 768 {
635 "bert_base".to_string()
636 } else if config.hidden_size == 1024 {
637 "bert_large".to_string()
638 } else {
639 format!("bert_{}h_{}l", config.hidden_size, config.num_hidden_layers)
640 }
641 }
642 NlpArchParams::GPT(config) => {
643 if config.n_embd == 768 {
644 "gpt_base".to_string()
645 } else {
646 format!("gpt_{}d_{}l", config.n_embd, config.n_layer)
647 }
648 }
649 NlpArchParams::T5(config) => {
650 if config.d_model == 512 {
651 "t5_small".to_string()
652 } else if config.d_model == 768 {
653 "t5_base".to_string()
654 } else {
655 format!("t5_{}d_{}l", config.d_model, config.num_layers)
656 }
657 }
658 NlpArchParams::RoBERTa(config) => {
659 if config.bert_config.hidden_size == 768 {
660 "roberta_base".to_string()
661 } else {
662 "roberta_large".to_string()
663 }
664 }
665 NlpArchParams::BART(config) => {
666 if config.d_model == 768 {
667 "bart_base".to_string()
668 } else {
669 format!("bart_{}d", config.d_model)
670 }
671 }
672 }
673 }
674
675 fn validate(&self) -> Result<(), String> {
676 if self.vocab_size == 0 {
677 return Err("Vocabulary size must be positive".to_string());
678 }
679 if self.max_length == 0 {
680 return Err("Maximum length must be positive".to_string());
681 }
682
683 match &self.arch_params {
684 NlpArchParams::BERT(config) => {
685 if config.hidden_size % config.num_attention_heads != 0 {
686 return Err(
687 "Hidden size must be divisible by number of attention heads".to_string()
688 );
689 }
690 }
691 NlpArchParams::GPT(config) => {
692 if config.n_embd % config.n_head != 0 {
693 return Err("Embedding size must be divisible by number of heads".to_string());
694 }
695 }
696 _ => {}
697 }
698
699 Ok(())
700 }
701
702 fn estimated_parameters(&self) -> u64 {
703 match &self.arch_params {
704 NlpArchParams::BERT(config) => {
705 let embedding_params =
706 self.vocab_size * config.hidden_size + self.max_length * config.hidden_size;
707 let transformer_params = config.num_hidden_layers
708 * (4 * config.hidden_size * config.hidden_size
709 + config.intermediate_size * config.hidden_size);
710 (embedding_params + transformer_params) as u64
711 }
712 NlpArchParams::GPT(config) => {
713 let embedding_params =
714 self.vocab_size * config.n_embd + config.n_ctx * config.n_embd;
715 let transformer_params = config.n_layer * (4 * config.n_embd * config.n_embd);
716 (embedding_params + transformer_params) as u64
717 }
718 NlpArchParams::T5(config) => {
719 let embedding_params = self.vocab_size * config.d_model;
720 let encoder_params = config.num_layers
721 * (4 * config.d_model * config.d_model + config.d_ff * config.d_model);
722 let decoder_params = config.num_layers
723 * (4 * config.d_model * config.d_model + config.d_ff * config.d_model);
724 (embedding_params + encoder_params + decoder_params) as u64
725 }
726 NlpArchParams::RoBERTa(config) => {
727 let embedding_params = self.vocab_size * config.bert_config.hidden_size
728 + self.max_length * config.bert_config.hidden_size;
729 let transformer_params = config.bert_config.num_hidden_layers
730 * (4 * config.bert_config.hidden_size * config.bert_config.hidden_size
731 + config.bert_config.intermediate_size * config.bert_config.hidden_size);
732 (embedding_params + transformer_params) as u64
733 }
734 NlpArchParams::BART(config) => {
735 let embedding_params = config.vocab_size * config.d_model;
736 let encoder_params = config.encoder_layers
737 * (4 * config.d_model * config.d_model
738 + config.encoder_ffn_dim * config.d_model);
739 let decoder_params = config.decoder_layers
740 * (4 * config.d_model * config.d_model
741 + config.decoder_ffn_dim * config.d_model);
742 (embedding_params + encoder_params + decoder_params) as u64
743 }
744 }
745 }
746}
747
748pub struct ModelConfigs;
750
751impl ModelConfigs {
752 pub fn resnet_configs() -> HashMap<String, VisionModelConfig> {
754 let mut configs = HashMap::new();
755
756 configs.insert(
758 "resnet18".to_string(),
759 VisionModelConfig {
760 architecture: VisionArchitecture::ResNet,
761 arch_params: VisionArchParams::ResNet(ResNetConfig {
762 layers: [2, 2, 2, 2],
763 bottleneck: false,
764 width_mult: 1.0,
765 stem_type: StemType::Standard,
766 }),
767 ..Default::default()
768 },
769 );
770
771 configs.insert(
773 "resnet50".to_string(),
774 VisionModelConfig {
775 architecture: VisionArchitecture::ResNet,
776 arch_params: VisionArchParams::ResNet(ResNetConfig {
777 layers: [3, 4, 6, 3],
778 bottleneck: true,
779 width_mult: 1.0,
780 stem_type: StemType::Standard,
781 }),
782 ..Default::default()
783 },
784 );
785
786 configs
787 }
788
789 pub fn efficientnet_configs() -> HashMap<String, VisionModelConfig> {
791 let mut configs = HashMap::new();
792
793 let variants = [
794 ("efficientnet_b0", 1.0, 1.0, 1.0, 0.2),
795 ("efficientnet_b1", 1.0, 1.1, 1.15, 0.2),
796 ("efficientnet_b2", 1.1, 1.2, 1.3, 0.3),
797 ("efficientnet_b3", 1.2, 1.4, 1.5, 0.3),
798 ("efficientnet_b4", 1.4, 1.8, 1.8, 0.4),
799 ];
800
801 for (name, width, depth, res, dropout) in variants {
802 configs.insert(
803 name.to_string(),
804 VisionModelConfig {
805 architecture: VisionArchitecture::EfficientNet,
806 input_size: ((224.0 * res) as usize, (224.0 * res) as usize),
807 arch_params: VisionArchParams::EfficientNet(EfficientNetConfig {
808 width_mult: width,
809 depth_mult: depth,
810 resolution_mult: res,
811 dropout_rate: dropout,
812 se_ratio: 0.25,
813 }),
814 ..Default::default()
815 },
816 );
817 }
818
819 configs
820 }
821
822 pub fn vit_configs() -> HashMap<String, VisionModelConfig> {
824 let mut configs = HashMap::new();
825
826 configs.insert(
828 "vit_base_patch16_224".to_string(),
829 VisionModelConfig {
830 architecture: VisionArchitecture::VisionTransformer,
831 arch_params: VisionArchParams::VisionTransformer(ViTConfig {
832 patch_size: 16,
833 embed_dim: 768,
834 depth: 12,
835 num_heads: 12,
836 mlp_ratio: 4.0,
837 dropout_rate: 0.1,
838 attn_dropout_rate: 0.0,
839 pos_embed_type: PositionEmbedType::Learned,
840 }),
841 ..Default::default()
842 },
843 );
844
845 configs.insert(
847 "vit_large_patch16_224".to_string(),
848 VisionModelConfig {
849 architecture: VisionArchitecture::VisionTransformer,
850 arch_params: VisionArchParams::VisionTransformer(ViTConfig {
851 patch_size: 16,
852 embed_dim: 1024,
853 depth: 24,
854 num_heads: 16,
855 mlp_ratio: 4.0,
856 dropout_rate: 0.1,
857 attn_dropout_rate: 0.0,
858 pos_embed_type: PositionEmbedType::Learned,
859 }),
860 ..Default::default()
861 },
862 );
863
864 configs
865 }
866
867 pub fn bert_configs() -> HashMap<String, NlpModelConfig> {
869 let mut configs = HashMap::new();
870
871 configs.insert(
873 "bert_base_uncased".to_string(),
874 NlpModelConfig {
875 architecture: NlpArchitecture::BERT,
876 vocab_size: 30522,
877 max_length: 512,
878 arch_params: NlpArchParams::BERT(BERTConfig {
879 hidden_size: 768,
880 num_hidden_layers: 12,
881 num_attention_heads: 12,
882 intermediate_size: 3072,
883 hidden_act: "gelu".to_string(),
884 hidden_dropout_prob: 0.1,
885 attention_probs_dropout_prob: 0.1,
886 layer_norm_eps: 1e-12,
887 use_absolute_pos: true,
888 }),
889 ..Default::default()
890 },
891 );
892
893 configs.insert(
895 "bert_large_uncased".to_string(),
896 NlpModelConfig {
897 architecture: NlpArchitecture::BERT,
898 vocab_size: 30522,
899 max_length: 512,
900 arch_params: NlpArchParams::BERT(BERTConfig {
901 hidden_size: 1024,
902 num_hidden_layers: 24,
903 num_attention_heads: 16,
904 intermediate_size: 4096,
905 hidden_act: "gelu".to_string(),
906 hidden_dropout_prob: 0.1,
907 attention_probs_dropout_prob: 0.1,
908 layer_norm_eps: 1e-12,
909 use_absolute_pos: true,
910 }),
911 ..Default::default()
912 },
913 );
914
915 configs
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::*;
922
923 #[test]
924 fn test_vision_config_validation() {
925 let mut config = VisionModelConfig::default();
926 assert!(config.validate().is_ok());
927
928 config.input_size = (0, 224);
929 assert!(config.validate().is_err());
930
931 config.input_size = (224, 224);
932 config.num_classes = 0;
933 assert!(config.validate().is_err());
934 }
935
936 #[test]
937 fn test_vit_config_validation() {
938 let mut config = VisionModelConfig {
939 input_size: (224, 224),
940 arch_params: VisionArchParams::VisionTransformer(ViTConfig {
941 patch_size: 16,
942 embed_dim: 768,
943 num_heads: 12,
944 ..Default::default()
945 }),
946 ..Default::default()
947 };
948 assert!(config.validate().is_ok());
949
950 config.input_size = (225, 225);
952 assert!(config.validate().is_err());
953
954 config.input_size = (224, 224);
956 if let VisionArchParams::VisionTransformer(ref mut vit_config) = config.arch_params {
957 vit_config.embed_dim = 770; }
959 assert!(config.validate().is_err());
960 }
961
962 #[test]
963 fn test_model_variants() {
964 let resnet_config = VisionModelConfig {
965 arch_params: VisionArchParams::ResNet(ResNetConfig {
966 layers: [2, 2, 2, 2],
967 ..Default::default()
968 }),
969 ..Default::default()
970 };
971 assert_eq!(resnet_config.variant(), "resnet18");
972
973 let vit_config = VisionModelConfig {
974 arch_params: VisionArchParams::VisionTransformer(ViTConfig {
975 embed_dim: 768,
976 patch_size: 16,
977 ..Default::default()
978 }),
979 ..Default::default()
980 };
981 assert_eq!(vit_config.variant(), "vit_base_patch16");
982 }
983
984 #[test]
985 fn test_predefined_configs() {
986 let resnet_configs = ModelConfigs::resnet_configs();
987 assert!(resnet_configs.contains_key("resnet18"));
988 assert!(resnet_configs.contains_key("resnet50"));
989
990 let efficientnet_configs = ModelConfigs::efficientnet_configs();
991 assert!(efficientnet_configs.contains_key("efficientnet_b0"));
992
993 let vit_configs = ModelConfigs::vit_configs();
994 assert!(vit_configs.contains_key("vit_base_patch16_224"));
995
996 let bert_configs = ModelConfigs::bert_configs();
997 assert!(bert_configs.contains_key("bert_base_uncased"));
998 }
999
1000 #[test]
1001 fn test_parameter_estimation() {
1002 let resnet_configs = ModelConfigs::resnet_configs();
1003 let resnet18_config = resnet_configs.get("resnet18").unwrap();
1004 let params = resnet18_config.estimated_parameters();
1005 assert!(params > 10_000_000 && params < 15_000_000); let bert_configs = ModelConfigs::bert_configs();
1008 let bert_base_config = bert_configs.get("bert_base_uncased").unwrap();
1009 let bert_params = bert_base_config.estimated_parameters();
1010 assert!(bert_params > 50_000_000 && bert_params < 200_000_000); }
1012}