Skip to main content

tensorlogic_trustformers/
utils.rs

1//! Utility functions for transformer models.
2//!
3//! This module provides helper functions for common transformer operations:
4//! - Parameter counting
5//! - Configuration validation
6//! - Dimension calculations
7//! - Model statistics
8
9use crate::{
10    AttentionConfig, DecoderLayerConfig, DecoderStackConfig, EncoderLayerConfig,
11    EncoderStackConfig, FeedForwardConfig, LayerNormConfig,
12};
13
14/// Statistics about a transformer model
15#[derive(Clone, Debug, PartialEq)]
16pub struct ModelStats {
17    /// Total number of parameters
18    pub total_params: usize,
19    /// Number of trainable parameters
20    pub trainable_params: usize,
21    /// Number of layers
22    pub num_layers: usize,
23    /// Model dimension
24    pub d_model: usize,
25    /// Memory footprint estimate (bytes)
26    pub memory_estimate: usize,
27}
28
29impl ModelStats {
30    /// Format as human-readable string
31    pub fn summary(&self) -> String {
32        format!(
33            "ModelStats:\n  Total params: {}\n  Trainable: {}\n  Layers: {}\n  d_model: {}\n  Memory: {} MB",
34            Self::format_number(self.total_params),
35            Self::format_number(self.trainable_params),
36            self.num_layers,
37            self.d_model,
38            self.memory_estimate / (1024 * 1024)
39        )
40    }
41
42    fn format_number(n: usize) -> String {
43        if n >= 1_000_000_000 {
44            format!("{:.2}B", n as f64 / 1_000_000_000.0)
45        } else if n >= 1_000_000 {
46            format!("{:.2}M", n as f64 / 1_000_000.0)
47        } else if n >= 1_000 {
48            format!("{:.2}K", n as f64 / 1_000.0)
49        } else {
50            n.to_string()
51        }
52    }
53}
54
55/// Count parameters in attention layer
56pub fn count_attention_params(config: &AttentionConfig) -> usize {
57    let d_model = config.d_model;
58
59    // Q, K, V projection matrices: 3 * (d_model * d_model)
60    let qkv_params = 3 * d_model * d_model;
61
62    // Output projection: d_model * d_model
63    let out_params = d_model * d_model;
64
65    // Biases for Q, K, V, and output (optional, typically included)
66    let bias_params = 4 * d_model;
67
68    qkv_params + out_params + bias_params
69}
70
71/// Count parameters in feed-forward network
72pub fn count_ffn_params(config: &FeedForwardConfig) -> usize {
73    let d_model = config.d_model;
74    let d_ff = config.d_ff;
75
76    // First layer: d_model * d_ff + d_ff (weights + bias)
77    let layer1_params = d_model * d_ff + d_ff;
78
79    // Second layer: d_ff * d_model + d_model (weights + bias)
80    let layer2_params = d_ff * d_model + d_model;
81
82    layer1_params + layer2_params
83}
84
85/// Count parameters in layer normalization
86pub fn count_layernorm_params(config: &LayerNormConfig) -> usize {
87    if config.elementwise_affine {
88        // Gamma (scale) and beta (shift)
89        2 * config.normalized_shape
90    } else {
91        0
92    }
93}
94
95/// Count parameters in encoder layer
96pub fn count_encoder_layer_params(config: &EncoderLayerConfig) -> usize {
97    let attention_params = count_attention_params(&config.attention);
98    let ffn_params = count_ffn_params(&config.feed_forward);
99    let ln1_params = count_layernorm_params(&config.layer_norm);
100    let ln2_params = count_layernorm_params(&config.layer_norm);
101
102    attention_params + ffn_params + ln1_params + ln2_params
103}
104
105/// Count parameters in decoder layer
106pub fn count_decoder_layer_params(config: &DecoderLayerConfig) -> usize {
107    let self_attn_params = count_attention_params(&config.self_attention);
108    let cross_attn_params = count_attention_params(&config.cross_attention);
109    let ffn_params = count_ffn_params(&config.feed_forward);
110    let ln1_params = count_layernorm_params(&config.layer_norm);
111    let ln2_params = count_layernorm_params(&config.layer_norm);
112    let ln3_params = count_layernorm_params(&config.layer_norm);
113
114    self_attn_params + cross_attn_params + ffn_params + ln1_params + ln2_params + ln3_params
115}
116
117/// Get statistics for encoder stack
118pub fn encoder_stack_stats(config: &EncoderStackConfig) -> ModelStats {
119    let layer_params = count_encoder_layer_params(&config.layer_config);
120    let total_layers_params = layer_params * config.num_layers;
121
122    // Position encoding parameters (if learned)
123    let pos_encoding_params = match config.position_encoding.encoding_type {
124        crate::position::PositionEncodingType::Learned => {
125            config.position_encoding.max_seq_len * config.position_encoding.d_model
126        }
127        _ => 0, // Sinusoidal and relative don't have learned parameters
128    };
129
130    // Final layer norm (if enabled)
131    let final_norm_params = if config.final_layer_norm {
132        count_layernorm_params(&LayerNormConfig::new(config.layer_config.attention.d_model))
133    } else {
134        0
135    };
136
137    let total_params = total_layers_params + pos_encoding_params + final_norm_params;
138
139    // Memory estimate: 4 bytes per parameter (float32)
140    let memory_estimate = total_params * 4;
141
142    ModelStats {
143        total_params,
144        trainable_params: total_params,
145        num_layers: config.num_layers,
146        d_model: config.layer_config.attention.d_model,
147        memory_estimate,
148    }
149}
150
151/// Get statistics for decoder stack
152pub fn decoder_stack_stats(config: &DecoderStackConfig) -> ModelStats {
153    let layer_params = count_decoder_layer_params(&config.layer_config);
154    let total_layers_params = layer_params * config.num_layers;
155
156    // Position encoding parameters
157    let pos_encoding_params = match config.position_encoding.encoding_type {
158        crate::position::PositionEncodingType::Learned => {
159            config.position_encoding.max_seq_len * config.position_encoding.d_model
160        }
161        _ => 0,
162    };
163
164    // Final layer norm
165    let final_norm_params = if config.final_layer_norm {
166        count_layernorm_params(&LayerNormConfig::new(
167            config.layer_config.self_attention.d_model,
168        ))
169    } else {
170        0
171    };
172
173    let total_params = total_layers_params + pos_encoding_params + final_norm_params;
174    let memory_estimate = total_params * 4;
175
176    ModelStats {
177        total_params,
178        trainable_params: total_params,
179        num_layers: config.num_layers,
180        d_model: config.layer_config.self_attention.d_model,
181        memory_estimate,
182    }
183}
184
185/// Calculate FLOPs for attention operation
186///
187/// FLOPs for attention: 4 * batch * seq_len^2 * d_model
188pub fn attention_flops(batch_size: usize, seq_len: usize, d_model: usize) -> usize {
189    4 * batch_size * seq_len * seq_len * d_model
190}
191
192/// Calculate FLOPs for feed-forward network
193///
194/// FLOPs for FFN: 2 * batch * seq_len * (d_model * d_ff + d_ff * d_model)
195pub fn ffn_flops(batch_size: usize, seq_len: usize, d_model: usize, d_ff: usize) -> usize {
196    2 * batch_size * seq_len * (d_model * d_ff + d_ff * d_model)
197}
198
199/// Calculate total FLOPs for transformer layer
200pub fn layer_flops(batch_size: usize, seq_len: usize, config: &EncoderLayerConfig) -> usize {
201    let attn = attention_flops(batch_size, seq_len, config.attention.d_model);
202    let ffn = ffn_flops(
203        batch_size,
204        seq_len,
205        config.feed_forward.d_model,
206        config.feed_forward.d_ff,
207    );
208    attn + ffn
209}
210
211/// Validate configuration compatibility
212pub fn validate_encoder_decoder_compatibility(
213    encoder: &EncoderStackConfig,
214    decoder: &DecoderStackConfig,
215) -> Result<(), String> {
216    // Check d_model compatibility
217    if encoder.layer_config.attention.d_model != decoder.layer_config.self_attention.d_model {
218        return Err(format!(
219            "d_model mismatch: encoder={}, decoder={}",
220            encoder.layer_config.attention.d_model, decoder.layer_config.self_attention.d_model
221        ));
222    }
223
224    // Check that decoder uses causal masking
225    if !decoder.layer_config.self_attention.causal {
226        return Err("Decoder self-attention must use causal masking".to_string());
227    }
228
229    Ok(())
230}
231
232/// Helper to create common transformer configurations
233pub mod presets {
234    use super::*;
235
236    /// GPT-2 Small configuration (117M parameters)
237    pub fn gpt2_small() -> EncoderStackConfig {
238        EncoderStackConfig::new(
239            12,   // layers
240            768,  // d_model
241            12,   // n_heads
242            3072, // d_ff (4 * d_model)
243            1024, // max_seq_len
244        )
245        .unwrap()
246        .with_dropout(0.1)
247    }
248
249    /// BERT Base configuration (110M parameters)
250    pub fn bert_base() -> EncoderStackConfig {
251        EncoderStackConfig::new(
252            12,   // layers
253            768,  // d_model
254            12,   // n_heads
255            3072, // d_ff
256            512,  // max_seq_len
257        )
258        .unwrap()
259        .with_dropout(0.1)
260    }
261
262    /// Transformer Base (from "Attention Is All You Need")
263    pub fn transformer_base() -> (EncoderStackConfig, DecoderStackConfig) {
264        let encoder = EncoderStackConfig::new(6, 512, 8, 2048, 512)
265            .unwrap()
266            .with_dropout(0.1);
267
268        let decoder = DecoderStackConfig::new(6, 512, 8, 2048, 512)
269            .unwrap()
270            .with_dropout(0.1);
271
272        (encoder, decoder)
273    }
274
275    /// Small model for testing (faster)
276    pub fn tiny() -> EncoderStackConfig {
277        EncoderStackConfig::new(2, 128, 4, 512, 128)
278            .unwrap()
279            .with_dropout(0.0)
280    }
281
282    /// BERT Large configuration (340M parameters)
283    pub fn bert_large() -> EncoderStackConfig {
284        EncoderStackConfig::new(
285            24,   // layers
286            1024, // d_model
287            16,   // n_heads
288            4096, // d_ff
289            512,  // max_seq_len
290        )
291        .unwrap()
292        .with_dropout(0.1)
293    }
294
295    /// GPT-2 Medium configuration (345M parameters)
296    pub fn gpt2_medium() -> EncoderStackConfig {
297        EncoderStackConfig::new(
298            24,   // layers
299            1024, // d_model
300            16,   // n_heads
301            4096, // d_ff
302            1024, // max_seq_len
303        )
304        .unwrap()
305        .with_dropout(0.1)
306    }
307
308    /// GPT-2 Large configuration (774M parameters)
309    pub fn gpt2_large() -> EncoderStackConfig {
310        EncoderStackConfig::new(
311            36,   // layers
312            1280, // d_model
313            20,   // n_heads
314            5120, // d_ff
315            1024, // max_seq_len
316        )
317        .unwrap()
318        .with_dropout(0.1)
319    }
320
321    /// GPT-2 XL configuration (1.5B parameters)
322    pub fn gpt2_xl() -> EncoderStackConfig {
323        EncoderStackConfig::new(
324            48,   // layers
325            1600, // d_model
326            25,   // n_heads
327            6400, // d_ff
328            1024, // max_seq_len
329        )
330        .unwrap()
331        .with_dropout(0.1)
332    }
333
334    /// GPT-3 Small configuration (~125M parameters)
335    pub fn gpt3_small() -> EncoderStackConfig {
336        EncoderStackConfig::new(
337            12,   // layers
338            768,  // d_model
339            12,   // n_heads
340            3072, // d_ff
341            2048, // max_seq_len
342        )
343        .unwrap()
344        .with_dropout(0.0)
345    }
346
347    /// GPT-3 Medium configuration (~350M parameters)
348    pub fn gpt3_medium() -> EncoderStackConfig {
349        EncoderStackConfig::new(
350            24,   // layers
351            1024, // d_model
352            16,   // n_heads
353            4096, // d_ff
354            2048, // max_seq_len
355        )
356        .unwrap()
357        .with_dropout(0.0)
358    }
359
360    /// GPT-3 Large configuration (~760M parameters)
361    pub fn gpt3_large() -> EncoderStackConfig {
362        EncoderStackConfig::new(
363            24,   // layers
364            1536, // d_model
365            16,   // n_heads
366            6144, // d_ff
367            2048, // max_seq_len
368        )
369        .unwrap()
370        .with_dropout(0.0)
371    }
372
373    /// GPT-3 XL configuration (~1.3B parameters)
374    pub fn gpt3_xl() -> EncoderStackConfig {
375        EncoderStackConfig::new(
376            24,   // layers
377            2048, // d_model
378            16,   // n_heads (d_model must be divisible by n_heads)
379            8192, // d_ff
380            2048, // max_seq_len
381        )
382        .unwrap()
383        .with_dropout(0.0)
384    }
385
386    /// GPT-3 2.7B configuration
387    pub fn gpt3_2_7b() -> EncoderStackConfig {
388        EncoderStackConfig::new(
389            32,    // layers
390            2560,  // d_model
391            32,    // n_heads
392            10240, // d_ff
393            2048,  // max_seq_len
394        )
395        .unwrap()
396        .with_dropout(0.0)
397    }
398
399    /// GPT-3 6.7B configuration
400    pub fn gpt3_6_7b() -> EncoderStackConfig {
401        EncoderStackConfig::new(
402            32,    // layers
403            4096,  // d_model
404            32,    // n_heads
405            16384, // d_ff
406            2048,  // max_seq_len
407        )
408        .unwrap()
409        .with_dropout(0.0)
410    }
411
412    /// GPT-3 13B configuration
413    pub fn gpt3_13b() -> EncoderStackConfig {
414        EncoderStackConfig::new(
415            40,    // layers
416            5120,  // d_model (must be divisible by n_heads)
417            40,    // n_heads
418            20480, // d_ff (4 * d_model)
419            2048,  // max_seq_len
420        )
421        .unwrap()
422        .with_dropout(0.0)
423    }
424
425    /// GPT-3 175B configuration (davinci)
426    pub fn gpt3_175b() -> EncoderStackConfig {
427        EncoderStackConfig::new(
428            96,    // layers
429            12288, // d_model
430            96,    // n_heads
431            49152, // d_ff
432            2048,  // max_seq_len
433        )
434        .unwrap()
435        .with_dropout(0.0)
436    }
437
438    /// LLaMA 7B configuration
439    /// Uses RoPE (implemented separately in position module)
440    pub fn llama_7b() -> EncoderStackConfig {
441        EncoderStackConfig::new(
442            32,    // layers
443            4096,  // d_model
444            32,    // n_heads
445            11008, // d_ff (uses SwiGLU, ~2.7x d_model)
446            2048,  // max_seq_len (can be extended with RoPE)
447        )
448        .unwrap()
449        .with_dropout(0.0)
450        .with_learned_position_encoding() // Would use RoPE in practice
451    }
452
453    /// LLaMA 13B configuration
454    pub fn llama_13b() -> EncoderStackConfig {
455        EncoderStackConfig::new(
456            40,    // layers
457            5120,  // d_model
458            40,    // n_heads
459            13824, // d_ff
460            2048,  // max_seq_len
461        )
462        .unwrap()
463        .with_dropout(0.0)
464        .with_learned_position_encoding()
465    }
466
467    /// LLaMA 30B configuration
468    pub fn llama_30b() -> EncoderStackConfig {
469        EncoderStackConfig::new(
470            60,    // layers
471            6656,  // d_model
472            52,    // n_heads
473            17920, // d_ff
474            2048,  // max_seq_len
475        )
476        .unwrap()
477        .with_dropout(0.0)
478        .with_learned_position_encoding()
479    }
480
481    /// LLaMA 65B configuration
482    pub fn llama_65b() -> EncoderStackConfig {
483        EncoderStackConfig::new(
484            80,    // layers
485            8192,  // d_model
486            64,    // n_heads
487            22016, // d_ff
488            2048,  // max_seq_len
489        )
490        .unwrap()
491        .with_dropout(0.0)
492        .with_learned_position_encoding()
493    }
494
495    /// BLOOM 560M configuration (uses ALiBi)
496    pub fn bloom_560m() -> EncoderStackConfig {
497        EncoderStackConfig::new(
498            24,   // layers
499            1024, // d_model
500            16,   // n_heads
501            4096, // d_ff
502            2048, // max_seq_len
503        )
504        .unwrap()
505        .with_dropout(0.0)
506        // Note: BLOOM uses ALiBi position encoding (implemented in position module)
507    }
508
509    /// BLOOM 3B configuration
510    pub fn bloom_3b() -> EncoderStackConfig {
511        EncoderStackConfig::new(
512            30,    // layers
513            2560,  // d_model
514            32,    // n_heads
515            10240, // d_ff
516            2048,  // max_seq_len
517        )
518        .unwrap()
519        .with_dropout(0.0)
520    }
521
522    /// BLOOM 7B configuration
523    pub fn bloom_7b() -> EncoderStackConfig {
524        EncoderStackConfig::new(
525            30,    // layers
526            4096,  // d_model
527            32,    // n_heads
528            16384, // d_ff
529            2048,  // max_seq_len
530        )
531        .unwrap()
532        .with_dropout(0.0)
533    }
534
535    /// T5 Small configuration (60M parameters)
536    pub fn t5_small() -> (EncoderStackConfig, DecoderStackConfig) {
537        let encoder = EncoderStackConfig::new(6, 512, 8, 2048, 512)
538            .unwrap()
539            .with_dropout(0.1);
540
541        let decoder = DecoderStackConfig::new(6, 512, 8, 2048, 512)
542            .unwrap()
543            .with_dropout(0.1);
544
545        (encoder, decoder)
546    }
547
548    /// T5 Base configuration (220M parameters)
549    pub fn t5_base() -> (EncoderStackConfig, DecoderStackConfig) {
550        let encoder = EncoderStackConfig::new(12, 768, 12, 3072, 512)
551            .unwrap()
552            .with_dropout(0.1);
553
554        let decoder = DecoderStackConfig::new(12, 768, 12, 3072, 512)
555            .unwrap()
556            .with_dropout(0.1);
557
558        (encoder, decoder)
559    }
560
561    /// T5 Large configuration (770M parameters)
562    pub fn t5_large() -> (EncoderStackConfig, DecoderStackConfig) {
563        let encoder = EncoderStackConfig::new(24, 1024, 16, 4096, 512)
564            .unwrap()
565            .with_dropout(0.1);
566
567        let decoder = DecoderStackConfig::new(24, 1024, 16, 4096, 512)
568            .unwrap()
569            .with_dropout(0.1);
570
571        (encoder, decoder)
572    }
573
574    /// T5 XL configuration (3B parameters)
575    pub fn t5_xl() -> (EncoderStackConfig, DecoderStackConfig) {
576        let encoder = EncoderStackConfig::new(24, 2048, 32, 8192, 512)
577            .unwrap()
578            .with_dropout(0.1);
579
580        let decoder = DecoderStackConfig::new(24, 2048, 32, 8192, 512)
581            .unwrap()
582            .with_dropout(0.1);
583
584        (encoder, decoder)
585    }
586
587    /// T5 XXL configuration (11B parameters)
588    pub fn t5_xxl() -> (EncoderStackConfig, DecoderStackConfig) {
589        let encoder = EncoderStackConfig::new(24, 4096, 64, 16384, 512)
590            .unwrap()
591            .with_dropout(0.1);
592
593        let decoder = DecoderStackConfig::new(24, 4096, 64, 16384, 512)
594            .unwrap()
595            .with_dropout(0.1);
596
597        (encoder, decoder)
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn test_count_attention_params() {
607        let config = AttentionConfig::new(512, 8).unwrap();
608        let params = count_attention_params(&config);
609
610        // QKV: 3 * 512 * 512 = 786,432
611        // Output: 512 * 512 = 262,144
612        // Biases: 4 * 512 = 2,048
613        // Total: 1,050,624
614        assert_eq!(params, 1_050_624);
615    }
616
617    #[test]
618    fn test_count_ffn_params() {
619        let config = FeedForwardConfig::new(512, 2048);
620        let params = count_ffn_params(&config);
621
622        // Layer 1: 512 * 2048 + 2048 = 1,050,624
623        // Layer 2: 2048 * 512 + 512 = 1,049,088
624        // Total: 2,099,712
625        assert_eq!(params, 2_099_712);
626    }
627
628    #[test]
629    fn test_count_layernorm_params() {
630        let config = LayerNormConfig::new(512);
631        let params = count_layernorm_params(&config);
632        assert_eq!(params, 1024); // gamma + beta
633
634        let config_no_affine = LayerNormConfig::new(512).with_elementwise_affine(false);
635        let params = count_layernorm_params(&config_no_affine);
636        assert_eq!(params, 0);
637    }
638
639    #[test]
640    fn test_encoder_layer_params() {
641        let config = EncoderLayerConfig::new(512, 8, 2048).unwrap();
642        let params = count_encoder_layer_params(&config);
643
644        // Attention: 1,050,624
645        // FFN: 2,099,712
646        // LN1: 1,024
647        // LN2: 1,024
648        // Total: 3,152,384
649        assert_eq!(params, 3_152_384);
650    }
651
652    #[test]
653    fn test_encoder_stack_stats() {
654        let config = EncoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
655        let stats = encoder_stack_stats(&config);
656
657        assert_eq!(stats.num_layers, 6);
658        assert_eq!(stats.d_model, 512);
659        assert!(stats.total_params > 0);
660        assert_eq!(stats.trainable_params, stats.total_params);
661    }
662
663    #[test]
664    fn test_decoder_stack_stats() {
665        let config = DecoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
666        let stats = decoder_stack_stats(&config);
667
668        assert_eq!(stats.num_layers, 6);
669        assert_eq!(stats.d_model, 512);
670        // Decoder has more params than encoder (cross-attention)
671        assert!(stats.total_params > 0);
672    }
673
674    #[test]
675    fn test_flops_calculations() {
676        let batch = 32;
677        let seq_len = 128;
678        let d_model = 512;
679        let d_ff = 2048;
680
681        let attn_flops = attention_flops(batch, seq_len, d_model);
682        assert!(attn_flops > 0);
683
684        let ffn_flops = ffn_flops(batch, seq_len, d_model, d_ff);
685        assert!(ffn_flops > 0);
686    }
687
688    #[test]
689    fn test_validate_compatibility() {
690        let encoder = EncoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
691        let decoder = DecoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
692
693        assert!(validate_encoder_decoder_compatibility(&encoder, &decoder).is_ok());
694
695        // Mismatched d_model
696        let encoder_mismatch = EncoderStackConfig::new(6, 768, 8, 2048, 512).unwrap();
697        assert!(validate_encoder_decoder_compatibility(&encoder_mismatch, &decoder).is_err());
698    }
699
700    #[test]
701    fn test_presets() {
702        let gpt2 = presets::gpt2_small();
703        assert_eq!(gpt2.num_layers, 12);
704        assert_eq!(gpt2.layer_config.attention.d_model, 768);
705
706        let bert = presets::bert_base();
707        assert_eq!(bert.num_layers, 12);
708        assert_eq!(bert.layer_config.attention.d_model, 768);
709
710        let (encoder, decoder) = presets::transformer_base();
711        assert_eq!(encoder.num_layers, 6);
712        assert_eq!(decoder.num_layers, 6);
713        assert!(validate_encoder_decoder_compatibility(&encoder, &decoder).is_ok());
714    }
715
716    #[test]
717    fn test_model_stats_summary() {
718        let config = EncoderStackConfig::new(6, 512, 8, 2048, 512).unwrap();
719        let stats = encoder_stack_stats(&config);
720        let summary = stats.summary();
721
722        assert!(summary.contains("ModelStats"));
723        assert!(summary.contains("Total params"));
724        assert!(summary.contains("Layers: 6"));
725    }
726
727    #[test]
728    fn test_format_number() {
729        let stats = ModelStats {
730            total_params: 117_000_000,
731            trainable_params: 117_000_000,
732            num_layers: 12,
733            d_model: 768,
734            memory_estimate: 468_000_000,
735        };
736
737        let summary = stats.summary();
738        assert!(summary.contains("117.00M"));
739    }
740
741    #[test]
742    fn test_bert_large_preset() {
743        let config = presets::bert_large();
744        assert_eq!(config.num_layers, 24);
745        assert_eq!(config.layer_config.attention.d_model, 1024);
746        assert_eq!(config.layer_config.attention.n_heads, 16);
747    }
748
749    #[test]
750    fn test_gpt2_variants() {
751        let small = presets::gpt2_small();
752        let medium = presets::gpt2_medium();
753        let large = presets::gpt2_large();
754        let xl = presets::gpt2_xl();
755
756        // Verify parameter counts increase
757        let small_stats = encoder_stack_stats(&small);
758        let medium_stats = encoder_stack_stats(&medium);
759        let large_stats = encoder_stack_stats(&large);
760        let xl_stats = encoder_stack_stats(&xl);
761
762        assert!(medium_stats.total_params > small_stats.total_params);
763        assert!(large_stats.total_params > medium_stats.total_params);
764        assert!(xl_stats.total_params > large_stats.total_params);
765    }
766
767    #[test]
768    fn test_gpt3_variants() {
769        let small = presets::gpt3_small();
770        let medium = presets::gpt3_medium();
771        let large = presets::gpt3_large();
772        let xl = presets::gpt3_xl();
773
774        assert_eq!(small.num_layers, 12);
775        assert_eq!(medium.num_layers, 24);
776        assert_eq!(large.num_layers, 24);
777        assert_eq!(xl.num_layers, 24);
778
779        // d_model increases
780        assert!(medium.layer_config.attention.d_model > small.layer_config.attention.d_model);
781        assert!(large.layer_config.attention.d_model > medium.layer_config.attention.d_model);
782        assert!(xl.layer_config.attention.d_model > large.layer_config.attention.d_model);
783    }
784
785    #[test]
786    fn test_gpt3_large_models() {
787        let m2_7b = presets::gpt3_2_7b();
788        let m6_7b = presets::gpt3_6_7b();
789        let m13b = presets::gpt3_13b();
790        let m175b = presets::gpt3_175b();
791
792        assert_eq!(m2_7b.num_layers, 32);
793        assert_eq!(m6_7b.num_layers, 32);
794        assert_eq!(m13b.num_layers, 40);
795        assert_eq!(m175b.num_layers, 96);
796
797        // Verify d_model increases
798        assert!(m6_7b.layer_config.attention.d_model > m2_7b.layer_config.attention.d_model);
799        assert!(m13b.layer_config.attention.d_model > m6_7b.layer_config.attention.d_model);
800        assert!(m175b.layer_config.attention.d_model > m13b.layer_config.attention.d_model);
801    }
802
803    #[test]
804    fn test_llama_variants() {
805        let m7b = presets::llama_7b();
806        let m13b = presets::llama_13b();
807        let m30b = presets::llama_30b();
808        let m65b = presets::llama_65b();
809
810        // Verify layer counts increase
811        assert!(m13b.num_layers > m7b.num_layers);
812        assert!(m30b.num_layers > m13b.num_layers);
813        assert!(m65b.num_layers > m30b.num_layers);
814
815        // Verify d_model increases
816        assert!(m13b.layer_config.attention.d_model > m7b.layer_config.attention.d_model);
817        assert!(m30b.layer_config.attention.d_model > m13b.layer_config.attention.d_model);
818        assert!(m65b.layer_config.attention.d_model > m30b.layer_config.attention.d_model);
819
820        // LLaMA uses learned PE (would be RoPE in practice)
821        assert!(matches!(
822            m7b.position_encoding.encoding_type,
823            crate::position::PositionEncodingType::Learned
824        ));
825    }
826
827    #[test]
828    fn test_bloom_variants() {
829        let m560m = presets::bloom_560m();
830        let m3b = presets::bloom_3b();
831        let m7b = presets::bloom_7b();
832
833        assert_eq!(m560m.num_layers, 24);
834        assert_eq!(m3b.num_layers, 30);
835        assert_eq!(m7b.num_layers, 30);
836
837        // Verify d_model increases
838        assert!(m3b.layer_config.attention.d_model > m560m.layer_config.attention.d_model);
839        assert!(m7b.layer_config.attention.d_model > m3b.layer_config.attention.d_model);
840    }
841
842    #[test]
843    fn test_t5_variants() {
844        let small = presets::t5_small();
845        let base = presets::t5_base();
846        let large = presets::t5_large();
847        let xl = presets::t5_xl();
848        let xxl = presets::t5_xxl();
849
850        // Verify encoder-decoder compatibility
851        assert!(validate_encoder_decoder_compatibility(&small.0, &small.1).is_ok());
852        assert!(validate_encoder_decoder_compatibility(&base.0, &base.1).is_ok());
853        assert!(validate_encoder_decoder_compatibility(&large.0, &large.1).is_ok());
854        assert!(validate_encoder_decoder_compatibility(&xl.0, &xl.1).is_ok());
855        assert!(validate_encoder_decoder_compatibility(&xxl.0, &xxl.1).is_ok());
856
857        // Verify parameter counts increase
858        let small_stats = encoder_stack_stats(&small.0);
859        let base_stats = encoder_stack_stats(&base.0);
860        let large_stats = encoder_stack_stats(&large.0);
861
862        assert!(base_stats.total_params > small_stats.total_params);
863        assert!(large_stats.total_params > base_stats.total_params);
864    }
865
866    #[test]
867    fn test_all_presets_validate() {
868        // Ensure all preset configurations are valid
869        assert!(presets::tiny().validate().is_ok());
870        assert!(presets::gpt2_small().validate().is_ok());
871        assert!(presets::bert_base().validate().is_ok());
872        assert!(presets::bert_large().validate().is_ok());
873        assert!(presets::gpt2_medium().validate().is_ok());
874        assert!(presets::gpt2_large().validate().is_ok());
875        assert!(presets::gpt2_xl().validate().is_ok());
876        assert!(presets::gpt3_small().validate().is_ok());
877        assert!(presets::gpt3_medium().validate().is_ok());
878        assert!(presets::gpt3_large().validate().is_ok());
879        assert!(presets::gpt3_xl().validate().is_ok());
880        assert!(presets::llama_7b().validate().is_ok());
881        assert!(presets::llama_13b().validate().is_ok());
882        assert!(presets::bloom_560m().validate().is_ok());
883        assert!(presets::bloom_3b().validate().is_ok());
884
885        let (enc, dec) = presets::transformer_base();
886        assert!(enc.validate().is_ok());
887        assert!(dec.validate().is_ok());
888    }
889}