1#![allow(dead_code)]
5use crate::config::{
6 ModelConfig, ModelConfigs, NlpArchitecture, NlpModelConfig, VisionArchParams,
7 VisionArchitecture, VisionModelConfig,
8};
9use crate::vision::{ResNet, VisionTransformer};
10use crate::{ModelError, ModelType};
11use std::collections::HashMap;
12use torsh_nn::Module;
13
14pub type BuildResult<T> = Result<T, ModelError>;
16
17pub trait ModelBuilder<T: Module>: Send + Sync {
19 type Config: ModelConfig;
20
21 fn build(&self, config: &Self::Config) -> BuildResult<T>;
23
24 fn build_from_name(&self, name: &str) -> BuildResult<T>;
26
27 fn available_models(&self) -> Vec<String>;
29
30 fn get_config(&self, name: &str) -> Option<Self::Config>;
32}
33
34pub struct VisionModelBuilder {
36 configs: HashMap<String, VisionModelConfig>,
37}
38
39impl VisionModelBuilder {
40 pub fn new() -> Self {
42 let mut configs = HashMap::new();
43
44 configs.extend(ModelConfigs::resnet_configs());
46 configs.extend(ModelConfigs::efficientnet_configs());
47 configs.extend(ModelConfigs::vit_configs());
48
49 Self { configs }
50 }
51
52 pub fn add_config(&mut self, name: String, config: VisionModelConfig) {
54 self.configs.insert(name, config);
55 }
56
57 fn build_resnet(&self, config: &VisionModelConfig) -> BuildResult<ModelType> {
59 if let VisionArchParams::ResNet(resnet_config) = &config.arch_params {
60 let model = if resnet_config.layers == [2, 2, 2, 2] {
61 ResNet::resnet18(config.num_classes)
62 } else if resnet_config.layers == [3, 4, 6, 3] && !resnet_config.bottleneck {
63 ResNet::resnet34(config.num_classes)
64 } else if resnet_config.layers == [3, 4, 6, 3] && resnet_config.bottleneck {
65 ResNet::resnet50(config.num_classes)
66 } else {
67 return Err(ModelError::LoadingError {
68 reason: format!(
69 "Unsupported ResNet configuration: {:?}",
70 resnet_config.layers
71 ),
72 });
73 };
74
75 Ok(ModelType::ResNet(model?))
76 } else {
77 Err(ModelError::LoadingError {
78 reason: "Invalid ResNet configuration".to_string(),
79 })
80 }
81 }
82
83 fn build_efficientnet(&self, _config: &VisionModelConfig) -> BuildResult<ModelType> {
85 Err(ModelError::LoadingError {
88 reason: "EfficientNet not yet implemented".to_string(),
89 })
90 }
91
92 fn build_vit(&self, config: &VisionModelConfig) -> BuildResult<ModelType> {
94 if let VisionArchParams::VisionTransformer(vit_config) = &config.arch_params {
95 let vit_config_obj = crate::vision::vit::ViTConfig {
97 variant: crate::vision::vit::ViTVariant::Base,
98 img_size: config.input_size.0,
99 patch_size: vit_config.patch_size,
100 in_channels: 3, num_classes: config.num_classes,
102 embed_dim: vit_config.embed_dim,
103 depth: vit_config.depth,
104 num_heads: vit_config.num_heads,
105 mlp_ratio: vit_config.mlp_ratio,
106 qkv_bias: true,
107 representation_size: None,
108 attn_dropout: vit_config.attn_dropout_rate,
109 proj_dropout: vit_config.dropout_rate,
110 path_dropout: 0.0,
111 norm_eps: 1e-5,
112 global_pool: false,
113 patch_embed_strategy: crate::vision::vit::PatchEmbedStrategy::Convolution,
114 };
115 let model = VisionTransformer::new(vit_config_obj);
116 Ok(ModelType::VisionTransformer(model?))
117 } else {
118 Err(ModelError::LoadingError {
119 reason: "Invalid Vision Transformer configuration".to_string(),
120 })
121 }
122 }
123}
124
125impl Default for VisionModelBuilder {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl ModelBuilder<ModelType> for VisionModelBuilder {
132 type Config = VisionModelConfig;
133
134 fn build(&self, config: &Self::Config) -> BuildResult<ModelType> {
135 config
137 .validate()
138 .map_err(|e| ModelError::ValidationError { reason: e })?;
139
140 match config.architecture {
141 VisionArchitecture::ResNet => self.build_resnet(config),
142 VisionArchitecture::EfficientNet => self.build_efficientnet(config),
143 VisionArchitecture::VisionTransformer => self.build_vit(config),
144 _ => Err(ModelError::LoadingError {
145 reason: format!("Architecture {:?} not yet implemented", config.architecture),
146 }),
147 }
148 }
149
150 fn build_from_name(&self, name: &str) -> BuildResult<ModelType> {
151 let config = self
152 .get_config(name)
153 .ok_or_else(|| ModelError::ModelNotFound {
154 name: name.to_string(),
155 })?;
156 self.build(&config)
157 }
158
159 fn available_models(&self) -> Vec<String> {
160 self.configs.keys().cloned().collect()
161 }
162
163 fn get_config(&self, name: &str) -> Option<Self::Config> {
164 self.configs.get(name).cloned()
165 }
166}
167
168pub struct NlpModelBuilder {
170 configs: HashMap<String, NlpModelConfig>,
171}
172
173impl NlpModelBuilder {
174 pub fn new() -> Self {
176 let mut configs = HashMap::new();
177
178 configs.extend(ModelConfigs::bert_configs());
180
181 Self { configs }
182 }
183
184 pub fn add_config(&mut self, name: String, config: NlpModelConfig) {
186 self.configs.insert(name, config);
187 }
188
189 fn build_bert(&self, _config: &NlpModelConfig) -> BuildResult<ModelType> {
191 Err(ModelError::LoadingError {
193 reason: "BERT implementation not yet available".to_string(),
194 })
195 }
196}
197
198impl Default for NlpModelBuilder {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl ModelBuilder<ModelType> for NlpModelBuilder {
205 type Config = NlpModelConfig;
206
207 fn build(&self, config: &Self::Config) -> BuildResult<ModelType> {
208 config
210 .validate()
211 .map_err(|e| ModelError::ValidationError { reason: e })?;
212
213 match config.architecture {
214 NlpArchitecture::BERT => self.build_bert(config),
215 _ => Err(ModelError::LoadingError {
216 reason: format!("Architecture {:?} not yet implemented", config.architecture),
217 }),
218 }
219 }
220
221 fn build_from_name(&self, name: &str) -> BuildResult<ModelType> {
222 let config = self
223 .get_config(name)
224 .ok_or_else(|| ModelError::ModelNotFound {
225 name: name.to_string(),
226 })?;
227 self.build(&config)
228 }
229
230 fn available_models(&self) -> Vec<String> {
231 self.configs.keys().cloned().collect()
232 }
233
234 fn get_config(&self, name: &str) -> Option<Self::Config> {
235 self.configs.get(name).cloned()
236 }
237}
238
239pub struct ModelFactory {
241 vision_builder: VisionModelBuilder,
242 nlp_builder: NlpModelBuilder,
243}
244
245impl ModelFactory {
246 pub fn new() -> Self {
248 Self {
249 vision_builder: VisionModelBuilder::new(),
250 nlp_builder: NlpModelBuilder::new(),
251 }
252 }
253
254 pub fn build_vision_model(&self, name: &str) -> BuildResult<ModelType> {
256 self.vision_builder.build_from_name(name)
257 }
258
259 pub fn build_vision_model_from_config(
261 &self,
262 config: &VisionModelConfig,
263 ) -> BuildResult<ModelType> {
264 self.vision_builder.build(config)
265 }
266
267 pub fn build_nlp_model(&self, name: &str) -> BuildResult<ModelType> {
269 self.nlp_builder.build_from_name(name)
270 }
271
272 pub fn build_nlp_model_from_config(&self, config: &NlpModelConfig) -> BuildResult<ModelType> {
274 self.nlp_builder.build(config)
275 }
276
277 pub fn list_all_models(&self) -> HashMap<String, Vec<String>> {
279 let mut models = HashMap::new();
280 models.insert("vision".to_string(), self.vision_builder.available_models());
281 models.insert("nlp".to_string(), self.nlp_builder.available_models());
282 models
283 }
284
285 pub fn get_model_info(&self, domain: &str, name: &str) -> Option<ModelInfo> {
287 match domain {
288 "vision" => {
289 if let Some(config) = self.vision_builder.get_config(name) {
290 Some(ModelInfo {
291 name: name.to_string(),
292 architecture: config.model_name(),
293 variant: config.variant(),
294 parameters: config.estimated_parameters(),
295 description: format!(
296 "{} model for computer vision tasks",
297 config.model_name()
298 ),
299 input_spec: format!("RGB image {:?}", config.input_size),
300 output_spec: format!("{} class probabilities", config.num_classes),
301 domain: "vision".to_string(),
302 })
303 } else {
304 None
305 }
306 }
307 "nlp" => {
308 if let Some(config) = self.nlp_builder.get_config(name) {
309 Some(ModelInfo {
310 name: name.to_string(),
311 architecture: config.model_name(),
312 variant: config.variant(),
313 parameters: config.estimated_parameters(),
314 description: format!(
315 "{} model for natural language processing",
316 config.model_name()
317 ),
318 input_spec: format!("Tokenized text [seq_len <= {}]", config.max_length),
319 output_spec: "Hidden states or logits".to_string(),
320 domain: "nlp".to_string(),
321 })
322 } else {
323 None
324 }
325 }
326 _ => None,
327 }
328 }
329
330 pub fn create_custom_vision_model(
332 &mut self,
333 base_name: &str,
334 custom_name: &str,
335 modifications: VisionModelModifications,
336 ) -> BuildResult<()> {
337 let mut base_config =
338 self.vision_builder
339 .get_config(base_name)
340 .ok_or_else(|| ModelError::ModelNotFound {
341 name: base_name.to_string(),
342 })?;
343
344 if let Some(num_classes) = modifications.num_classes {
346 base_config.num_classes = num_classes;
347 }
348 if let Some(input_size) = modifications.input_size {
349 base_config.input_size = input_size;
350 }
351 if let Some(dropout_rate) = modifications.dropout_rate {
352 match &mut base_config.arch_params {
354 VisionArchParams::VisionTransformer(ref mut vit_config) => {
355 vit_config.dropout_rate = dropout_rate;
356 }
357 VisionArchParams::EfficientNet(ref mut eff_config) => {
358 eff_config.dropout_rate = dropout_rate;
359 }
360 _ => {}
361 }
362 }
363
364 base_config
366 .validate()
367 .map_err(|e| ModelError::ValidationError { reason: e })?;
368
369 self.vision_builder
371 .add_config(custom_name.to_string(), base_config);
372
373 Ok(())
374 }
375}
376
377impl Default for ModelFactory {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383#[derive(Debug, Clone, Default)]
385pub struct VisionModelModifications {
386 pub num_classes: Option<usize>,
387 pub input_size: Option<(usize, usize)>,
388 pub dropout_rate: Option<f32>,
389}
390
391#[derive(Debug, Clone)]
393pub struct ModelInfo {
394 pub name: String,
395 pub architecture: String,
396 pub variant: String,
397 pub parameters: u64,
398 pub description: String,
399 pub input_spec: String,
400 pub output_spec: String,
401 pub domain: String,
402}
403
404lazy_static::lazy_static! {
405 static ref GLOBAL_FACTORY: ModelFactory = ModelFactory::new();
407}
408
409pub fn get_global_factory() -> &'static ModelFactory {
411 &GLOBAL_FACTORY
412}
413
414pub mod quick {
416 use super::*;
417
418 pub fn resnet18(num_classes: usize) -> BuildResult<ModelType> {
420 let factory = get_global_factory();
421 let mut config = factory
422 .vision_builder
423 .get_config("resnet18")
424 .ok_or_else(|| ModelError::ModelNotFound {
425 name: "resnet18".to_string(),
426 })?;
427 config.num_classes = num_classes;
428 factory.build_vision_model_from_config(&config)
429 }
430
431 pub fn resnet50(num_classes: usize) -> BuildResult<ModelType> {
433 let factory = get_global_factory();
434 let mut config = factory
435 .vision_builder
436 .get_config("resnet50")
437 .ok_or_else(|| ModelError::ModelNotFound {
438 name: "resnet50".to_string(),
439 })?;
440 config.num_classes = num_classes;
441 factory.build_vision_model_from_config(&config)
442 }
443
444 pub fn efficientnet_b0(num_classes: usize) -> BuildResult<ModelType> {
446 let factory = get_global_factory();
447 let mut config = factory
448 .vision_builder
449 .get_config("efficientnet_b0")
450 .ok_or_else(|| ModelError::ModelNotFound {
451 name: "efficientnet_b0".to_string(),
452 })?;
453 config.num_classes = num_classes;
454 factory.build_vision_model_from_config(&config)
455 }
456
457 pub fn vit_base(num_classes: usize) -> BuildResult<ModelType> {
459 let factory = get_global_factory();
460 let mut config = factory
461 .vision_builder
462 .get_config("vit_base_patch16_224")
463 .ok_or_else(|| ModelError::ModelNotFound {
464 name: "vit_base_patch16_224".to_string(),
465 })?;
466 config.num_classes = num_classes;
467 factory.build_vision_model_from_config(&config)
468 }
469
470 pub fn custom_vision_model(
472 base_model: &str,
473 num_classes: usize,
474 input_size: Option<(usize, usize)>,
475 dropout_rate: Option<f32>,
476 ) -> BuildResult<ModelType> {
477 let factory = get_global_factory();
478 let mut config = factory
479 .vision_builder
480 .get_config(base_model)
481 .ok_or_else(|| ModelError::ModelNotFound {
482 name: base_model.to_string(),
483 })?;
484
485 config.num_classes = num_classes;
486 if let Some(size) = input_size {
487 config.input_size = size;
488 }
489 if let Some(dropout) = dropout_rate {
490 match &mut config.arch_params {
491 VisionArchParams::VisionTransformer(ref mut vit_config) => {
492 vit_config.dropout_rate = dropout;
493 }
494 VisionArchParams::EfficientNet(ref mut eff_config) => {
495 eff_config.dropout_rate = dropout;
496 }
497 _ => {}
498 }
499 }
500
501 factory.build_vision_model_from_config(&config)
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_vision_model_builder() {
511 let builder = VisionModelBuilder::new();
512 let available = builder.available_models();
513
514 assert!(available.contains(&"resnet18".to_string()));
515 assert!(available.contains(&"efficientnet_b0".to_string()));
516 assert!(available.contains(&"vit_base_patch16_224".to_string()));
517 }
518
519 #[test]
520 fn test_model_factory() {
521 let factory = ModelFactory::new();
522 let all_models = factory.list_all_models();
523
524 assert!(all_models.contains_key("vision"));
525 assert!(all_models.contains_key("nlp"));
526
527 let vision_models = &all_models["vision"];
528 assert!(!vision_models.is_empty());
529 }
530
531 #[test]
532 fn test_model_info() {
533 let factory = ModelFactory::new();
534 let info = factory.get_model_info("vision", "resnet18").unwrap();
535
536 assert_eq!(info.name, "resnet18");
537 assert_eq!(info.domain, "vision");
538 assert!(info.parameters > 10_000_000);
539 }
540
541 #[test]
542 fn test_custom_model_creation() {
543 let mut factory = ModelFactory::new();
544
545 let modifications = VisionModelModifications {
546 num_classes: Some(10),
547 input_size: Some((32, 32)),
548 dropout_rate: None,
549 };
550
551 factory
552 .create_custom_vision_model("resnet18", "resnet18_cifar10", modifications)
553 .unwrap();
554
555 let available = factory.vision_builder.available_models();
556 assert!(available.contains(&"resnet18_cifar10".to_string()));
557
558 let config = factory
559 .vision_builder
560 .get_config("resnet18_cifar10")
561 .unwrap();
562 assert_eq!(config.num_classes, 10);
563 assert_eq!(config.input_size, (32, 32));
564 }
565
566 #[test]
567 fn test_quick_builders() {
568 let result = quick::resnet18(1000);
570 assert!(result.is_ok() || result.is_err()); }
572}