1use crate::{
10 AttentionConfig, DecoderLayerConfig, DecoderStackConfig, EncoderLayerConfig,
11 EncoderStackConfig, FeedForwardConfig, LayerNormConfig,
12};
13
14#[derive(Clone, Debug, PartialEq)]
16pub struct ModelStats {
17 pub total_params: usize,
19 pub trainable_params: usize,
21 pub num_layers: usize,
23 pub d_model: usize,
25 pub memory_estimate: usize,
27}
28
29impl ModelStats {
30 pub fn summary(&self) -> String {
32 format!(
33 "ModelStats:\n Total params: {}\n Trainable: {}\n Layers: {}\n d_model: {}\n Memory: {} MB",
34 Self::format_number(self.total_params),
35 Self::format_number(self.trainable_params),
36 self.num_layers,
37 self.d_model,
38 self.memory_estimate / (1024 * 1024)
39 )
40 }
41
42 fn format_number(n: usize) -> String {
43 if n >= 1_000_000_000 {
44 format!("{:.2}B", n as f64 / 1_000_000_000.0)
45 } else if n >= 1_000_000 {
46 format!("{:.2}M", n as f64 / 1_000_000.0)
47 } else if n >= 1_000 {
48 format!("{:.2}K", n as f64 / 1_000.0)
49 } else {
50 n.to_string()
51 }
52 }
53}
54
55pub fn count_attention_params(config: &AttentionConfig) -> usize {
57 let d_model = config.d_model;
58
59 let qkv_params = 3 * d_model * d_model;
61
62 let out_params = d_model * d_model;
64
65 let bias_params = 4 * d_model;
67
68 qkv_params + out_params + bias_params
69}
70
71pub fn count_ffn_params(config: &FeedForwardConfig) -> usize {
73 let d_model = config.d_model;
74 let d_ff = config.d_ff;
75
76 let layer1_params = d_model * d_ff + d_ff;
78
79 let layer2_params = d_ff * d_model + d_model;
81
82 layer1_params + layer2_params
83}
84
85pub fn count_layernorm_params(config: &LayerNormConfig) -> usize {
87 if config.elementwise_affine {
88 2 * config.normalized_shape
90 } else {
91 0
92 }
93}
94
95pub fn count_encoder_layer_params(config: &EncoderLayerConfig) -> usize {
97 let attention_params = count_attention_params(&config.attention);
98 let ffn_params = count_ffn_params(&config.feed_forward);
99 let ln1_params = count_layernorm_params(&config.layer_norm);
100 let ln2_params = count_layernorm_params(&config.layer_norm);
101
102 attention_params + ffn_params + ln1_params + ln2_params
103}
104
105pub fn count_decoder_layer_params(config: &DecoderLayerConfig) -> usize {
107 let self_attn_params = count_attention_params(&config.self_attention);
108 let cross_attn_params = count_attention_params(&config.cross_attention);
109 let ffn_params = count_ffn_params(&config.feed_forward);
110 let ln1_params = count_layernorm_params(&config.layer_norm);
111 let ln2_params = count_layernorm_params(&config.layer_norm);
112 let ln3_params = count_layernorm_params(&config.layer_norm);
113
114 self_attn_params + cross_attn_params + ffn_params + ln1_params + ln2_params + ln3_params
115}
116
117pub fn encoder_stack_stats(config: &EncoderStackConfig) -> ModelStats {
119 let layer_params = count_encoder_layer_params(&config.layer_config);
120 let total_layers_params = layer_params * config.num_layers;
121
122 let pos_encoding_params = match config.position_encoding.encoding_type {
124 crate::position::PositionEncodingType::Learned => {
125 config.position_encoding.max_seq_len * config.position_encoding.d_model
126 }
127 _ => 0, };
129
130 let final_norm_params = if config.final_layer_norm {
132 count_layernorm_params(&LayerNormConfig::new(config.layer_config.attention.d_model))
133 } else {
134 0
135 };
136
137 let total_params = total_layers_params + pos_encoding_params + final_norm_params;
138
139 let memory_estimate = total_params * 4;
141
142 ModelStats {
143 total_params,
144 trainable_params: total_params,
145 num_layers: config.num_layers,
146 d_model: config.layer_config.attention.d_model,
147 memory_estimate,
148 }
149}
150
151pub fn decoder_stack_stats(config: &DecoderStackConfig) -> ModelStats {
153 let layer_params = count_decoder_layer_params(&config.layer_config);
154 let total_layers_params = layer_params * config.num_layers;
155
156 let pos_encoding_params = match config.position_encoding.encoding_type {
158 crate::position::PositionEncodingType::Learned => {
159 config.position_encoding.max_seq_len * config.position_encoding.d_model
160 }
161 _ => 0,
162 };
163
164 let final_norm_params = if config.final_layer_norm {
166 count_layernorm_params(&LayerNormConfig::new(
167 config.layer_config.self_attention.d_model,
168 ))
169 } else {
170 0
171 };
172
173 let total_params = total_layers_params + pos_encoding_params + final_norm_params;
174 let memory_estimate = total_params * 4;
175
176 ModelStats {
177 total_params,
178 trainable_params: total_params,
179 num_layers: config.num_layers,
180 d_model: config.layer_config.self_attention.d_model,
181 memory_estimate,
182 }
183}
184
185pub fn attention_flops(batch_size: usize, seq_len: usize, d_model: usize) -> usize {
189 4 * batch_size * seq_len * seq_len * d_model
190}
191
192pub fn ffn_flops(batch_size: usize, seq_len: usize, d_model: usize, d_ff: usize) -> usize {
196 2 * batch_size * seq_len * (d_model * d_ff + d_ff * d_model)
197}
198
199pub fn layer_flops(batch_size: usize, seq_len: usize, config: &EncoderLayerConfig) -> usize {
201 let attn = attention_flops(batch_size, seq_len, config.attention.d_model);
202 let ffn = ffn_flops(
203 batch_size,
204 seq_len,
205 config.feed_forward.d_model,
206 config.feed_forward.d_ff,
207 );
208 attn + ffn
209}
210
211pub fn validate_encoder_decoder_compatibility(
213 encoder: &EncoderStackConfig,
214 decoder: &DecoderStackConfig,
215) -> Result<(), String> {
216 if encoder.layer_config.attention.d_model != decoder.layer_config.self_attention.d_model {
218 return Err(format!(
219 "d_model mismatch: encoder={}, decoder={}",
220 encoder.layer_config.attention.d_model, decoder.layer_config.self_attention.d_model
221 ));
222 }
223
224 if !decoder.layer_config.self_attention.causal {
226 return Err("Decoder self-attention must use causal masking".to_string());
227 }
228
229 Ok(())
230}
231
232pub mod presets {
234 use super::*;
235
236 pub fn gpt2_small() -> EncoderStackConfig {
238 EncoderStackConfig::new(
239 12, 768, 12, 3072, 1024, )
245 .unwrap()
246 .with_dropout(0.1)
247 }
248
249 pub fn bert_base() -> EncoderStackConfig {
251 EncoderStackConfig::new(
252 12, 768, 12, 3072, 512, )
258 .unwrap()
259 .with_dropout(0.1)
260 }
261
262 pub fn transformer_base() -> (EncoderStackConfig, DecoderStackConfig) {
264 let encoder = EncoderStackConfig::new(6, 512, 8, 2048, 512)
265 .unwrap()
266 .with_dropout(0.1);
267
268 let decoder = DecoderStackConfig::new(6, 512, 8, 2048, 512)
269 .unwrap()
270 .with_dropout(0.1);
271
272 (encoder, decoder)
273 }
274
275 pub fn tiny() -> EncoderStackConfig {
277 EncoderStackConfig::new(2, 128, 4, 512, 128)
278 .unwrap()
279 .with_dropout(0.0)
280 }
281
282 pub fn bert_large() -> EncoderStackConfig {
284 EncoderStackConfig::new(
285 24, 1024, 16, 4096, 512, )
291 .unwrap()
292 .with_dropout(0.1)
293 }
294
295 pub fn gpt2_medium() -> EncoderStackConfig {
297 EncoderStackConfig::new(
298 24, 1024, 16, 4096, 1024, )
304 .unwrap()
305 .with_dropout(0.1)
306 }
307
308 pub fn gpt2_large() -> EncoderStackConfig {
310 EncoderStackConfig::new(
311 36, 1280, 20, 5120, 1024, )
317 .unwrap()
318 .with_dropout(0.1)
319 }
320
321 pub fn gpt2_xl() -> EncoderStackConfig {
323 EncoderStackConfig::new(
324 48, 1600, 25, 6400, 1024, )
330 .unwrap()
331 .with_dropout(0.1)
332 }
333
334 pub fn gpt3_small() -> EncoderStackConfig {
336 EncoderStackConfig::new(
337 12, 768, 12, 3072, 2048, )
343 .unwrap()
344 .with_dropout(0.0)
345 }
346
347 pub fn gpt3_medium() -> EncoderStackConfig {
349 EncoderStackConfig::new(
350 24, 1024, 16, 4096, 2048, )
356 .unwrap()
357 .with_dropout(0.0)
358 }
359
360 pub fn gpt3_large() -> EncoderStackConfig {
362 EncoderStackConfig::new(
363 24, 1536, 16, 6144, 2048, )
369 .unwrap()
370 .with_dropout(0.0)
371 }
372
373 pub fn gpt3_xl() -> EncoderStackConfig {
375 EncoderStackConfig::new(
376 24, 2048, 16, 8192, 2048, )
382 .unwrap()
383 .with_dropout(0.0)
384 }
385
386 pub fn gpt3_2_7b() -> EncoderStackConfig {
388 EncoderStackConfig::new(
389 32, 2560, 32, 10240, 2048, )
395 .unwrap()
396 .with_dropout(0.0)
397 }
398
399 pub fn gpt3_6_7b() -> EncoderStackConfig {
401 EncoderStackConfig::new(
402 32, 4096, 32, 16384, 2048, )
408 .unwrap()
409 .with_dropout(0.0)
410 }
411
412 pub fn gpt3_13b() -> EncoderStackConfig {
414 EncoderStackConfig::new(
415 40, 5120, 40, 20480, 2048, )
421 .unwrap()
422 .with_dropout(0.0)
423 }
424
425 pub fn gpt3_175b() -> EncoderStackConfig {
427 EncoderStackConfig::new(
428 96, 12288, 96, 49152, 2048, )
434 .unwrap()
435 .with_dropout(0.0)
436 }
437
438 pub fn llama_7b() -> EncoderStackConfig {
441 EncoderStackConfig::new(
442 32, 4096, 32, 11008, 2048, )
448 .unwrap()
449 .with_dropout(0.0)
450 .with_learned_position_encoding() }
452
453 pub fn llama_13b() -> EncoderStackConfig {
455 EncoderStackConfig::new(
456 40, 5120, 40, 13824, 2048, )
462 .unwrap()
463 .with_dropout(0.0)
464 .with_learned_position_encoding()
465 }
466
467 pub fn llama_30b() -> EncoderStackConfig {
469 EncoderStackConfig::new(
470 60, 6656, 52, 17920, 2048, )
476 .unwrap()
477 .with_dropout(0.0)
478 .with_learned_position_encoding()
479 }
480
481 pub fn llama_65b() -> EncoderStackConfig {
483 EncoderStackConfig::new(
484 80, 8192, 64, 22016, 2048, )
490 .unwrap()
491 .with_dropout(0.0)
492 .with_learned_position_encoding()
493 }
494
495 pub fn bloom_560m() -> EncoderStackConfig {
497 EncoderStackConfig::new(
498 24, 1024, 16, 4096, 2048, )
504 .unwrap()
505 .with_dropout(0.0)
506 }
508
509 pub fn bloom_3b() -> EncoderStackConfig {
511 EncoderStackConfig::new(
512 30, 2560, 32, 10240, 2048, )
518 .unwrap()
519 .with_dropout(0.0)
520 }
521
522 pub fn bloom_7b() -> EncoderStackConfig {
524 EncoderStackConfig::new(
525 30, 4096, 32, 16384, 2048, )
531 .unwrap()
532 .with_dropout(0.0)
533 }
534
535 pub fn t5_small() -> (EncoderStackConfig, DecoderStackConfig) {
537 let encoder = EncoderStackConfig::new(6, 512, 8, 2048, 512)
538 .unwrap()
539 .with_dropout(0.1);
540
541 let decoder = DecoderStackConfig::new(6, 512, 8, 2048, 512)
542 .unwrap()
543 .with_dropout(0.1);
544
545 (encoder, decoder)
546 }
547
548 pub fn t5_base() -> (EncoderStackConfig, DecoderStackConfig) {
550 let encoder = EncoderStackConfig::new(12, 768, 12, 3072, 512)
551 .unwrap()
552 .with_dropout(0.1);
553
554 let decoder = DecoderStackConfig::new(12, 768, 12, 3072, 512)
555 .unwrap()
556 .with_dropout(0.1);
557
558 (encoder, decoder)
559 }
560
561 pub fn t5_large() -> (EncoderStackConfig, DecoderStackConfig) {
563 let encoder = EncoderStackConfig::new(24, 1024, 16, 4096, 512)
564 .unwrap()
565 .with_dropout(0.1);
566
567 let decoder = DecoderStackConfig::new(24, 1024, 16, 4096, 512)
568 .unwrap()
569 .with_dropout(0.1);
570
571 (encoder, decoder)
572 }
573
574 pub fn t5_xl() -> (EncoderStackConfig, DecoderStackConfig) {
576 let encoder = EncoderStackConfig::new(24, 2048, 32, 8192, 512)
577 .unwrap()
578 .with_dropout(0.1);
579
580 let decoder = DecoderStackConfig::new(24, 2048, 32, 8192, 512)
581 .unwrap()
582 .with_dropout(0.1);
583
584 (encoder, decoder)
585 }
586
587 pub fn t5_xxl() -> (EncoderStackConfig, DecoderStackConfig) {
589 let encoder = EncoderStackConfig::new(24, 4096, 64, 16384, 512)
590 .unwrap()
591 .with_dropout(0.1);
592
593 let decoder = DecoderStackConfig::new(24, 4096, 64, 16384, 512)
594 .unwrap()
595 .with_dropout(0.1);
596
597 (encoder, decoder)
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_count_attention_params() {
607 let config = AttentionConfig::new(512, 8).unwrap();
608 let params = count_attention_params(&config);
609
610 assert_eq!(params, 1_050_624);
615 }
616
617 #[test]
618 fn test_count_ffn_params() {
619 let config = FeedForwardConfig::new(512, 2048);
620 let params = count_ffn_params(&config);
621
622 assert_eq!(params, 2_099_712);
626 }
627
628 #[test]
629 fn test_count_layernorm_params() {
630 let config = LayerNormConfig::new(512);
631 let params = count_layernorm_params(&config);
632 assert_eq!(params, 1024); let config_no_affine = LayerNormConfig::new(512).with_elementwise_affine(false);
635 let params = count_layernorm_params(&config_no_affine);
636 assert_eq!(params, 0);
637 }
638
639 #[test]
640 fn test_encoder_layer_params() {
641 let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
642 let params = count_encoder_layer_params(&config);
643
644 assert_eq!(params, 3_152_384);
650 }
651
652 #[test]
653 fn test_encoder_stack_stats() {
654 let config = EncoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
655 let stats = encoder_stack_stats(&config);
656
657 assert_eq!(stats.num_layers, 6);
658 assert_eq!(stats.d_model, 512);
659 assert!(stats.total_params > 0);
660 assert_eq!(stats.trainable_params, stats.total_params);
661 }
662
663 #[test]
664 fn test_decoder_stack_stats() {
665 let config = DecoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
666 let stats = decoder_stack_stats(&config);
667
668 assert_eq!(stats.num_layers, 6);
669 assert_eq!(stats.d_model, 512);
670 assert!(stats.total_params > 0);
672 }
673
674 #[test]
675 fn test_flops_calculations() {
676 let batch = 32;
677 let seq_len = 128;
678 let d_model = 512;
679 let d_ff = 2048;
680
681 let attn_flops = attention_flops(batch, seq_len, d_model);
682 assert!(attn_flops > 0);
683
684 let ffn_flops = ffn_flops(batch, seq_len, d_model, d_ff);
685 assert!(ffn_flops > 0);
686 }
687
688 #[test]
689 fn test_validate_compatibility() {
690 let encoder = EncoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
691 let decoder = DecoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
692
693 assert!(validate_encoder_decoder_compatibility(&encoder, &decoder).is_ok());
694
695 let encoder_mismatch = EncoderStackConfig::new(6, 768, 8, 2048, 512).unwrap();
697 assert!(validate_encoder_decoder_compatibility(&encoder_mismatch, &decoder).is_err());
698 }
699
700 #[test]
701 fn test_presets() {
702 let gpt2 = presets::gpt2_small();
703 assert_eq!(gpt2.num_layers, 12);
704 assert_eq!(gpt2.layer_config.attention.d_model, 768);
705
706 let bert = presets::bert_base();
707 assert_eq!(bert.num_layers, 12);
708 assert_eq!(bert.layer_config.attention.d_model, 768);
709
710 let (encoder, decoder) = presets::transformer_base();
711 assert_eq!(encoder.num_layers, 6);
712 assert_eq!(decoder.num_layers, 6);
713 assert!(validate_encoder_decoder_compatibility(&encoder, &decoder).is_ok());
714 }
715
716 #[test]
717 fn test_model_stats_summary() {
718 let config = EncoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
719 let stats = encoder_stack_stats(&config);
720 let summary = stats.summary();
721
722 assert!(summary.contains("ModelStats"));
723 assert!(summary.contains("Total params"));
724 assert!(summary.contains("Layers: 6"));
725 }
726
727 #[test]
728 fn test_format_number() {
729 let stats = ModelStats {
730 total_params: 117_000_000,
731 trainable_params: 117_000_000,
732 num_layers: 12,
733 d_model: 768,
734 memory_estimate: 468_000_000,
735 };
736
737 let summary = stats.summary();
738 assert!(summary.contains("117.00M"));
739 }
740
741 #[test]
742 fn test_bert_large_preset() {
743 let config = presets::bert_large();
744 assert_eq!(config.num_layers, 24);
745 assert_eq!(config.layer_config.attention.d_model, 1024);
746 assert_eq!(config.layer_config.attention.n_heads, 16);
747 }
748
749 #[test]
750 fn test_gpt2_variants() {
751 let small = presets::gpt2_small();
752 let medium = presets::gpt2_medium();
753 let large = presets::gpt2_large();
754 let xl = presets::gpt2_xl();
755
756 let small_stats = encoder_stack_stats(&small);
758 let medium_stats = encoder_stack_stats(&medium);
759 let large_stats = encoder_stack_stats(&large);
760 let xl_stats = encoder_stack_stats(&xl);
761
762 assert!(medium_stats.total_params > small_stats.total_params);
763 assert!(large_stats.total_params > medium_stats.total_params);
764 assert!(xl_stats.total_params > large_stats.total_params);
765 }
766
767 #[test]
768 fn test_gpt3_variants() {
769 let small = presets::gpt3_small();
770 let medium = presets::gpt3_medium();
771 let large = presets::gpt3_large();
772 let xl = presets::gpt3_xl();
773
774 assert_eq!(small.num_layers, 12);
775 assert_eq!(medium.num_layers, 24);
776 assert_eq!(large.num_layers, 24);
777 assert_eq!(xl.num_layers, 24);
778
779 assert!(medium.layer_config.attention.d_model > small.layer_config.attention.d_model);
781 assert!(large.layer_config.attention.d_model > medium.layer_config.attention.d_model);
782 assert!(xl.layer_config.attention.d_model > large.layer_config.attention.d_model);
783 }
784
785 #[test]
786 fn test_gpt3_large_models() {
787 let m2_7b = presets::gpt3_2_7b();
788 let m6_7b = presets::gpt3_6_7b();
789 let m13b = presets::gpt3_13b();
790 let m175b = presets::gpt3_175b();
791
792 assert_eq!(m2_7b.num_layers, 32);
793 assert_eq!(m6_7b.num_layers, 32);
794 assert_eq!(m13b.num_layers, 40);
795 assert_eq!(m175b.num_layers, 96);
796
797 assert!(m6_7b.layer_config.attention.d_model > m2_7b.layer_config.attention.d_model);
799 assert!(m13b.layer_config.attention.d_model > m6_7b.layer_config.attention.d_model);
800 assert!(m175b.layer_config.attention.d_model > m13b.layer_config.attention.d_model);
801 }
802
803 #[test]
804 fn test_llama_variants() {
805 let m7b = presets::llama_7b();
806 let m13b = presets::llama_13b();
807 let m30b = presets::llama_30b();
808 let m65b = presets::llama_65b();
809
810 assert!(m13b.num_layers > m7b.num_layers);
812 assert!(m30b.num_layers > m13b.num_layers);
813 assert!(m65b.num_layers > m30b.num_layers);
814
815 assert!(m13b.layer_config.attention.d_model > m7b.layer_config.attention.d_model);
817 assert!(m30b.layer_config.attention.d_model > m13b.layer_config.attention.d_model);
818 assert!(m65b.layer_config.attention.d_model > m30b.layer_config.attention.d_model);
819
820 assert!(matches!(
822 m7b.position_encoding.encoding_type,
823 crate::position::PositionEncodingType::Learned
824 ));
825 }
826
827 #[test]
828 fn test_bloom_variants() {
829 let m560m = presets::bloom_560m();
830 let m3b = presets::bloom_3b();
831 let m7b = presets::bloom_7b();
832
833 assert_eq!(m560m.num_layers, 24);
834 assert_eq!(m3b.num_layers, 30);
835 assert_eq!(m7b.num_layers, 30);
836
837 assert!(m3b.layer_config.attention.d_model > m560m.layer_config.attention.d_model);
839 assert!(m7b.layer_config.attention.d_model > m3b.layer_config.attention.d_model);
840 }
841
842 #[test]
843 fn test_t5_variants() {
844 let small = presets::t5_small();
845 let base = presets::t5_base();
846 let large = presets::t5_large();
847 let xl = presets::t5_xl();
848 let xxl = presets::t5_xxl();
849
850 assert!(validate_encoder_decoder_compatibility(&small.0, &small.1).is_ok());
852 assert!(validate_encoder_decoder_compatibility(&base.0, &base.1).is_ok());
853 assert!(validate_encoder_decoder_compatibility(&large.0, &large.1).is_ok());
854 assert!(validate_encoder_decoder_compatibility(&xl.0, &xl.1).is_ok());
855 assert!(validate_encoder_decoder_compatibility(&xxl.0, &xxl.1).is_ok());
856
857 let small_stats = encoder_stack_stats(&small.0);
859 let base_stats = encoder_stack_stats(&base.0);
860 let large_stats = encoder_stack_stats(&large.0);
861
862 assert!(base_stats.total_params > small_stats.total_params);
863 assert!(large_stats.total_params > base_stats.total_params);
864 }
865
866 #[test]
867 fn test_all_presets_validate() {
868 assert!(presets::tiny().validate().is_ok());
870 assert!(presets::gpt2_small().validate().is_ok());
871 assert!(presets::bert_base().validate().is_ok());
872 assert!(presets::bert_large().validate().is_ok());
873 assert!(presets::gpt2_medium().validate().is_ok());
874 assert!(presets::gpt2_large().validate().is_ok());
875 assert!(presets::gpt2_xl().validate().is_ok());
876 assert!(presets::gpt3_small().validate().is_ok());
877 assert!(presets::gpt3_medium().validate().is_ok());
878 assert!(presets::gpt3_large().validate().is_ok());
879 assert!(presets::gpt3_xl().validate().is_ok());
880 assert!(presets::llama_7b().validate().is_ok());
881 assert!(presets::llama_13b().validate().is_ok());
882 assert!(presets::bloom_560m().validate().is_ok());
883 assert!(presets::bloom_3b().validate().is_ok());
884
885 let (enc, dec) = presets::transformer_base();
886 assert!(enc.validate().is_ok());
887 assert!(dec.validate().is_ok());
888 }
889}