1use anyhow::Result;
39use scirs2_core::random::*; use serde::{Deserialize, Serialize};
41use std::any::Any;
42use std::collections::HashMap;
43use std::sync::{Mutex, OnceLock};
44use trustformers_core::errors::Result as CoreResult;
45use trustformers_core::tensor::Tensor;
46use trustformers_core::traits::Config;
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct GenerationConfig {
51 pub max_new_tokens: usize,
52 pub max_length: Option<usize>,
53 pub temperature: f32,
54 pub top_p: f32,
55 pub top_k: Option<usize>,
56 pub repetition_penalty: f32,
57 pub length_penalty: f32,
58 pub do_sample: bool,
59 pub early_stopping: bool,
60 pub num_beams: Option<usize>,
61 pub num_return_sequences: usize,
62 pub pad_token_id: Option<u32>,
63 pub eos_token_id: Option<u32>,
64 pub use_cache: bool,
65 pub stream: bool,
66}
67
68impl Default for GenerationConfig {
69 fn default() -> Self {
70 Self {
71 max_new_tokens: 100,
72 max_length: None,
73 temperature: 1.0,
74 top_p: 0.9,
75 top_k: None,
76 repetition_penalty: 1.0,
77 length_penalty: 1.0,
78 do_sample: true,
79 early_stopping: false,
80 num_beams: None,
81 num_return_sequences: 1,
82 pad_token_id: None,
83 eos_token_id: None,
84 use_cache: true,
85 stream: false,
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
92pub enum InitializationStrategy {
93 Normal { std: f32 },
95 XavierUniform,
97 XavierNormal,
99 KaimingUniform,
101 KaimingNormal,
103 TruncatedNormal { std: f32, bounds: f32 },
105 Custom(String),
107}
108
109impl Default for InitializationStrategy {
110 fn default() -> Self {
111 Self::Normal { std: 0.02 }
112 }
113}
114
115pub trait DynConfig {
117 fn validate(&self) -> CoreResult<()>;
119
120 fn architecture(&self) -> &'static str;
122
123 fn as_any(&self) -> &dyn Any;
125}
126
127impl<T: Config + 'static> DynConfig for T {
129 fn validate(&self) -> CoreResult<()> {
130 self.validate()
131 }
132
133 fn architecture(&self) -> &'static str {
134 self.architecture()
135 }
136
137 fn as_any(&self) -> &dyn Any {
138 self
139 }
140}
141
142pub trait ModelFamily: Send + Sync {
144 fn family_name() -> &'static str
146 where
147 Self: Sized;
148
149 fn available_sizes() -> Vec<&'static str>
151 where
152 Self: Sized;
153
154 fn available_variants() -> Vec<&'static str>
156 where
157 Self: Sized;
158
159 fn create_config(size: &str, variant: Option<&str>) -> Result<Box<dyn DynConfig>>
161 where
162 Self: Sized;
163
164 fn use_cases() -> Vec<&'static str>
166 where
167 Self: Sized;
168
169 fn metadata() -> ModelFamilyMetadata
171 where
172 Self: Sized;
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ModelFamilyMetadata {
178 pub family_name: String,
179 pub description: String,
180 pub paper_reference: Option<String>,
181 pub organization: Option<String>,
182 pub license: Option<String>,
183 pub release_date: Option<String>,
184 pub architecture_type: ArchitectureType,
185 pub supported_tasks: Vec<TaskType>,
186 pub compute_requirements: ComputeRequirements,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
191pub enum ArchitectureType {
192 EncoderOnly,
194 DecoderOnly,
196 EncoderDecoder,
198 StateSpace,
200 Hybrid,
202 Multimodal,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
208pub enum TaskType {
209 TextGeneration,
211 TextClassification,
213 QuestionAnswering,
215 Summarization,
217 Translation,
219 CodeGeneration,
221 ImageUnderstanding,
223 MultimodalUnderstanding,
225 Custom(String),
227}
228
229#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct ComputeRequirements {
232 pub minimum_vram_gb: f32,
233 pub recommended_vram_gb: f32,
234 pub minimum_ram_gb: f32,
235 pub cpu_requirements: String,
236 pub gpu_requirements: Option<String>,
237 pub supports_cpu_inference: bool,
238 pub supports_quantization: bool,
239}
240
241pub trait GenerativeModel {
243 fn generate(&self, prompt: &str, config: &GenerationConfig) -> Result<String>;
245
246 fn generate_batch(&self, prompts: &[&str], config: &GenerationConfig) -> Result<Vec<String>>;
248
249 fn generate_stream(
251 &self,
252 prompt: &str,
253 config: &GenerationConfig,
254 ) -> Result<Box<dyn Iterator<Item = Result<String>>>>;
255
256 fn max_context_length(&self) -> usize;
258
259 fn config(&self) -> &dyn DynConfig;
261
262 fn supports_task(&self, task: &TaskType) -> bool;
264}
265
266pub trait EvaluableModel {
268 fn compute_perplexity(&self, text: &str) -> Result<f32>;
270
271 fn compute_log_likelihood(&self, text: &str) -> Result<f32>;
273
274 fn get_embeddings(&self, text: &str) -> Result<Tensor>;
276
277 fn evaluate(&self, evaluation_data: &EvaluationData) -> Result<EvaluationResults>;
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct EvaluationData {
284 pub prompts: Vec<String>,
285 pub expected_outputs: Option<Vec<String>>,
286 pub task_type: TaskType,
287 pub metrics: Vec<EvaluationMetric>,
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
292pub enum EvaluationMetric {
293 Perplexity,
295 BLEU,
297 ROUGE,
299 ExactMatch,
301 F1Score,
303 SemanticSimilarity,
305 Custom(String),
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct EvaluationResults {
312 pub overall_score: f32,
313 pub metric_scores: HashMap<EvaluationMetric, f32>,
314 pub per_sample_scores: Option<Vec<f32>>,
315 pub metadata: HashMap<String, String>,
316}
317
318pub struct ModelUtils;
320
321impl ModelUtils {
322 pub fn initialize_weights(
324 tensor: &mut Tensor,
325 strategy: &InitializationStrategy,
326 ) -> Result<()> {
327 match strategy {
328 InitializationStrategy::Normal { std } => {
329 let mut rng = thread_rng();
331 let normal = Normal::new(0.0, *std)
332 .map_err(|e| anyhow::anyhow!("Failed to create normal distribution: {}", e))?;
333
334 match tensor {
335 Tensor::F32(data) => {
336 for value in data.iter_mut() {
337 *value = normal.sample(&mut rng);
338 }
339 Ok(())
340 },
341 _ => Err(anyhow::anyhow!(
342 "Normal initialization only supports F32 tensors"
343 )),
344 }
345 },
346 InitializationStrategy::XavierUniform => {
347 let mut rng = thread_rng();
349 let shape = tensor.shape();
350
351 match tensor {
352 Tensor::F32(data) => {
353 let (fan_in, fan_out) = if shape.len() >= 2 {
354 (shape[shape.len() - 1], shape[shape.len() - 2])
355 } else {
356 (1, data.len())
357 };
358
359 let limit = (6.0 / (fan_in + fan_out) as f32).sqrt();
360 let uniform = Uniform::new(-limit, limit).map_err(|e| {
361 anyhow::anyhow!("Failed to create uniform distribution: {}", e)
362 })?;
363
364 for value in data.iter_mut() {
365 *value = uniform.sample(&mut rng);
366 }
367 Ok(())
368 },
369 _ => Err(anyhow::anyhow!(
370 "Xavier uniform initialization only supports F32 tensors"
371 )),
372 }
373 },
374 InitializationStrategy::XavierNormal => {
375 let mut rng = thread_rng();
377 let shape = tensor.shape();
378
379 match tensor {
380 Tensor::F32(data) => {
381 let (fan_in, fan_out) = if shape.len() >= 2 {
382 (shape[shape.len() - 1], shape[shape.len() - 2])
383 } else {
384 (1, data.len())
385 };
386
387 let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
388 let normal = Normal::new(0.0, std).map_err(|e| {
389 anyhow::anyhow!("Failed to create normal distribution: {}", e)
390 })?;
391
392 for value in data.iter_mut() {
393 *value = normal.sample(&mut rng);
394 }
395 Ok(())
396 },
397 _ => Err(anyhow::anyhow!(
398 "Xavier normal initialization only supports F32 tensors"
399 )),
400 }
401 },
402 InitializationStrategy::KaimingUniform => {
403 let mut rng = thread_rng();
405 let shape = tensor.shape();
406
407 match tensor {
408 Tensor::F32(data) => {
409 let fan_in = if shape.len() >= 2 { shape[shape.len() - 1] } else { 1 };
410
411 let limit = (6.0 / fan_in as f32).sqrt();
412 let uniform = Uniform::new(-limit, limit).map_err(|e| {
413 anyhow::anyhow!("Failed to create uniform distribution: {}", e)
414 })?;
415
416 for value in data.iter_mut() {
417 *value = uniform.sample(&mut rng);
418 }
419 Ok(())
420 },
421 _ => Err(anyhow::anyhow!(
422 "Kaiming uniform initialization only supports F32 tensors"
423 )),
424 }
425 },
426 InitializationStrategy::KaimingNormal => {
427 let mut rng = thread_rng();
429 let shape = tensor.shape();
430
431 match tensor {
432 Tensor::F32(data) => {
433 let fan_in = if shape.len() >= 2 { shape[shape.len() - 1] } else { 1 };
434
435 let std = (2.0 / fan_in as f32).sqrt();
436 let normal = Normal::new(0.0, std).map_err(|e| {
437 anyhow::anyhow!("Failed to create normal distribution: {}", e)
438 })?;
439
440 for value in data.iter_mut() {
441 *value = normal.sample(&mut rng);
442 }
443 Ok(())
444 },
445 _ => Err(anyhow::anyhow!(
446 "Kaiming normal initialization only supports F32 tensors"
447 )),
448 }
449 },
450 InitializationStrategy::TruncatedNormal { std, bounds } => {
451 let mut rng = thread_rng();
453
454 match tensor {
455 Tensor::F32(data) => {
456 let normal = Normal::new(0.0, *std).map_err(|e| {
457 anyhow::anyhow!("Failed to create normal distribution: {}", e)
458 })?;
459
460 for value in data.iter_mut() {
461 loop {
462 let sample = normal.sample(&mut rng);
463 if sample.abs() <= *bounds {
464 *value = sample;
465 break;
466 }
467 }
469 }
470 Ok(())
471 },
472 _ => Err(anyhow::anyhow!(
473 "Truncated normal initialization only supports F32 tensors"
474 )),
475 }
476 },
477 InitializationStrategy::Custom(name) => {
478 match name.as_str() {
480 "zero" => {
481 match tensor {
483 Tensor::F32(data) => {
484 data.fill(0.0);
485 Ok(())
486 },
487 _ => Err(anyhow::anyhow!(
488 "Custom zero initialization only supports F32 tensors"
489 )),
490 }
491 },
492 "ones" => {
493 match tensor {
495 Tensor::F32(data) => {
496 data.fill(1.0);
497 Ok(())
498 },
499 _ => Err(anyhow::anyhow!(
500 "Custom ones initialization only supports F32 tensors"
501 )),
502 }
503 },
504 "identity" => {
505 let shape = tensor.shape();
507 match tensor {
508 Tensor::F32(data) => {
509 if shape.len() == 2 && shape[0] == shape[1] {
510 data.fill(0.0);
512 let dim = shape[0];
513 for i in 0..dim {
514 data[i * dim + i] = 1.0;
515 }
516 Ok(())
517 } else {
518 Err(anyhow::anyhow!(
519 "Identity initialization requires square matrix"
520 ))
521 }
522 },
523 _ => Err(anyhow::anyhow!(
524 "Custom identity initialization only supports F32 tensors"
525 )),
526 }
527 },
528 _ => Err(anyhow::anyhow!(
529 "Unknown custom initialization strategy: {}",
530 name
531 )),
532 }
533 },
534 }
535 }
536
537 pub fn generation_config_for_task(task: &TaskType) -> GenerationConfig {
539 match task {
540 TaskType::TextGeneration => GenerationConfig {
541 max_new_tokens: 512,
542 temperature: 0.7,
543 top_p: 0.9,
544 repetition_penalty: 1.1,
545 ..GenerationConfig::default()
546 },
547 TaskType::CodeGeneration => GenerationConfig {
548 max_new_tokens: 1024,
549 temperature: 0.2,
550 top_p: 0.95,
551 repetition_penalty: 1.05,
552 ..GenerationConfig::default()
553 },
554 TaskType::Summarization => GenerationConfig {
555 max_new_tokens: 256,
556 temperature: 0.3,
557 top_p: 0.9,
558 repetition_penalty: 1.2,
559 ..GenerationConfig::default()
560 },
561 TaskType::QuestionAnswering => GenerationConfig {
562 max_new_tokens: 128,
563 temperature: 0.1,
564 top_p: 0.95,
565 repetition_penalty: 1.0,
566 early_stopping: true,
567 ..GenerationConfig::default()
568 },
569 _ => GenerationConfig::default(),
570 }
571 }
572
573 pub fn validate_config(config: &dyn DynConfig) -> Result<Vec<String>> {
575 let warnings = Vec::new();
576
577 config.validate()?;
579
580 Ok(warnings)
584 }
585
586 pub fn estimate_memory_requirements(
588 vocab_size: usize,
589 hidden_size: usize,
590 num_layers: usize,
591 context_length: usize,
592 ) -> MemoryEstimate {
593 let embedding_params = vocab_size * hidden_size;
595 let layer_params = num_layers
596 * (
597 hidden_size * hidden_size * 4 +
599 hidden_size * hidden_size * 4 * 2 +
601 hidden_size * 2
603 );
604 let output_head_params = vocab_size * hidden_size;
605
606 let total_params = embedding_params + layer_params + output_head_params;
607
608 let model_memory_gb = (total_params * 4) as f32 / 1_000_000_000.0; let activation_memory_gb =
611 (context_length * hidden_size * num_layers * 4) as f32 / 1_000_000_000.0;
612
613 MemoryEstimate {
614 total_parameters: total_params,
615 model_memory_gb,
616 activation_memory_gb,
617 total_memory_gb: model_memory_gb + activation_memory_gb,
618 inference_memory_gb: model_memory_gb + activation_memory_gb * 0.5,
619 }
620 }
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
625pub struct MemoryEstimate {
626 pub total_parameters: usize,
627 pub model_memory_gb: f32,
628 pub activation_memory_gb: f32,
629 pub total_memory_gb: f32,
630 pub inference_memory_gb: f32,
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize)]
635pub enum GenerationStrategy {
636 Greedy,
638 Sampling { temperature: f32 },
640 TopK { k: usize, temperature: f32 },
642 TopP { p: f32, temperature: f32 },
644 BeamSearch { num_beams: usize },
646 DiverseBeamSearch {
648 num_beams: usize,
649 diversity_penalty: f32,
650 },
651 ContrastiveSearch { penalty_alpha: f32, top_k: usize },
653}
654
655pub mod components {
657 use super::*;
658
659 pub trait TransformerLayer {
661 fn forward(&self, input: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor>;
663
664 fn config(&self) -> &dyn DynConfig;
666 }
667
668 pub trait AttentionMechanism {
670 fn compute_attention(&self, query: &Tensor, key: &Tensor, value: &Tensor)
672 -> Result<Tensor>;
673
674 fn apply_mask(&self, attention_weights: &Tensor, mask: &Tensor) -> Result<Tensor>;
676 }
677
678 pub trait FeedForwardNetwork {
680 fn forward(&self, input: &Tensor) -> Result<Tensor>;
682
683 fn hidden_size(&self) -> usize;
685 fn intermediate_size(&self) -> usize;
686 }
687
688 pub trait EmbeddingLayer {
690 fn forward(&self, input_ids: &Tensor) -> Result<Tensor>;
692
693 fn vocab_size(&self) -> usize;
695
696 fn embedding_dim(&self) -> usize;
698 }
699}
700
701pub trait ModelFamilyImpl: Send + Sync {
703 fn create_config(&self, size: &str, variant: Option<&str>) -> Result<Box<dyn DynConfig>>;
705
706 fn family_name(&self) -> &'static str;
708
709 fn available_sizes(&self) -> Vec<&'static str>;
711
712 fn available_variants(&self) -> Vec<&'static str>;
714
715 fn use_cases(&self) -> Vec<&'static str>;
717
718 fn metadata(&self) -> ModelFamilyMetadata;
720}
721
722pub struct ModelFamilyWrapper<T: ModelFamily>(std::marker::PhantomData<T>);
724
725impl<T: ModelFamily> Default for ModelFamilyWrapper<T> {
726 fn default() -> Self {
727 Self::new()
728 }
729}
730
731impl<T: ModelFamily> ModelFamilyWrapper<T> {
732 pub fn new() -> Self {
733 Self(std::marker::PhantomData)
734 }
735}
736
737impl<T: ModelFamily> ModelFamilyImpl for ModelFamilyWrapper<T> {
738 fn create_config(&self, size: &str, variant: Option<&str>) -> Result<Box<dyn DynConfig>> {
739 T::create_config(size, variant)
740 }
741
742 fn family_name(&self) -> &'static str {
743 T::family_name()
744 }
745
746 fn available_sizes(&self) -> Vec<&'static str> {
747 T::available_sizes()
748 }
749
750 fn available_variants(&self) -> Vec<&'static str> {
751 T::available_variants()
752 }
753
754 fn use_cases(&self) -> Vec<&'static str> {
755 T::use_cases()
756 }
757
758 fn metadata(&self) -> ModelFamilyMetadata {
759 T::metadata()
760 }
761}
762
763pub struct ModelRegistry {
765 families: HashMap<String, Box<dyn ModelFamilyImpl>>,
766}
767
768impl Default for ModelRegistry {
769 fn default() -> Self {
770 Self::new()
771 }
772}
773
774impl ModelRegistry {
775 pub fn new() -> Self {
776 Self {
777 families: HashMap::new(),
778 }
779 }
780
781 pub fn register_family<F: ModelFamilyImpl + 'static>(&mut self, family: F) {
782 let name = family.family_name().to_string();
783 self.families.insert(name, Box::new(family));
784 }
785
786 pub fn register_model_family<F: ModelFamily + 'static>(&mut self) {
787 let wrapper = ModelFamilyWrapper::<F>::new();
788 self.register_family(wrapper);
789 }
790
791 pub fn get_family(&self, name: &str) -> Option<&dyn ModelFamilyImpl> {
792 self.families.get(name).map(|f| f.as_ref())
793 }
794
795 pub fn list_families(&self) -> Vec<&str> {
796 self.families.keys().map(|s| s.as_str()).collect()
797 }
798
799 pub fn create_model(
800 &self,
801 family: &str,
802 size: &str,
803 variant: Option<&str>,
804 ) -> Result<Box<dyn DynConfig>> {
805 let family_impl = self
806 .get_family(family)
807 .ok_or_else(|| anyhow::anyhow!("Unknown model family: {}", family))?;
808
809 family_impl.create_config(size, variant)
810 }
811}
812
813static MODEL_REGISTRY: OnceLock<Mutex<ModelRegistry>> = OnceLock::new();
815
816pub fn get_global_registry() -> &'static Mutex<ModelRegistry> {
817 MODEL_REGISTRY.get_or_init(|| Mutex::new(ModelRegistry::new()))
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_generation_config_default() {
826 let config = GenerationConfig::default();
827 assert_eq!(config.max_new_tokens, 100);
828 assert_eq!(config.temperature, 1.0);
829 assert!(config.do_sample);
830 }
831
832 #[test]
833 fn test_task_specific_generation_config() {
834 let code_config = ModelUtils::generation_config_for_task(&TaskType::CodeGeneration);
835 assert_eq!(code_config.temperature, 0.2);
836 assert_eq!(code_config.max_new_tokens, 1024);
837
838 let qa_config = ModelUtils::generation_config_for_task(&TaskType::QuestionAnswering);
839 assert_eq!(qa_config.temperature, 0.1);
840 assert!(qa_config.early_stopping);
841 }
842
843 #[test]
844 fn test_memory_estimation() {
845 let estimate = ModelUtils::estimate_memory_requirements(32000, 4096, 32, 2048);
846 assert!(estimate.total_parameters > 0);
847 assert!(estimate.model_memory_gb > 0.0);
848 assert!(estimate.total_memory_gb >= estimate.model_memory_gb);
849 }
850
851 #[test]
852 fn test_model_registry() {
853 let registry = ModelRegistry::new();
854 assert_eq!(registry.list_families().len(), 0);
856 }
857
858 #[test]
859 fn test_initialization_strategy() {
860 let strategy = InitializationStrategy::Normal { std: 0.02 };
861 match strategy {
862 InitializationStrategy::Normal { std } => assert_eq!(std, 0.02),
863 _ => panic!("Wrong strategy type"),
864 }
865 }
866
867 #[test]
868 fn test_architecture_types() {
869 let encoder_only = ArchitectureType::EncoderOnly;
870 let decoder_only = ArchitectureType::DecoderOnly;
871 assert_ne!(encoder_only, decoder_only);
872 }
873}