Skip to main content

trustformers_models/
common_patterns.rs

1//! # Common Model Architecture Patterns and Traits
2//!
3//! This module provides common patterns, traits, and abstractions that are shared
4//! across different model implementations in the TrustformeRS Models crate.
5//!
6//! ## Features
7//!
8//! - **Common Model Traits**: Standardized interfaces for all models
9//! - **Architecture Patterns**: Reusable architectural components
10//! - **Weight Initialization**: Standardized weight initialization strategies
11//! - **Generation Interfaces**: Unified text generation APIs
12//! - **Evaluation Interfaces**: Common evaluation and testing interfaces
13//! - **Configuration Patterns**: Shared configuration management
14//!
15//! ## Common Traits
16//!
17//! ### ModelFamily
18//! Groups related model configurations and provides family-level operations
19//!
20//! ### GenerativeModel
21//! Unified interface for text generation across all model types
22//!
23//! ### EvaluableModel
24//! Common evaluation interface for model testing and validation
25//!
26//! ## Example Usage
27//!
28//! ```rust
29//! use trustformers_models::common_patterns::{ModelFamily, GenerativeModel, GenerationConfig};
30//!
31//! // Use any model through common interface
32//! fn generate_text<M: GenerativeModel>(model: &M, prompt: &str) -> Result<String> {
33//!     let config = GenerationConfig::default();
34//!     model.generate(prompt, &config)
35//! }
36//! ```
37
38use anyhow::Result;
39use scirs2_core::random::*; // SciRS2 Integration Policy - includes Rng, Distribution, Normal, Uniform
40use serde::{Deserialize, Serialize};
41use std::any::Any;
42use std::collections::HashMap;
43use std::sync::{Mutex, OnceLock};
44use trustformers_core::errors::Result as CoreResult;
45use trustformers_core::tensor::Tensor;
46use trustformers_core::traits::Config;
47
48/// Common generation configuration for all models
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct GenerationConfig {
51    pub max_new_tokens: usize,
52    pub max_length: Option<usize>,
53    pub temperature: f32,
54    pub top_p: f32,
55    pub top_k: Option<usize>,
56    pub repetition_penalty: f32,
57    pub length_penalty: f32,
58    pub do_sample: bool,
59    pub early_stopping: bool,
60    pub num_beams: Option<usize>,
61    pub num_return_sequences: usize,
62    pub pad_token_id: Option<u32>,
63    pub eos_token_id: Option<u32>,
64    pub use_cache: bool,
65    pub stream: bool,
66}
67
68impl Default for GenerationConfig {
69    fn default() -> Self {
70        Self {
71            max_new_tokens: 100,
72            max_length: None,
73            temperature: 1.0,
74            top_p: 0.9,
75            top_k: None,
76            repetition_penalty: 1.0,
77            length_penalty: 1.0,
78            do_sample: true,
79            early_stopping: false,
80            num_beams: None,
81            num_return_sequences: 1,
82            pad_token_id: None,
83            eos_token_id: None,
84            use_cache: true,
85            stream: false,
86        }
87    }
88}
89
90/// Weight initialization strategy
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
92pub enum InitializationStrategy {
93    /// Normal distribution with mean=0, std=initializer_range
94    Normal { std: f32 },
95    /// Xavier/Glorot uniform initialization
96    XavierUniform,
97    /// Xavier/Glorot normal initialization
98    XavierNormal,
99    /// Kaiming/He uniform initialization
100    KaimingUniform,
101    /// Kaiming/He normal initialization
102    KaimingNormal,
103    /// Truncated normal initialization
104    TruncatedNormal { std: f32, bounds: f32 },
105    /// Custom initialization function
106    Custom(String),
107}
108
109impl Default for InitializationStrategy {
110    fn default() -> Self {
111        Self::Normal { std: 0.02 }
112    }
113}
114
115/// A dyn-compatible version of Config trait for runtime usage
116pub trait DynConfig {
117    /// Validates the configuration for correctness
118    fn validate(&self) -> CoreResult<()>;
119
120    /// Returns the architecture name for this configuration
121    fn architecture(&self) -> &'static str;
122
123    /// Returns the configuration as Any for downcasting
124    fn as_any(&self) -> &dyn Any;
125}
126
127/// Blanket implementation for any type that implements Config
128impl<T: Config + 'static> DynConfig for T {
129    fn validate(&self) -> CoreResult<()> {
130        self.validate()
131    }
132
133    fn architecture(&self) -> &'static str {
134        self.architecture()
135    }
136
137    fn as_any(&self) -> &dyn Any {
138        self
139    }
140}
141
142/// Model family trait for grouping related models
143pub trait ModelFamily: Send + Sync {
144    /// Get the family name (e.g., "LLaMA", "BERT", "GPT")
145    fn family_name() -> &'static str
146    where
147        Self: Sized;
148
149    /// Get available model sizes
150    fn available_sizes() -> Vec<&'static str>
151    where
152        Self: Sized;
153
154    /// Get available variants (base, instruct, chat, etc.)
155    fn available_variants() -> Vec<&'static str>
156    where
157        Self: Sized;
158
159    /// Create configuration for a specific size and variant
160    fn create_config(size: &str, variant: Option<&str>) -> Result<Box<dyn DynConfig>>
161    where
162        Self: Sized;
163
164    /// Get recommended use cases for this family
165    fn use_cases() -> Vec<&'static str>
166    where
167        Self: Sized;
168
169    /// Get model family metadata
170    fn metadata() -> ModelFamilyMetadata
171    where
172        Self: Sized;
173}
174
175/// Metadata about a model family
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ModelFamilyMetadata {
178    pub family_name: String,
179    pub description: String,
180    pub paper_reference: Option<String>,
181    pub organization: Option<String>,
182    pub license: Option<String>,
183    pub release_date: Option<String>,
184    pub architecture_type: ArchitectureType,
185    pub supported_tasks: Vec<TaskType>,
186    pub compute_requirements: ComputeRequirements,
187}
188
189/// Architecture type classification
190#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
191pub enum ArchitectureType {
192    /// Encoder-only (BERT-style)
193    EncoderOnly,
194    /// Decoder-only (GPT-style)
195    DecoderOnly,
196    /// Encoder-decoder (T5-style)
197    EncoderDecoder,
198    /// State-space models (Mamba-style)
199    StateSpace,
200    /// Hybrid architectures
201    Hybrid,
202    /// Multimodal architectures
203    Multimodal,
204}
205
206/// Supported task types
207#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
208pub enum TaskType {
209    /// Text generation
210    TextGeneration,
211    /// Text classification
212    TextClassification,
213    /// Question answering
214    QuestionAnswering,
215    /// Summarization
216    Summarization,
217    /// Translation
218    Translation,
219    /// Code generation
220    CodeGeneration,
221    /// Image understanding
222    ImageUnderstanding,
223    /// Multimodal understanding
224    MultimodalUnderstanding,
225    /// Custom task
226    Custom(String),
227}
228
229/// Compute requirements information
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct ComputeRequirements {
232    pub minimum_vram_gb: f32,
233    pub recommended_vram_gb: f32,
234    pub minimum_ram_gb: f32,
235    pub cpu_requirements: String,
236    pub gpu_requirements: Option<String>,
237    pub supports_cpu_inference: bool,
238    pub supports_quantization: bool,
239}
240
241/// Common generative model interface
242pub trait GenerativeModel {
243    /// Generate text from a prompt
244    fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<String>;
245
246    /// Generate multiple completions
247    fn generate_batch(&self, prompts: &[&str], config: &GenerationConfig) -> Result<Vec<String>>;
248
249    /// Stream generation token by token
250    fn generate_stream(
251        &self,
252        prompt: &str,
253        config: &GenerationConfig,
254    ) -> Result<Box<dyn Iterator<Item = Result<String>>>>;
255
256    /// Get the maximum context length
257    fn max_context_length(&self) -> usize;
258
259    /// Get the model configuration
260    fn config(&self) -> &dyn DynConfig;
261
262    /// Check if the model supports a specific task
263    fn supports_task(&self, task: &TaskType) -> bool;
264}
265
266/// Model evaluation interface
267pub trait EvaluableModel {
268    /// Compute perplexity on a text corpus
269    fn compute_perplexity(&self, text: &str) -> Result<f32>;
270
271    /// Compute log likelihood of text
272    fn compute_log_likelihood(&self, text: &str) -> Result<f32>;
273
274    /// Get model embeddings for text
275    fn get_embeddings(&self, text: &str) -> Result<Tensor>;
276
277    /// Run model-specific evaluation metrics
278    fn evaluate(&self, evaluation_data: &EvaluationData) -> Result<EvaluationResults>;
279}
280
281/// Evaluation data structure
282#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct EvaluationData {
284    pub prompts: Vec<String>,
285    pub expected_outputs: Option<Vec<String>>,
286    pub task_type: TaskType,
287    pub metrics: Vec<EvaluationMetric>,
288}
289
290/// Evaluation metrics
291#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
292pub enum EvaluationMetric {
293    /// Perplexity
294    Perplexity,
295    /// BLEU score (for translation/generation)
296    BLEU,
297    /// ROUGE score (for summarization)
298    ROUGE,
299    /// Exact match accuracy
300    ExactMatch,
301    /// F1 score
302    F1Score,
303    /// Semantic similarity
304    SemanticSimilarity,
305    /// Custom metric
306    Custom(String),
307}
308
309/// Evaluation results
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct EvaluationResults {
312    pub overall_score: f32,
313    pub metric_scores: HashMap<EvaluationMetric, f32>,
314    pub per_sample_scores: Option<Vec<f32>>,
315    pub metadata: HashMap<String, String>,
316}
317
318/// Common model utilities
319pub struct ModelUtils;
320
321impl ModelUtils {
322    /// Initialize weights using the specified strategy
323    pub fn initialize_weights(
324        tensor: &mut Tensor,
325        strategy: &InitializationStrategy,
326    ) -> Result<()> {
327        match strategy {
328            InitializationStrategy::Normal { std } => {
329                // Initialize with normal distribution
330                let mut rng = thread_rng();
331                let normal = Normal::new(0.0, *std)
332                    .map_err(|e| anyhow::anyhow!("Failed to create normal distribution: {}", e))?;
333
334                match tensor {
335                    Tensor::F32(data) => {
336                        for value in data.iter_mut() {
337                            *value = normal.sample(&mut rng);
338                        }
339                        Ok(())
340                    },
341                    _ => Err(anyhow::anyhow!(
342                        "Normal initialization only supports F32 tensors"
343                    )),
344                }
345            },
346            InitializationStrategy::XavierUniform => {
347                // Xavier uniform initialization
348                let mut rng = thread_rng();
349                let shape = tensor.shape();
350
351                match tensor {
352                    Tensor::F32(data) => {
353                        let (fan_in, fan_out) = if shape.len() >= 2 {
354                            (shape[shape.len() - 1], shape[shape.len() - 2])
355                        } else {
356                            (1, data.len())
357                        };
358
359                        let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
360                        let uniform = Uniform::new(-limit, limit).map_err(|e| {
361                            anyhow::anyhow!("Failed to create uniform distribution: {}", e)
362                        })?;
363
364                        for value in data.iter_mut() {
365                            *value = uniform.sample(&mut rng);
366                        }
367                        Ok(())
368                    },
369                    _ => Err(anyhow::anyhow!(
370                        "Xavier uniform initialization only supports F32 tensors"
371                    )),
372                }
373            },
374            InitializationStrategy::XavierNormal => {
375                // Xavier normal initialization
376                let mut rng = thread_rng();
377                let shape = tensor.shape();
378
379                match tensor {
380                    Tensor::F32(data) => {
381                        let (fan_in, fan_out) = if shape.len() >= 2 {
382                            (shape[shape.len() - 1], shape[shape.len() - 2])
383                        } else {
384                            (1, data.len())
385                        };
386
387                        let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
388                        let normal = Normal::new(0.0, std).map_err(|e| {
389                            anyhow::anyhow!("Failed to create normal distribution: {}", e)
390                        })?;
391
392                        for value in data.iter_mut() {
393                            *value = normal.sample(&mut rng);
394                        }
395                        Ok(())
396                    },
397                    _ => Err(anyhow::anyhow!(
398                        "Xavier normal initialization only supports F32 tensors"
399                    )),
400                }
401            },
402            InitializationStrategy::KaimingUniform => {
403                // Kaiming uniform initialization
404                let mut rng = thread_rng();
405                let shape = tensor.shape();
406
407                match tensor {
408                    Tensor::F32(data) => {
409                        let fan_in = if shape.len() >= 2 { shape[shape.len() - 1] } else { 1 };
410
411                        let limit = (6.0 / fan_in as f32).sqrt();
412                        let uniform = Uniform::new(-limit, limit).map_err(|e| {
413                            anyhow::anyhow!("Failed to create uniform distribution: {}", e)
414                        })?;
415
416                        for value in data.iter_mut() {
417                            *value = uniform.sample(&mut rng);
418                        }
419                        Ok(())
420                    },
421                    _ => Err(anyhow::anyhow!(
422                        "Kaiming uniform initialization only supports F32 tensors"
423                    )),
424                }
425            },
426            InitializationStrategy::KaimingNormal => {
427                // Kaiming normal initialization
428                let mut rng = thread_rng();
429                let shape = tensor.shape();
430
431                match tensor {
432                    Tensor::F32(data) => {
433                        let fan_in = if shape.len() >= 2 { shape[shape.len() - 1] } else { 1 };
434
435                        let std = (2.0 / fan_in as f32).sqrt();
436                        let normal = Normal::new(0.0, std).map_err(|e| {
437                            anyhow::anyhow!("Failed to create normal distribution: {}", e)
438                        })?;
439
440                        for value in data.iter_mut() {
441                            *value = normal.sample(&mut rng);
442                        }
443                        Ok(())
444                    },
445                    _ => Err(anyhow::anyhow!(
446                        "Kaiming normal initialization only supports F32 tensors"
447                    )),
448                }
449            },
450            InitializationStrategy::TruncatedNormal { std, bounds } => {
451                // Truncated normal initialization
452                let mut rng = thread_rng();
453
454                match tensor {
455                    Tensor::F32(data) => {
456                        let normal = Normal::new(0.0, *std).map_err(|e| {
457                            anyhow::anyhow!("Failed to create normal distribution: {}", e)
458                        })?;
459
460                        for value in data.iter_mut() {
461                            loop {
462                                let sample = normal.sample(&mut rng);
463                                if sample.abs() <= *bounds {
464                                    *value = sample;
465                                    break;
466                                }
467                                // Resample if outside bounds
468                            }
469                        }
470                        Ok(())
471                    },
472                    _ => Err(anyhow::anyhow!(
473                        "Truncated normal initialization only supports F32 tensors"
474                    )),
475                }
476            },
477            InitializationStrategy::Custom(name) => {
478                // Custom initialization - lookup by name
479                match name.as_str() {
480                    "zero" => {
481                        // Initialize all weights to zero
482                        match tensor {
483                            Tensor::F32(data) => {
484                                data.fill(0.0);
485                                Ok(())
486                            },
487                            _ => Err(anyhow::anyhow!(
488                                "Custom zero initialization only supports F32 tensors"
489                            )),
490                        }
491                    },
492                    "ones" => {
493                        // Initialize all weights to one
494                        match tensor {
495                            Tensor::F32(data) => {
496                                data.fill(1.0);
497                                Ok(())
498                            },
499                            _ => Err(anyhow::anyhow!(
500                                "Custom ones initialization only supports F32 tensors"
501                            )),
502                        }
503                    },
504                    "identity" => {
505                        // Initialize as identity matrix (for square matrices)
506                        let shape = tensor.shape();
507                        match tensor {
508                            Tensor::F32(data) => {
509                                if shape.len() == 2 && shape[0] == shape[1] {
510                                    // Square matrix - initialize as identity
511                                    data.fill(0.0);
512                                    let dim = shape[0];
513                                    for i in 0..dim {
514                                        data[i * dim + i] = 1.0;
515                                    }
516                                    Ok(())
517                                } else {
518                                    Err(anyhow::anyhow!(
519                                        "Identity initialization requires square matrix"
520                                    ))
521                                }
522                            },
523                            _ => Err(anyhow::anyhow!(
524                                "Custom identity initialization only supports F32 tensors"
525                            )),
526                        }
527                    },
528                    _ => Err(anyhow::anyhow!(
529                        "Unknown custom initialization strategy: {}",
530                        name
531                    )),
532                }
533            },
534        }
535    }
536
537    /// Create a standard generation configuration for a task
538    pub fn generation_config_for_task(task: &TaskType) -> GenerationConfig {
539        match task {
540            TaskType::TextGeneration => GenerationConfig {
541                max_new_tokens: 512,
542                temperature: 0.7,
543                top_p: 0.9,
544                repetition_penalty: 1.1,
545                ..GenerationConfig::default()
546            },
547            TaskType::CodeGeneration => GenerationConfig {
548                max_new_tokens: 1024,
549                temperature: 0.2,
550                top_p: 0.95,
551                repetition_penalty: 1.05,
552                ..GenerationConfig::default()
553            },
554            TaskType::Summarization => GenerationConfig {
555                max_new_tokens: 256,
556                temperature: 0.3,
557                top_p: 0.9,
558                repetition_penalty: 1.2,
559                ..GenerationConfig::default()
560            },
561            TaskType::QuestionAnswering => GenerationConfig {
562                max_new_tokens: 128,
563                temperature: 0.1,
564                top_p: 0.95,
565                repetition_penalty: 1.0,
566                early_stopping: true,
567                ..GenerationConfig::default()
568            },
569            _ => GenerationConfig::default(),
570        }
571    }
572
573    /// Validate model configuration
574    pub fn validate_config(config: &dyn DynConfig) -> Result<Vec<String>> {
575        let warnings = Vec::new();
576
577        // Basic validation
578        config.validate()?;
579
580        // Additional checks can be added here
581        // For example, checking if model size is reasonable, etc.
582
583        Ok(warnings)
584    }
585
586    /// Estimate model memory requirements
587    pub fn estimate_memory_requirements(
588        vocab_size: usize,
589        hidden_size: usize,
590        num_layers: usize,
591        context_length: usize,
592    ) -> MemoryEstimate {
593        // Rough estimation for transformer models
594        let embedding_params = vocab_size * hidden_size;
595        let layer_params = num_layers
596            * (
597                // Self-attention weights
598                hidden_size * hidden_size * 4 +
599            // Feed-forward weights (assuming 4x expansion)
600            hidden_size * hidden_size * 4 * 2 +
601            // Layer norms
602            hidden_size * 2
603            );
604        let output_head_params = vocab_size * hidden_size;
605
606        let total_params = embedding_params + layer_params + output_head_params;
607
608        // Estimate memory usage (parameters + gradients + optimizer states + activations)
609        let model_memory_gb = (total_params * 4) as f32 / 1_000_000_000.0; // 4 bytes per param (FP32)
610        let activation_memory_gb =
611            (context_length * hidden_size * num_layers * 4) as f32 / 1_000_000_000.0;
612
613        MemoryEstimate {
614            total_parameters: total_params,
615            model_memory_gb,
616            activation_memory_gb,
617            total_memory_gb: model_memory_gb + activation_memory_gb,
618            inference_memory_gb: model_memory_gb + activation_memory_gb * 0.5,
619        }
620    }
621}
622
623/// Memory usage estimation
624#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct MemoryEstimate {
626    pub total_parameters: usize,
627    pub model_memory_gb: f32,
628    pub activation_memory_gb: f32,
629    pub total_memory_gb: f32,
630    pub inference_memory_gb: f32,
631}
632
633/// Generation strategy patterns
634#[derive(Debug, Clone, Serialize, Deserialize)]
635pub enum GenerationStrategy {
636    /// Greedy decoding
637    Greedy,
638    /// Sampling with temperature
639    Sampling { temperature: f32 },
640    /// Top-k sampling
641    TopK { k: usize, temperature: f32 },
642    /// Nucleus (top-p) sampling
643    TopP { p: f32, temperature: f32 },
644    /// Beam search
645    BeamSearch { num_beams: usize },
646    /// Diverse beam search
647    DiverseBeamSearch {
648        num_beams: usize,
649        diversity_penalty: f32,
650    },
651    /// Contrastive search
652    ContrastiveSearch { penalty_alpha: f32, top_k: usize },
653}
654
655/// Common architectural components
656pub mod components {
657    use super::*;
658
659    /// Standard transformer layer interface
660    pub trait TransformerLayer {
661        /// Forward pass through the layer
662        fn forward(&self, input: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor>;
663
664        /// Get layer configuration
665        fn config(&self) -> &dyn DynConfig;
666    }
667
668    /// Standard attention mechanism interface
669    pub trait AttentionMechanism {
670        /// Compute attention weights
671        fn compute_attention(&self, query: &Tensor, key: &Tensor, value: &Tensor)
672            -> Result<Tensor>;
673
674        /// Apply attention mask
675        fn apply_mask(&self, attention_weights: &Tensor, mask: &Tensor) -> Result<Tensor>;
676    }
677
678    /// Standard feed-forward network interface
679    pub trait FeedForwardNetwork {
680        /// Forward pass through FFN
681        fn forward(&self, input: &Tensor) -> Result<Tensor>;
682
683        /// Get hidden dimensions
684        fn hidden_size(&self) -> usize;
685        fn intermediate_size(&self) -> usize;
686    }
687
688    /// Standard embedding layer interface
689    pub trait EmbeddingLayer {
690        /// Forward pass for token embeddings
691        fn forward(&self, input_ids: &Tensor) -> Result<Tensor>;
692
693        /// Get vocabulary size
694        fn vocab_size(&self) -> usize;
695
696        /// Get embedding dimension
697        fn embedding_dim(&self) -> usize;
698    }
699}
700
701/// Model family implementation for registry
702pub trait ModelFamilyImpl: Send + Sync {
703    /// Create configuration for a specific size and variant
704    fn create_config(&self, size: &str, variant: Option<&str>) -> Result<Box<dyn DynConfig>>;
705
706    /// Get family name
707    fn family_name(&self) -> &'static str;
708
709    /// Get available model sizes
710    fn available_sizes(&self) -> Vec<&'static str>;
711
712    /// Get available variants (base, instruct, chat, etc.)
713    fn available_variants(&self) -> Vec<&'static str>;
714
715    /// Get recommended use cases for this family
716    fn use_cases(&self) -> Vec<&'static str>;
717
718    /// Get model family metadata
719    fn metadata(&self) -> ModelFamilyMetadata;
720}
721
722/// Wrapper for ModelFamily trait to make it work with ModelRegistry
723pub struct ModelFamilyWrapper<T: ModelFamily>(std::marker::PhantomData<T>);
724
725impl<T: ModelFamily> Default for ModelFamilyWrapper<T> {
726    fn default() -> Self {
727        Self::new()
728    }
729}
730
731impl<T: ModelFamily> ModelFamilyWrapper<T> {
732    pub fn new() -> Self {
733        Self(std::marker::PhantomData)
734    }
735}
736
737impl<T: ModelFamily> ModelFamilyImpl for ModelFamilyWrapper<T> {
738    fn create_config(&self, size: &str, variant: Option<&str>) -> Result<Box<dyn DynConfig>> {
739        T::create_config(size, variant)
740    }
741
742    fn family_name(&self) -> &'static str {
743        T::family_name()
744    }
745
746    fn available_sizes(&self) -> Vec<&'static str> {
747        T::available_sizes()
748    }
749
750    fn available_variants(&self) -> Vec<&'static str> {
751        T::available_variants()
752    }
753
754    fn use_cases(&self) -> Vec<&'static str> {
755        T::use_cases()
756    }
757
758    fn metadata(&self) -> ModelFamilyMetadata {
759        T::metadata()
760    }
761}
762
763/// Model registry for dynamic model creation
764pub struct ModelRegistry {
765    families: HashMap<String, Box<dyn ModelFamilyImpl>>,
766}
767
768impl Default for ModelRegistry {
769    fn default() -> Self {
770        Self::new()
771    }
772}
773
774impl ModelRegistry {
775    pub fn new() -> Self {
776        Self {
777            families: HashMap::new(),
778        }
779    }
780
781    pub fn register_family<F: ModelFamilyImpl + 'static>(&mut self, family: F) {
782        let name = family.family_name().to_string();
783        self.families.insert(name, Box::new(family));
784    }
785
786    pub fn register_model_family<F: ModelFamily + 'static>(&mut self) {
787        let wrapper = ModelFamilyWrapper::<F>::new();
788        self.register_family(wrapper);
789    }
790
791    pub fn get_family(&self, name: &str) -> Option<&dyn ModelFamilyImpl> {
792        self.families.get(name).map(|f| f.as_ref())
793    }
794
795    pub fn list_families(&self) -> Vec<&str> {
796        self.families.keys().map(|s| s.as_str()).collect()
797    }
798
799    pub fn create_model(
800        &self,
801        family: &str,
802        size: &str,
803        variant: Option<&str>,
804    ) -> Result<Box<dyn DynConfig>> {
805        let family_impl = self
806            .get_family(family)
807            .ok_or_else(|| anyhow::anyhow!("Unknown model family: {}", family))?;
808
809        family_impl.create_config(size, variant)
810    }
811}
812
813/// Global model registry instance
814static MODEL_REGISTRY: OnceLock<Mutex<ModelRegistry>> = OnceLock::new();
815
816pub fn get_global_registry() -> &'static Mutex<ModelRegistry> {
817    MODEL_REGISTRY.get_or_init(|| Mutex::new(ModelRegistry::new()))
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823
824    #[test]
825    fn test_generation_config_default() {
826        let config = GenerationConfig::default();
827        assert_eq!(config.max_new_tokens, 100);
828        assert_eq!(config.temperature, 1.0);
829        assert!(config.do_sample);
830    }
831
832    #[test]
833    fn test_task_specific_generation_config() {
834        let code_config = ModelUtils::generation_config_for_task(&TaskType::CodeGeneration);
835        assert_eq!(code_config.temperature, 0.2);
836        assert_eq!(code_config.max_new_tokens, 1024);
837
838        let qa_config = ModelUtils::generation_config_for_task(&TaskType::QuestionAnswering);
839        assert_eq!(qa_config.temperature, 0.1);
840        assert!(qa_config.early_stopping);
841    }
842
843    #[test]
844    fn test_memory_estimation() {
845        let estimate = ModelUtils::estimate_memory_requirements(32000, 4096, 32, 2048);
846        assert!(estimate.total_parameters > 0);
847        assert!(estimate.model_memory_gb > 0.0);
848        assert!(estimate.total_memory_gb >= estimate.model_memory_gb);
849    }
850
851    #[test]
852    fn test_model_registry() {
853        let registry = ModelRegistry::new();
854        // Test basic registry operations
855        assert_eq!(registry.list_families().len(), 0);
856    }
857
858    #[test]
859    fn test_initialization_strategy() {
860        let strategy = InitializationStrategy::Normal { std: 0.02 };
861        match strategy {
862            InitializationStrategy::Normal { std } => assert_eq!(std, 0.02),
863            _ => panic!("Wrong strategy type"),
864        }
865    }
866
867    #[test]
868    fn test_architecture_types() {
869        let encoder_only = ArchitectureType::EncoderOnly;
870        let decoder_only = ArchitectureType::DecoderOnly;
871        assert_ne!(encoder_only, decoder_only);
872    }
873}