Skip to main content

torsh_models/
builder.rs

1//! Model builders and factories for easy instantiation
2
3// Framework infrastructure - components designed for future use
4#![allow(dead_code)]
5use crate::config::{
6    ModelConfig, ModelConfigs, NlpArchitecture, NlpModelConfig, VisionArchParams,
7    VisionArchitecture, VisionModelConfig,
8};
9use crate::vision::{ResNet, VisionTransformer};
10use crate::{ModelError, ModelType};
11use std::collections::HashMap;
12use torsh_nn::Module;
13
14/// Result type for model builders
15pub type BuildResult<T> = Result<T, ModelError>;
16
17/// Generic model builder trait
18pub trait ModelBuilder<T: Module>: Send + Sync {
19    type Config: ModelConfig;
20
21    /// Build model from configuration
22    fn build(&self, config: &Self::Config) -> BuildResult<T>;
23
24    /// Build model from configuration name
25    fn build_from_name(&self, name: &str) -> BuildResult<T>;
26
27    /// List available model configurations
28    fn available_models(&self) -> Vec<String>;
29
30    /// Get model configuration by name
31    fn get_config(&self, name: &str) -> Option<Self::Config>;
32}
33
34/// Vision model builder
35pub struct VisionModelBuilder {
36    configs: HashMap<String, VisionModelConfig>,
37}
38
39impl VisionModelBuilder {
40    /// Create new vision model builder
41    pub fn new() -> Self {
42        let mut configs = HashMap::new();
43
44        // Add predefined configurations
45        configs.extend(ModelConfigs::resnet_configs());
46        configs.extend(ModelConfigs::efficientnet_configs());
47        configs.extend(ModelConfigs::vit_configs());
48
49        Self { configs }
50    }
51
52    /// Add custom configuration
53    pub fn add_config(&mut self, name: String, config: VisionModelConfig) {
54        self.configs.insert(name, config);
55    }
56
57    /// Build ResNet model
58    fn build_resnet(&self, config: &VisionModelConfig) -> BuildResult<ModelType> {
59        if let VisionArchParams::ResNet(resnet_config) = &config.arch_params {
60            let model = if resnet_config.layers == [2, 2, 2, 2] {
61                ResNet::resnet18(config.num_classes)
62            } else if resnet_config.layers == [3, 4, 6, 3] && !resnet_config.bottleneck {
63                ResNet::resnet34(config.num_classes)
64            } else if resnet_config.layers == [3, 4, 6, 3] && resnet_config.bottleneck {
65                ResNet::resnet50(config.num_classes)
66            } else {
67                return Err(ModelError::LoadingError {
68                    reason: format!(
69                        "Unsupported ResNet configuration: {:?}",
70                        resnet_config.layers
71                    ),
72                });
73            };
74
75            Ok(ModelType::ResNet(model?))
76        } else {
77            Err(ModelError::LoadingError {
78                reason: "Invalid ResNet configuration".to_string(),
79            })
80        }
81    }
82
83    /// Build EfficientNet model
84    fn build_efficientnet(&self, _config: &VisionModelConfig) -> BuildResult<ModelType> {
85        // EfficientNet implementation exists but requires torsh-nn v0.2 API compatibility
86        // Will be enabled in next major release
87        Err(ModelError::LoadingError {
88            reason: "EfficientNet not yet implemented".to_string(),
89        })
90    }
91
92    /// Build Vision Transformer model
93    fn build_vit(&self, config: &VisionModelConfig) -> BuildResult<ModelType> {
94        if let VisionArchParams::VisionTransformer(vit_config) = &config.arch_params {
95            // Create a ViTConfig from the provided parameters
96            let vit_config_obj = crate::vision::vit::ViTConfig {
97                variant: crate::vision::vit::ViTVariant::Base,
98                img_size: config.input_size.0,
99                patch_size: vit_config.patch_size,
100                in_channels: 3, // Default to RGB channels
101                num_classes: config.num_classes,
102                embed_dim: vit_config.embed_dim,
103                depth: vit_config.depth,
104                num_heads: vit_config.num_heads,
105                mlp_ratio: vit_config.mlp_ratio,
106                qkv_bias: true,
107                representation_size: None,
108                attn_dropout: vit_config.attn_dropout_rate,
109                proj_dropout: vit_config.dropout_rate,
110                path_dropout: 0.0,
111                norm_eps: 1e-5,
112                global_pool: false,
113                patch_embed_strategy: crate::vision::vit::PatchEmbedStrategy::Convolution,
114            };
115            let model = VisionTransformer::new(vit_config_obj);
116            Ok(ModelType::VisionTransformer(model?))
117        } else {
118            Err(ModelError::LoadingError {
119                reason: "Invalid Vision Transformer configuration".to_string(),
120            })
121        }
122    }
123}
124
125impl Default for VisionModelBuilder {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl ModelBuilder<ModelType> for VisionModelBuilder {
132    type Config = VisionModelConfig;
133
134    fn build(&self, config: &Self::Config) -> BuildResult<ModelType> {
135        // Validate configuration
136        config
137            .validate()
138            .map_err(|e| ModelError::ValidationError { reason: e })?;
139
140        match config.architecture {
141            VisionArchitecture::ResNet => self.build_resnet(config),
142            VisionArchitecture::EfficientNet => self.build_efficientnet(config),
143            VisionArchitecture::VisionTransformer => self.build_vit(config),
144            _ => Err(ModelError::LoadingError {
145                reason: format!("Architecture {:?} not yet implemented", config.architecture),
146            }),
147        }
148    }
149
150    fn build_from_name(&self, name: &str) -> BuildResult<ModelType> {
151        let config = self
152            .get_config(name)
153            .ok_or_else(|| ModelError::ModelNotFound {
154                name: name.to_string(),
155            })?;
156        self.build(&config)
157    }
158
159    fn available_models(&self) -> Vec<String> {
160        self.configs.keys().cloned().collect()
161    }
162
163    fn get_config(&self, name: &str) -> Option<Self::Config> {
164        self.configs.get(name).cloned()
165    }
166}
167
168/// NLP model builder
169pub struct NlpModelBuilder {
170    configs: HashMap<String, NlpModelConfig>,
171}
172
173impl NlpModelBuilder {
174    /// Create new NLP model builder
175    pub fn new() -> Self {
176        let mut configs = HashMap::new();
177
178        // Add predefined configurations
179        configs.extend(ModelConfigs::bert_configs());
180
181        Self { configs }
182    }
183
184    /// Add custom configuration
185    pub fn add_config(&mut self, name: String, config: NlpModelConfig) {
186        self.configs.insert(name, config);
187    }
188
189    /// Build BERT model (placeholder - actual implementation would be in nlp module)
190    fn build_bert(&self, _config: &NlpModelConfig) -> BuildResult<ModelType> {
191        // This is a placeholder - actual BERT implementation would go here
192        Err(ModelError::LoadingError {
193            reason: "BERT implementation not yet available".to_string(),
194        })
195    }
196}
197
198impl Default for NlpModelBuilder {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl ModelBuilder<ModelType> for NlpModelBuilder {
205    type Config = NlpModelConfig;
206
207    fn build(&self, config: &Self::Config) -> BuildResult<ModelType> {
208        // Validate configuration
209        config
210            .validate()
211            .map_err(|e| ModelError::ValidationError { reason: e })?;
212
213        match config.architecture {
214            NlpArchitecture::BERT => self.build_bert(config),
215            _ => Err(ModelError::LoadingError {
216                reason: format!("Architecture {:?} not yet implemented", config.architecture),
217            }),
218        }
219    }
220
221    fn build_from_name(&self, name: &str) -> BuildResult<ModelType> {
222        let config = self
223            .get_config(name)
224            .ok_or_else(|| ModelError::ModelNotFound {
225                name: name.to_string(),
226            })?;
227        self.build(&config)
228    }
229
230    fn available_models(&self) -> Vec<String> {
231        self.configs.keys().cloned().collect()
232    }
233
234    fn get_config(&self, name: &str) -> Option<Self::Config> {
235        self.configs.get(name).cloned()
236    }
237}
238
239/// Universal model factory
240pub struct ModelFactory {
241    vision_builder: VisionModelBuilder,
242    nlp_builder: NlpModelBuilder,
243}
244
245impl ModelFactory {
246    /// Create new model factory
247    pub fn new() -> Self {
248        Self {
249            vision_builder: VisionModelBuilder::new(),
250            nlp_builder: NlpModelBuilder::new(),
251        }
252    }
253
254    /// Build vision model
255    pub fn build_vision_model(&self, name: &str) -> BuildResult<ModelType> {
256        self.vision_builder.build_from_name(name)
257    }
258
259    /// Build vision model from config
260    pub fn build_vision_model_from_config(
261        &self,
262        config: &VisionModelConfig,
263    ) -> BuildResult<ModelType> {
264        self.vision_builder.build(config)
265    }
266
267    /// Build NLP model
268    pub fn build_nlp_model(&self, name: &str) -> BuildResult<ModelType> {
269        self.nlp_builder.build_from_name(name)
270    }
271
272    /// Build NLP model from config
273    pub fn build_nlp_model_from_config(&self, config: &NlpModelConfig) -> BuildResult<ModelType> {
274        self.nlp_builder.build(config)
275    }
276
277    /// List all available models
278    pub fn list_all_models(&self) -> HashMap<String, Vec<String>> {
279        let mut models = HashMap::new();
280        models.insert("vision".to_string(), self.vision_builder.available_models());
281        models.insert("nlp".to_string(), self.nlp_builder.available_models());
282        models
283    }
284
285    /// Get model information
286    pub fn get_model_info(&self, domain: &str, name: &str) -> Option<ModelInfo> {
287        match domain {
288            "vision" => {
289                if let Some(config) = self.vision_builder.get_config(name) {
290                    Some(ModelInfo {
291                        name: name.to_string(),
292                        architecture: config.model_name(),
293                        variant: config.variant(),
294                        parameters: config.estimated_parameters(),
295                        description: format!(
296                            "{} model for computer vision tasks",
297                            config.model_name()
298                        ),
299                        input_spec: format!("RGB image {:?}", config.input_size),
300                        output_spec: format!("{} class probabilities", config.num_classes),
301                        domain: "vision".to_string(),
302                    })
303                } else {
304                    None
305                }
306            }
307            "nlp" => {
308                if let Some(config) = self.nlp_builder.get_config(name) {
309                    Some(ModelInfo {
310                        name: name.to_string(),
311                        architecture: config.model_name(),
312                        variant: config.variant(),
313                        parameters: config.estimated_parameters(),
314                        description: format!(
315                            "{} model for natural language processing",
316                            config.model_name()
317                        ),
318                        input_spec: format!("Tokenized text [seq_len <= {}]", config.max_length),
319                        output_spec: "Hidden states or logits".to_string(),
320                        domain: "nlp".to_string(),
321                    })
322                } else {
323                    None
324                }
325            }
326            _ => None,
327        }
328    }
329
330    /// Create custom vision model with modifications
331    pub fn create_custom_vision_model(
332        &mut self,
333        base_name: &str,
334        custom_name: &str,
335        modifications: VisionModelModifications,
336    ) -> BuildResult<()> {
337        let mut base_config =
338            self.vision_builder
339                .get_config(base_name)
340                .ok_or_else(|| ModelError::ModelNotFound {
341                    name: base_name.to_string(),
342                })?;
343
344        // Apply modifications
345        if let Some(num_classes) = modifications.num_classes {
346            base_config.num_classes = num_classes;
347        }
348        if let Some(input_size) = modifications.input_size {
349            base_config.input_size = input_size;
350        }
351        if let Some(dropout_rate) = modifications.dropout_rate {
352            // Apply dropout rate based on architecture
353            match &mut base_config.arch_params {
354                VisionArchParams::VisionTransformer(ref mut vit_config) => {
355                    vit_config.dropout_rate = dropout_rate;
356                }
357                VisionArchParams::EfficientNet(ref mut eff_config) => {
358                    eff_config.dropout_rate = dropout_rate;
359                }
360                _ => {}
361            }
362        }
363
364        // Validate modified configuration
365        base_config
366            .validate()
367            .map_err(|e| ModelError::ValidationError { reason: e })?;
368
369        // Add to builder
370        self.vision_builder
371            .add_config(custom_name.to_string(), base_config);
372
373        Ok(())
374    }
375}
376
377impl Default for ModelFactory {
378    fn default() -> Self {
379        Self::new()
380    }
381}
382
383/// Vision model modifications for creating custom variants
384#[derive(Debug, Clone, Default)]
385pub struct VisionModelModifications {
386    pub num_classes: Option<usize>,
387    pub input_size: Option<(usize, usize)>,
388    pub dropout_rate: Option<f32>,
389}
390
391/// Simplified ModelInfo for the factory (to avoid circular dependencies)
392#[derive(Debug, Clone)]
393pub struct ModelInfo {
394    pub name: String,
395    pub architecture: String,
396    pub variant: String,
397    pub parameters: u64,
398    pub description: String,
399    pub input_spec: String,
400    pub output_spec: String,
401    pub domain: String,
402}
403
404lazy_static::lazy_static! {
405    /// Global model factory instance using lazy_static for safe static initialization
406    static ref GLOBAL_FACTORY: ModelFactory = ModelFactory::new();
407}
408
409/// Get the global model factory
410pub fn get_global_factory() -> &'static ModelFactory {
411    &GLOBAL_FACTORY
412}
413
414/// Convenience functions for easy model creation
415pub mod quick {
416    use super::*;
417
418    /// Create ResNet-18 model
419    pub fn resnet18(num_classes: usize) -> BuildResult<ModelType> {
420        let factory = get_global_factory();
421        let mut config = factory
422            .vision_builder
423            .get_config("resnet18")
424            .ok_or_else(|| ModelError::ModelNotFound {
425                name: "resnet18".to_string(),
426            })?;
427        config.num_classes = num_classes;
428        factory.build_vision_model_from_config(&config)
429    }
430
431    /// Create ResNet-50 model
432    pub fn resnet50(num_classes: usize) -> BuildResult<ModelType> {
433        let factory = get_global_factory();
434        let mut config = factory
435            .vision_builder
436            .get_config("resnet50")
437            .ok_or_else(|| ModelError::ModelNotFound {
438                name: "resnet50".to_string(),
439            })?;
440        config.num_classes = num_classes;
441        factory.build_vision_model_from_config(&config)
442    }
443
444    /// Create EfficientNet-B0 model
445    pub fn efficientnet_b0(num_classes: usize) -> BuildResult<ModelType> {
446        let factory = get_global_factory();
447        let mut config = factory
448            .vision_builder
449            .get_config("efficientnet_b0")
450            .ok_or_else(|| ModelError::ModelNotFound {
451                name: "efficientnet_b0".to_string(),
452            })?;
453        config.num_classes = num_classes;
454        factory.build_vision_model_from_config(&config)
455    }
456
457    /// Create ViT-Base model
458    pub fn vit_base(num_classes: usize) -> BuildResult<ModelType> {
459        let factory = get_global_factory();
460        let mut config = factory
461            .vision_builder
462            .get_config("vit_base_patch16_224")
463            .ok_or_else(|| ModelError::ModelNotFound {
464                name: "vit_base_patch16_224".to_string(),
465            })?;
466        config.num_classes = num_classes;
467        factory.build_vision_model_from_config(&config)
468    }
469
470    /// Create custom model with specific configuration
471    pub fn custom_vision_model(
472        base_model: &str,
473        num_classes: usize,
474        input_size: Option<(usize, usize)>,
475        dropout_rate: Option<f32>,
476    ) -> BuildResult<ModelType> {
477        let factory = get_global_factory();
478        let mut config = factory
479            .vision_builder
480            .get_config(base_model)
481            .ok_or_else(|| ModelError::ModelNotFound {
482                name: base_model.to_string(),
483            })?;
484
485        config.num_classes = num_classes;
486        if let Some(size) = input_size {
487            config.input_size = size;
488        }
489        if let Some(dropout) = dropout_rate {
490            match &mut config.arch_params {
491                VisionArchParams::VisionTransformer(ref mut vit_config) => {
492                    vit_config.dropout_rate = dropout;
493                }
494                VisionArchParams::EfficientNet(ref mut eff_config) => {
495                    eff_config.dropout_rate = dropout;
496                }
497                _ => {}
498            }
499        }
500
501        factory.build_vision_model_from_config(&config)
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_vision_model_builder() {
511        let builder = VisionModelBuilder::new();
512        let available = builder.available_models();
513
514        assert!(available.contains(&"resnet18".to_string()));
515        assert!(available.contains(&"efficientnet_b0".to_string()));
516        assert!(available.contains(&"vit_base_patch16_224".to_string()));
517    }
518
519    #[test]
520    fn test_model_factory() {
521        let factory = ModelFactory::new();
522        let all_models = factory.list_all_models();
523
524        assert!(all_models.contains_key("vision"));
525        assert!(all_models.contains_key("nlp"));
526
527        let vision_models = &all_models["vision"];
528        assert!(!vision_models.is_empty());
529    }
530
531    #[test]
532    fn test_model_info() {
533        let factory = ModelFactory::new();
534        let info = factory.get_model_info("vision", "resnet18").unwrap();
535
536        assert_eq!(info.name, "resnet18");
537        assert_eq!(info.domain, "vision");
538        assert!(info.parameters > 10_000_000);
539    }
540
541    #[test]
542    fn test_custom_model_creation() {
543        let mut factory = ModelFactory::new();
544
545        let modifications = VisionModelModifications {
546            num_classes: Some(10),
547            input_size: Some((32, 32)),
548            dropout_rate: None,
549        };
550
551        factory
552            .create_custom_vision_model("resnet18", "resnet18_cifar10", modifications)
553            .unwrap();
554
555        let available = factory.vision_builder.available_models();
556        assert!(available.contains(&"resnet18_cifar10".to_string()));
557
558        let config = factory
559            .vision_builder
560            .get_config("resnet18_cifar10")
561            .unwrap();
562        assert_eq!(config.num_classes, 10);
563        assert_eq!(config.input_size, (32, 32));
564    }
565
566    #[test]
567    fn test_quick_builders() {
568        // These would fail without actual model implementations, but test the interface
569        let result = quick::resnet18(1000);
570        assert!(result.is_ok() || result.is_err()); // Just test it compiles and runs
571    }
572}