Skip to main content

torsh_models/
config.rs

1//! Model configuration system for parameterizing architectures
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Base model configuration trait
7pub trait ModelConfig: Clone + Default {
8    /// Get the model name
9    fn model_name(&self) -> String;
10
11    /// Get model variant/size description
12    fn variant(&self) -> String;
13
14    /// Validate configuration parameters
15    fn validate(&self) -> Result<(), String>;
16
17    /// Get estimated parameter count
18    fn estimated_parameters(&self) -> u64;
19}
20
21/// Vision model configuration
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct VisionModelConfig {
24    /// Model architecture type
25    pub architecture: VisionArchitecture,
26    /// Input image size (height, width)
27    pub input_size: (usize, usize),
28    /// Number of input channels (usually 3 for RGB)
29    pub in_channels: usize,
30    /// Number of output classes
31    pub num_classes: usize,
32    /// Architecture-specific parameters
33    pub arch_params: VisionArchParams,
34    /// Training hyperparameters
35    pub training: TrainingConfig,
36}
37
38/// Vision architecture types
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub enum VisionArchitecture {
41    ResNet,
42    EfficientNet,
43    VisionTransformer,
44    MobileNet,
45    DenseNet,
46    ConvNeXt,
47    Swin,
48}
49
50/// Architecture-specific parameters for vision models
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum VisionArchParams {
53    ResNet(ResNetConfig),
54    EfficientNet(EfficientNetConfig),
55    VisionTransformer(ViTConfig),
56    MobileNet(MobileNetConfig),
57    DenseNet(DenseNetConfig),
58}
59
60/// ResNet configuration
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ResNetConfig {
63    /// Number of layers in each stage [stage1, stage2, stage3, stage4]
64    pub layers: [usize; 4],
65    /// Use bottleneck blocks (for ResNet-50+)
66    pub bottleneck: bool,
67    /// Width multiplier for channels
68    pub width_mult: f32,
69    /// Stem convolution type
70    pub stem_type: StemType,
71}
72
73/// EfficientNet configuration
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct EfficientNetConfig {
76    /// Width scaling factor
77    pub width_mult: f32,
78    /// Depth scaling factor
79    pub depth_mult: f32,
80    /// Resolution scaling factor
81    pub resolution_mult: f32,
82    /// Dropout rate
83    pub dropout_rate: f32,
84    /// Squeeze-and-Excitation ratio
85    pub se_ratio: f32,
86}
87
88/// Vision Transformer configuration
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ViTConfig {
91    /// Patch size (assumed square)
92    pub patch_size: usize,
93    /// Embedding dimension
94    pub embed_dim: usize,
95    /// Number of transformer layers
96    pub depth: usize,
97    /// Number of attention heads
98    pub num_heads: usize,
99    /// MLP expansion ratio
100    pub mlp_ratio: f32,
101    /// Dropout rate
102    pub dropout_rate: f32,
103    /// Attention dropout rate
104    pub attn_dropout_rate: f32,
105    /// Position embedding type
106    pub pos_embed_type: PositionEmbedType,
107}
108
109/// MobileNet configuration
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct MobileNetConfig {
112    /// MobileNet version
113    pub version: MobileNetVersion,
114    /// Width multiplier
115    pub width_mult: f32,
116    /// Minimum channel divisor
117    pub min_ch: usize,
118    /// Dropout rate
119    pub dropout_rate: f32,
120}
121
122/// DenseNet configuration
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct DenseNetConfig {
125    /// Growth rate (number of channels added per layer)
126    pub growth_rate: usize,
127    /// Number of layers in each dense block
128    pub block_config: Vec<usize>,
129    /// Number of initial features
130    pub num_init_features: usize,
131    /// Bottleneck width multiplier
132    pub bn_size: usize,
133    /// Compression factor in transition layers
134    pub compression: f32,
135}
136
137/// NLP model configuration
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct NlpModelConfig {
140    /// Model architecture type
141    pub architecture: NlpArchitecture,
142    /// Vocabulary size
143    pub vocab_size: usize,
144    /// Maximum sequence length
145    pub max_length: usize,
146    /// Architecture-specific parameters
147    pub arch_params: NlpArchParams,
148    /// Training hyperparameters
149    pub training: TrainingConfig,
150}
151
152/// NLP architecture types
153#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
154pub enum NlpArchitecture {
155    BERT,
156    GPT,
157    T5,
158    RoBERTa,
159    BART,
160    ELECTRA,
161    DeBERTa,
162}
163
164/// Architecture-specific parameters for NLP models
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum NlpArchParams {
167    BERT(BERTConfig),
168    GPT(GPTConfig),
169    T5(T5Config),
170    RoBERTa(RoBERTaConfig),
171    BART(BARTConfig),
172}
173
174/// BERT configuration
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct BERTConfig {
177    /// Hidden size
178    pub hidden_size: usize,
179    /// Number of hidden layers
180    pub num_hidden_layers: usize,
181    /// Number of attention heads
182    pub num_attention_heads: usize,
183    /// Intermediate size in feed-forward layers
184    pub intermediate_size: usize,
185    /// Hidden activation function
186    pub hidden_act: String,
187    /// Hidden dropout probability
188    pub hidden_dropout_prob: f32,
189    /// Attention dropout probability
190    pub attention_probs_dropout_prob: f32,
191    /// Layer norm epsilon
192    pub layer_norm_eps: f32,
193    /// Use absolute position embeddings
194    pub use_absolute_pos: bool,
195}
196
197/// GPT configuration
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct GPTConfig {
200    /// Embedding/hidden size
201    pub n_embd: usize,
202    /// Number of layers
203    pub n_layer: usize,
204    /// Number of attention heads
205    pub n_head: usize,
206    /// Context length
207    pub n_ctx: usize,
208    /// Residual dropout
209    pub resid_pdrop: f32,
210    /// Embedding dropout
211    pub embd_pdrop: f32,
212    /// Attention dropout
213    pub attn_pdrop: f32,
214    /// Use bias in linear layers
215    pub use_bias: bool,
216}
217
218/// T5 configuration
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct T5Config {
221    /// Model dimension
222    pub d_model: usize,
223    /// Key/value dimension
224    pub d_kv: usize,
225    /// Feed-forward dimension
226    pub d_ff: usize,
227    /// Number of layers
228    pub num_layers: usize,
229    /// Number of heads
230    pub num_heads: usize,
231    /// Relative attention bucket size
232    pub relative_attention_num_buckets: usize,
233    /// Dropout rate
234    pub dropout_rate: f32,
235    /// Layer norm epsilon
236    pub layer_norm_epsilon: f32,
237}
238
239/// RoBERTa configuration (extends BERT)
240#[derive(Debug, Clone, Serialize, Deserialize)]
241pub struct RoBERTaConfig {
242    /// Base BERT configuration
243    pub bert_config: BERTConfig,
244    /// Use different layer norm
245    pub use_alternate_layernorm: bool,
246}
247
248/// BART configuration
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct BARTConfig {
251    /// Vocabulary size
252    pub vocab_size: usize,
253    /// Model dimension
254    pub d_model: usize,
255    /// Encoder layers
256    pub encoder_layers: usize,
257    /// Decoder layers
258    pub decoder_layers: usize,
259    /// Encoder attention heads
260    pub encoder_attention_heads: usize,
261    /// Decoder attention heads  
262    pub decoder_attention_heads: usize,
263    /// Feed-forward dimension
264    pub encoder_ffn_dim: usize,
265    pub decoder_ffn_dim: usize,
266    /// Dropout
267    pub dropout: f32,
268    /// Attention dropout
269    pub attention_dropout: f32,
270    /// Activation function
271    pub activation_function: String,
272}
273
274/// Training configuration
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct TrainingConfig {
277    /// Learning rate
278    pub learning_rate: f32,
279    /// Batch size
280    pub batch_size: usize,
281    /// Number of epochs
282    pub epochs: usize,
283    /// Weight decay
284    pub weight_decay: f32,
285    /// Optimizer type
286    pub optimizer: OptimizerType,
287    /// Learning rate scheduler
288    pub lr_scheduler: LRSchedulerType,
289    /// Gradient clipping
290    pub max_grad_norm: Option<f32>,
291    /// Warmup steps
292    pub warmup_steps: Option<usize>,
293}
294
295/// Optimizer types
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub enum OptimizerType {
298    SGD { momentum: f32, nesterov: bool },
299    Adam { beta1: f32, beta2: f32, eps: f32 },
300    AdamW { beta1: f32, beta2: f32, eps: f32 },
301    RMSprop { alpha: f32, eps: f32 },
302}
303
304/// Learning rate scheduler types
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub enum LRSchedulerType {
307    Constant,
308    Linear { total_steps: usize },
309    Cosine { total_steps: usize, min_lr: f32 },
310    StepLR { step_size: usize, gamma: f32 },
311    ExponentialLR { gamma: f32 },
312}
313
314/// Stem convolution types for ResNet
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub enum StemType {
317    /// Standard 7x7 conv with stride 2
318    Standard,
319    /// Deep stem with multiple 3x3 convs
320    Deep,
321    /// Patch-like stem for better fine-grained features
322    Patch,
323}
324
325/// Position embedding types
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub enum PositionEmbedType {
328    Learned,
329    Sinusoidal,
330    Relative,
331    Rotary,
332}
333
334/// MobileNet version
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub enum MobileNetVersion {
337    V1,
338    V2,
339    V3Small,
340    V3Large,
341}
342
343impl Default for VisionModelConfig {
344    fn default() -> Self {
345        Self {
346            architecture: VisionArchitecture::ResNet,
347            input_size: (224, 224),
348            in_channels: 3,
349            num_classes: 1000,
350            arch_params: VisionArchParams::ResNet(ResNetConfig::default()),
351            training: TrainingConfig::default(),
352        }
353    }
354}
355
356impl Default for ResNetConfig {
357    fn default() -> Self {
358        Self {
359            layers: [2, 2, 2, 2], // ResNet-18
360            bottleneck: false,
361            width_mult: 1.0,
362            stem_type: StemType::Standard,
363        }
364    }
365}
366
367impl Default for EfficientNetConfig {
368    fn default() -> Self {
369        Self {
370            width_mult: 1.0,
371            depth_mult: 1.0,
372            resolution_mult: 1.0,
373            dropout_rate: 0.2,
374            se_ratio: 0.25,
375        }
376    }
377}
378
379impl Default for ViTConfig {
380    fn default() -> Self {
381        Self {
382            patch_size: 16,
383            embed_dim: 768,
384            depth: 12,
385            num_heads: 12,
386            mlp_ratio: 4.0,
387            dropout_rate: 0.1,
388            attn_dropout_rate: 0.0,
389            pos_embed_type: PositionEmbedType::Learned,
390        }
391    }
392}
393
394impl Default for MobileNetConfig {
395    fn default() -> Self {
396        Self {
397            version: MobileNetVersion::V2,
398            width_mult: 1.0,
399            min_ch: 8,
400            dropout_rate: 0.2,
401        }
402    }
403}
404
405impl Default for DenseNetConfig {
406    fn default() -> Self {
407        Self {
408            growth_rate: 32,
409            block_config: vec![6, 12, 24, 16], // DenseNet-121
410            num_init_features: 64,
411            bn_size: 4,
412            compression: 0.5,
413        }
414    }
415}
416
417impl Default for NlpModelConfig {
418    fn default() -> Self {
419        Self {
420            architecture: NlpArchitecture::BERT,
421            vocab_size: 30522,
422            max_length: 512,
423            arch_params: NlpArchParams::BERT(BERTConfig::default()),
424            training: TrainingConfig::default(),
425        }
426    }
427}
428
429impl Default for BERTConfig {
430    fn default() -> Self {
431        Self {
432            hidden_size: 768,
433            num_hidden_layers: 12,
434            num_attention_heads: 12,
435            intermediate_size: 3072,
436            hidden_act: "gelu".to_string(),
437            hidden_dropout_prob: 0.1,
438            attention_probs_dropout_prob: 0.1,
439            layer_norm_eps: 1e-12,
440            use_absolute_pos: true,
441        }
442    }
443}
444
445impl Default for GPTConfig {
446    fn default() -> Self {
447        Self {
448            n_embd: 768,
449            n_layer: 12,
450            n_head: 12,
451            n_ctx: 1024,
452            resid_pdrop: 0.1,
453            embd_pdrop: 0.1,
454            attn_pdrop: 0.1,
455            use_bias: true,
456        }
457    }
458}
459
460impl Default for T5Config {
461    fn default() -> Self {
462        Self {
463            d_model: 512,
464            d_kv: 64,
465            d_ff: 2048,
466            num_layers: 6,
467            num_heads: 8,
468            relative_attention_num_buckets: 32,
469            dropout_rate: 0.1,
470            layer_norm_epsilon: 1e-6,
471        }
472    }
473}
474
475impl Default for RoBERTaConfig {
476    fn default() -> Self {
477        Self {
478            bert_config: BERTConfig::default(),
479            use_alternate_layernorm: false,
480        }
481    }
482}
483
484impl Default for BARTConfig {
485    fn default() -> Self {
486        Self {
487            vocab_size: 50265,
488            d_model: 768,
489            encoder_layers: 6,
490            decoder_layers: 6,
491            encoder_attention_heads: 12,
492            decoder_attention_heads: 12,
493            encoder_ffn_dim: 3072,
494            decoder_ffn_dim: 3072,
495            dropout: 0.1,
496            attention_dropout: 0.0,
497            activation_function: "gelu".to_string(),
498        }
499    }
500}
501
502impl Default for TrainingConfig {
503    fn default() -> Self {
504        Self {
505            learning_rate: 1e-4,
506            batch_size: 32,
507            epochs: 10,
508            weight_decay: 0.01,
509            optimizer: OptimizerType::AdamW {
510                beta1: 0.9,
511                beta2: 0.999,
512                eps: 1e-8,
513            },
514            lr_scheduler: LRSchedulerType::Constant,
515            max_grad_norm: Some(1.0),
516            warmup_steps: None,
517        }
518    }
519}
520
521impl ModelConfig for VisionModelConfig {
522    fn model_name(&self) -> String {
523        format!("{:?}", self.architecture)
524    }
525
526    fn variant(&self) -> String {
527        match &self.arch_params {
528            VisionArchParams::ResNet(config) => {
529                let total_layers: usize = config.layers.iter().sum::<usize>() * 2 + 2;
530                format!("resnet{}", total_layers)
531            }
532            VisionArchParams::EfficientNet(config) => {
533                if config.width_mult == 1.0 && config.depth_mult == 1.0 {
534                    "efficientnet_b0".to_string()
535                } else {
536                    format!(
537                        "efficientnet_w{:.1}_d{:.1}",
538                        config.width_mult, config.depth_mult
539                    )
540                }
541            }
542            VisionArchParams::VisionTransformer(config) => {
543                if config.embed_dim == 768 {
544                    format!("vit_base_patch{}", config.patch_size)
545                } else if config.embed_dim == 1024 {
546                    format!("vit_large_patch{}", config.patch_size)
547                } else {
548                    format!("vit_{}d_patch{}", config.embed_dim, config.patch_size)
549                }
550            }
551            VisionArchParams::MobileNet(config) => {
552                format!("mobilenet_{:?}_w{:.1}", config.version, config.width_mult)
553            }
554            VisionArchParams::DenseNet(config) => {
555                let total_layers: usize = config.block_config.iter().sum::<usize>() * 2 + 1;
556                format!("densenet{}", total_layers)
557            }
558        }
559    }
560
561    fn validate(&self) -> Result<(), String> {
562        if self.input_size.0 == 0 || self.input_size.1 == 0 {
563            return Err("Input size must be positive".to_string());
564        }
565        if self.in_channels == 0 {
566            return Err("Input channels must be positive".to_string());
567        }
568        if self.num_classes == 0 {
569            return Err("Number of classes must be positive".to_string());
570        }
571
572        match &self.arch_params {
573            VisionArchParams::VisionTransformer(config) => {
574                if self.input_size.0 % config.patch_size != 0
575                    || self.input_size.1 % config.patch_size != 0
576                {
577                    return Err("Image size must be divisible by patch size".to_string());
578                }
579                if config.embed_dim % config.num_heads != 0 {
580                    return Err(
581                        "Embedding dimension must be divisible by number of heads".to_string()
582                    );
583                }
584            }
585            _ => {}
586        }
587
588        Ok(())
589    }
590
591    fn estimated_parameters(&self) -> u64 {
592        match &self.arch_params {
593            VisionArchParams::ResNet(config) => {
594                let base_params = if config.bottleneck {
595                    25_000_000
596                } else {
597                    11_000_000
598                };
599                (base_params as f32 * config.width_mult * config.width_mult) as u64
600            }
601            VisionArchParams::EfficientNet(config) => {
602                let base_params = 5_300_000u64;
603                (base_params as f32 * config.width_mult * config.width_mult * config.depth_mult)
604                    as u64
605            }
606            VisionArchParams::VisionTransformer(config) => {
607                let patch_embed_params =
608                    self.in_channels * config.embed_dim * config.patch_size * config.patch_size;
609                let transformer_params = config.depth
610                    * (4 * config.embed_dim * config.embed_dim
611                        + 4 * config.embed_dim * (config.embed_dim * config.mlp_ratio as usize));
612                let head_params = config.embed_dim * self.num_classes;
613                (patch_embed_params + transformer_params + head_params) as u64
614            }
615            VisionArchParams::MobileNet(_) => 3_500_000,
616            VisionArchParams::DenseNet(config) => {
617                let base_params = config.growth_rate as u64
618                    * config.block_config.iter().sum::<usize>() as u64
619                    * 1000;
620                base_params
621            }
622        }
623    }
624}
625
626impl ModelConfig for NlpModelConfig {
627    fn model_name(&self) -> String {
628        format!("{:?}", self.architecture)
629    }
630
631    fn variant(&self) -> String {
632        match &self.arch_params {
633            NlpArchParams::BERT(config) => {
634                if config.hidden_size == 768 {
635                    "bert_base".to_string()
636                } else if config.hidden_size == 1024 {
637                    "bert_large".to_string()
638                } else {
639                    format!("bert_{}h_{}l", config.hidden_size, config.num_hidden_layers)
640                }
641            }
642            NlpArchParams::GPT(config) => {
643                if config.n_embd == 768 {
644                    "gpt_base".to_string()
645                } else {
646                    format!("gpt_{}d_{}l", config.n_embd, config.n_layer)
647                }
648            }
649            NlpArchParams::T5(config) => {
650                if config.d_model == 512 {
651                    "t5_small".to_string()
652                } else if config.d_model == 768 {
653                    "t5_base".to_string()
654                } else {
655                    format!("t5_{}d_{}l", config.d_model, config.num_layers)
656                }
657            }
658            NlpArchParams::RoBERTa(config) => {
659                if config.bert_config.hidden_size == 768 {
660                    "roberta_base".to_string()
661                } else {
662                    "roberta_large".to_string()
663                }
664            }
665            NlpArchParams::BART(config) => {
666                if config.d_model == 768 {
667                    "bart_base".to_string()
668                } else {
669                    format!("bart_{}d", config.d_model)
670                }
671            }
672        }
673    }
674
675    fn validate(&self) -> Result<(), String> {
676        if self.vocab_size == 0 {
677            return Err("Vocabulary size must be positive".to_string());
678        }
679        if self.max_length == 0 {
680            return Err("Maximum length must be positive".to_string());
681        }
682
683        match &self.arch_params {
684            NlpArchParams::BERT(config) => {
685                if config.hidden_size % config.num_attention_heads != 0 {
686                    return Err(
687                        "Hidden size must be divisible by number of attention heads".to_string()
688                    );
689                }
690            }
691            NlpArchParams::GPT(config) => {
692                if config.n_embd % config.n_head != 0 {
693                    return Err("Embedding size must be divisible by number of heads".to_string());
694                }
695            }
696            _ => {}
697        }
698
699        Ok(())
700    }
701
702    fn estimated_parameters(&self) -> u64 {
703        match &self.arch_params {
704            NlpArchParams::BERT(config) => {
705                let embedding_params =
706                    self.vocab_size * config.hidden_size + self.max_length * config.hidden_size;
707                let transformer_params = config.num_hidden_layers
708                    * (4 * config.hidden_size * config.hidden_size
709                        + config.intermediate_size * config.hidden_size);
710                (embedding_params + transformer_params) as u64
711            }
712            NlpArchParams::GPT(config) => {
713                let embedding_params =
714                    self.vocab_size * config.n_embd + config.n_ctx * config.n_embd;
715                let transformer_params = config.n_layer * (4 * config.n_embd * config.n_embd);
716                (embedding_params + transformer_params) as u64
717            }
718            NlpArchParams::T5(config) => {
719                let embedding_params = self.vocab_size * config.d_model;
720                let encoder_params = config.num_layers
721                    * (4 * config.d_model * config.d_model + config.d_ff * config.d_model);
722                let decoder_params = config.num_layers
723                    * (4 * config.d_model * config.d_model + config.d_ff * config.d_model);
724                (embedding_params + encoder_params + decoder_params) as u64
725            }
726            NlpArchParams::RoBERTa(config) => {
727                let embedding_params = self.vocab_size * config.bert_config.hidden_size
728                    + self.max_length * config.bert_config.hidden_size;
729                let transformer_params = config.bert_config.num_hidden_layers
730                    * (4 * config.bert_config.hidden_size * config.bert_config.hidden_size
731                        + config.bert_config.intermediate_size * config.bert_config.hidden_size);
732                (embedding_params + transformer_params) as u64
733            }
734            NlpArchParams::BART(config) => {
735                let embedding_params = config.vocab_size * config.d_model;
736                let encoder_params = config.encoder_layers
737                    * (4 * config.d_model * config.d_model
738                        + config.encoder_ffn_dim * config.d_model);
739                let decoder_params = config.decoder_layers
740                    * (4 * config.d_model * config.d_model
741                        + config.decoder_ffn_dim * config.d_model);
742                (embedding_params + encoder_params + decoder_params) as u64
743            }
744        }
745    }
746}
747
748/// Predefined model configurations
749pub struct ModelConfigs;
750
751impl ModelConfigs {
752    /// Get ResNet configurations
753    pub fn resnet_configs() -> HashMap<String, VisionModelConfig> {
754        let mut configs = HashMap::new();
755
756        // ResNet-18
757        configs.insert(
758            "resnet18".to_string(),
759            VisionModelConfig {
760                architecture: VisionArchitecture::ResNet,
761                arch_params: VisionArchParams::ResNet(ResNetConfig {
762                    layers: [2, 2, 2, 2],
763                    bottleneck: false,
764                    width_mult: 1.0,
765                    stem_type: StemType::Standard,
766                }),
767                ..Default::default()
768            },
769        );
770
771        // ResNet-50
772        configs.insert(
773            "resnet50".to_string(),
774            VisionModelConfig {
775                architecture: VisionArchitecture::ResNet,
776                arch_params: VisionArchParams::ResNet(ResNetConfig {
777                    layers: [3, 4, 6, 3],
778                    bottleneck: true,
779                    width_mult: 1.0,
780                    stem_type: StemType::Standard,
781                }),
782                ..Default::default()
783            },
784        );
785
786        configs
787    }
788
789    /// Get EfficientNet configurations
790    pub fn efficientnet_configs() -> HashMap<String, VisionModelConfig> {
791        let mut configs = HashMap::new();
792
793        let variants = [
794            ("efficientnet_b0", 1.0, 1.0, 1.0, 0.2),
795            ("efficientnet_b1", 1.0, 1.1, 1.15, 0.2),
796            ("efficientnet_b2", 1.1, 1.2, 1.3, 0.3),
797            ("efficientnet_b3", 1.2, 1.4, 1.5, 0.3),
798            ("efficientnet_b4", 1.4, 1.8, 1.8, 0.4),
799        ];
800
801        for (name, width, depth, res, dropout) in variants {
802            configs.insert(
803                name.to_string(),
804                VisionModelConfig {
805                    architecture: VisionArchitecture::EfficientNet,
806                    input_size: ((224.0 * res) as usize, (224.0 * res) as usize),
807                    arch_params: VisionArchParams::EfficientNet(EfficientNetConfig {
808                        width_mult: width,
809                        depth_mult: depth,
810                        resolution_mult: res,
811                        dropout_rate: dropout,
812                        se_ratio: 0.25,
813                    }),
814                    ..Default::default()
815                },
816            );
817        }
818
819        configs
820    }
821
822    /// Get Vision Transformer configurations
823    pub fn vit_configs() -> HashMap<String, VisionModelConfig> {
824        let mut configs = HashMap::new();
825
826        // ViT-Base/16
827        configs.insert(
828            "vit_base_patch16_224".to_string(),
829            VisionModelConfig {
830                architecture: VisionArchitecture::VisionTransformer,
831                arch_params: VisionArchParams::VisionTransformer(ViTConfig {
832                    patch_size: 16,
833                    embed_dim: 768,
834                    depth: 12,
835                    num_heads: 12,
836                    mlp_ratio: 4.0,
837                    dropout_rate: 0.1,
838                    attn_dropout_rate: 0.0,
839                    pos_embed_type: PositionEmbedType::Learned,
840                }),
841                ..Default::default()
842            },
843        );
844
845        // ViT-Large/16
846        configs.insert(
847            "vit_large_patch16_224".to_string(),
848            VisionModelConfig {
849                architecture: VisionArchitecture::VisionTransformer,
850                arch_params: VisionArchParams::VisionTransformer(ViTConfig {
851                    patch_size: 16,
852                    embed_dim: 1024,
853                    depth: 24,
854                    num_heads: 16,
855                    mlp_ratio: 4.0,
856                    dropout_rate: 0.1,
857                    attn_dropout_rate: 0.0,
858                    pos_embed_type: PositionEmbedType::Learned,
859                }),
860                ..Default::default()
861            },
862        );
863
864        configs
865    }
866
867    /// Get BERT configurations
868    pub fn bert_configs() -> HashMap<String, NlpModelConfig> {
869        let mut configs = HashMap::new();
870
871        // BERT-Base
872        configs.insert(
873            "bert_base_uncased".to_string(),
874            NlpModelConfig {
875                architecture: NlpArchitecture::BERT,
876                vocab_size: 30522,
877                max_length: 512,
878                arch_params: NlpArchParams::BERT(BERTConfig {
879                    hidden_size: 768,
880                    num_hidden_layers: 12,
881                    num_attention_heads: 12,
882                    intermediate_size: 3072,
883                    hidden_act: "gelu".to_string(),
884                    hidden_dropout_prob: 0.1,
885                    attention_probs_dropout_prob: 0.1,
886                    layer_norm_eps: 1e-12,
887                    use_absolute_pos: true,
888                }),
889                ..Default::default()
890            },
891        );
892
893        // BERT-Large
894        configs.insert(
895            "bert_large_uncased".to_string(),
896            NlpModelConfig {
897                architecture: NlpArchitecture::BERT,
898                vocab_size: 30522,
899                max_length: 512,
900                arch_params: NlpArchParams::BERT(BERTConfig {
901                    hidden_size: 1024,
902                    num_hidden_layers: 24,
903                    num_attention_heads: 16,
904                    intermediate_size: 4096,
905                    hidden_act: "gelu".to_string(),
906                    hidden_dropout_prob: 0.1,
907                    attention_probs_dropout_prob: 0.1,
908                    layer_norm_eps: 1e-12,
909                    use_absolute_pos: true,
910                }),
911                ..Default::default()
912            },
913        );
914
915        configs
916    }
917}
918
919#[cfg(test)]
920mod tests {
921    use super::*;
922
923    #[test]
924    fn test_vision_config_validation() {
925        let mut config = VisionModelConfig::default();
926        assert!(config.validate().is_ok());
927
928        config.input_size = (0, 224);
929        assert!(config.validate().is_err());
930
931        config.input_size = (224, 224);
932        config.num_classes = 0;
933        assert!(config.validate().is_err());
934    }
935
936    #[test]
937    fn test_vit_config_validation() {
938        let mut config = VisionModelConfig {
939            input_size: (224, 224),
940            arch_params: VisionArchParams::VisionTransformer(ViTConfig {
941                patch_size: 16,
942                embed_dim: 768,
943                num_heads: 12,
944                ..Default::default()
945            }),
946            ..Default::default()
947        };
948        assert!(config.validate().is_ok());
949
950        // Test invalid patch size
951        config.input_size = (225, 225);
952        assert!(config.validate().is_err());
953
954        // Test invalid head count
955        config.input_size = (224, 224);
956        if let VisionArchParams::VisionTransformer(ref mut vit_config) = config.arch_params {
957            vit_config.embed_dim = 770; // Not divisible by 12 heads
958        }
959        assert!(config.validate().is_err());
960    }
961
962    #[test]
963    fn test_model_variants() {
964        let resnet_config = VisionModelConfig {
965            arch_params: VisionArchParams::ResNet(ResNetConfig {
966                layers: [2, 2, 2, 2],
967                ..Default::default()
968            }),
969            ..Default::default()
970        };
971        assert_eq!(resnet_config.variant(), "resnet18");
972
973        let vit_config = VisionModelConfig {
974            arch_params: VisionArchParams::VisionTransformer(ViTConfig {
975                embed_dim: 768,
976                patch_size: 16,
977                ..Default::default()
978            }),
979            ..Default::default()
980        };
981        assert_eq!(vit_config.variant(), "vit_base_patch16");
982    }
983
984    #[test]
985    fn test_predefined_configs() {
986        let resnet_configs = ModelConfigs::resnet_configs();
987        assert!(resnet_configs.contains_key("resnet18"));
988        assert!(resnet_configs.contains_key("resnet50"));
989
990        let efficientnet_configs = ModelConfigs::efficientnet_configs();
991        assert!(efficientnet_configs.contains_key("efficientnet_b0"));
992
993        let vit_configs = ModelConfigs::vit_configs();
994        assert!(vit_configs.contains_key("vit_base_patch16_224"));
995
996        let bert_configs = ModelConfigs::bert_configs();
997        assert!(bert_configs.contains_key("bert_base_uncased"));
998    }
999
1000    #[test]
1001    fn test_parameter_estimation() {
1002        let resnet_configs = ModelConfigs::resnet_configs();
1003        let resnet18_config = resnet_configs.get("resnet18").unwrap();
1004        let params = resnet18_config.estimated_parameters();
1005        assert!(params > 10_000_000 && params < 15_000_000); // ResNet-18 has ~11M params
1006
1007        let bert_configs = ModelConfigs::bert_configs();
1008        let bert_base_config = bert_configs.get("bert_base_uncased").unwrap();
1009        let bert_params = bert_base_config.estimated_parameters();
1010        assert!(bert_params > 50_000_000 && bert_params < 200_000_000); // BERT-base has ~110M params (allow wider range for estimation)
1011    }
1012}