Skip to main content

scirs2_neural/models/architectures/
gpt.rs

1//! GPT implementation
2//!
3//! GPT (Generative Pre-trained Transformer) is a transformer-based language model
4//! designed for autoregressive language modeling. Unlike BERT which is bidirectional,
5//! GPT uses a unidirectional (left-to-right) transformer architecture.
6//! Reference: "Improving Language Understanding by Generative Pre-Training", Radford et al. (2018)
7//! <https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf>
8
9use crate::error::{NeuralError, Result};
10use crate::layers::{Dense, Dropout, Embedding, EmbeddingConfig, Layer, LayerNorm};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, NumAssign};
13use scirs2_core::random::SeedableRng;
14use scirs2_core::simd_ops::SimdUnifiedOps;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// Configuration for a GPT model
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct GPTConfig {
21    /// Vocabulary size
22    pub vocab_size: usize,
23    /// Maximum sequence length
24    pub max_position_embeddings: usize,
25    /// Hidden size
26    pub hidden_size: usize,
27    /// Number of hidden layers
28    pub num_hidden_layers: usize,
29    /// Number of attention heads
30    pub num_attention_heads: usize,
31    /// Intermediate size in feed-forward networks
32    pub intermediate_size: usize,
33    /// Hidden activation function
34    pub hidden_act: String,
35    /// Hidden dropout probability
36    pub hidden_dropout_prob: f64,
37    /// Attention dropout probability
38    pub attention_probs_dropout_prob: f64,
39    /// Layer norm epsilon
40    pub layer_norm_eps: f64,
41    /// Initializer range
42    pub initializer_range: f64,
43}
44
45impl GPTConfig {
46    /// Create a GPT-2 Small configuration
47    pub fn gpt2_small() -> Self {
48        Self {
49            vocab_size: 50257,
50            max_position_embeddings: 1024,
51            hidden_size: 768,
52            num_hidden_layers: 12,
53            num_attention_heads: 12,
54            intermediate_size: 3072,
55            hidden_act: "gelu".to_string(),
56            hidden_dropout_prob: 0.1,
57            attention_probs_dropout_prob: 0.1,
58            layer_norm_eps: 1e-5,
59            initializer_range: 0.02,
60        }
61    }
62
63    /// Create a GPT-2 Medium configuration
64    pub fn gpt2_medium() -> Self {
65        Self {
66            vocab_size: 50257,
67            max_position_embeddings: 1024,
68            hidden_size: 1024,
69            num_hidden_layers: 24,
70            num_attention_heads: 16,
71            intermediate_size: 4096,
72            hidden_act: "gelu".to_string(),
73            hidden_dropout_prob: 0.1,
74            attention_probs_dropout_prob: 0.1,
75            layer_norm_eps: 1e-5,
76            initializer_range: 0.02,
77        }
78    }
79
80    /// Create a GPT-2 Large configuration
81    pub fn gpt2_large() -> Self {
82        Self {
83            vocab_size: 50257,
84            max_position_embeddings: 1024,
85            hidden_size: 1280,
86            num_hidden_layers: 36,
87            num_attention_heads: 20,
88            intermediate_size: 5120,
89            hidden_act: "gelu".to_string(),
90            hidden_dropout_prob: 0.1,
91            attention_probs_dropout_prob: 0.1,
92            layer_norm_eps: 1e-5,
93            initializer_range: 0.02,
94        }
95    }
96
97    /// Create a custom GPT configuration
98    pub fn custom(
99        vocab_size: usize,
100        hidden_size: usize,
101        num_hidden_layers: usize,
102        num_attention_heads: usize,
103    ) -> Self {
104        Self {
105            vocab_size,
106            max_position_embeddings: 1024,
107            hidden_size,
108            num_hidden_layers,
109            num_attention_heads,
110            intermediate_size: hidden_size * 4,
111            hidden_act: "gelu".to_string(),
112            hidden_dropout_prob: 0.1,
113            attention_probs_dropout_prob: 0.1,
114            layer_norm_eps: 1e-5,
115            initializer_range: 0.02,
116        }
117    }
118}
119
120/// GPT embedding combining token and position embeddings
121struct GPTEmbeddings<
122    F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
123> {
124    /// Token embeddings
125    token_embeddings: Embedding<F>,
126    /// Position embeddings
127    position_embeddings: Embedding<F>,
128    /// Dropout
129    dropout: Dropout<F>,
130}
131
132impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
133    for GPTEmbeddings<F>
134{
135    fn clone(&self) -> Self {
136        Self {
137            token_embeddings: self.token_embeddings.clone(),
138            position_embeddings: self.position_embeddings.clone(),
139            dropout: self.dropout.clone(),
140        }
141    }
142}
143
144impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
145    GPTEmbeddings<F>
146{
147    /// Create GPT embeddings
148    pub fn new(config: &GPTConfig) -> Result<Self> {
149        let token_embeddings = Embedding::new(EmbeddingConfig {
150            num_embeddings: config.vocab_size,
151            embedding_dim: config.hidden_size,
152            padding_idx: None,
153            max_norm: None,
154            norm_type: 2.0,
155            scale_grad_by_freq: false,
156        })?;
157
158        let position_embeddings = Embedding::new(EmbeddingConfig {
159            num_embeddings: config.max_position_embeddings,
160            embedding_dim: config.hidden_size,
161            padding_idx: None,
162            max_norm: None,
163            norm_type: 2.0,
164            scale_grad_by_freq: false,
165        })?;
166
167        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
168        let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng3)?;
169
170        Ok(Self {
171            token_embeddings,
172            position_embeddings,
173            dropout,
174        })
175    }
176}
177
178impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
179    for GPTEmbeddings<F>
180{
181    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
182        let shape = input.shape();
183        if shape.len() != 2 {
184            return Err(NeuralError::InferenceError(format!(
185                "Expected input shape [batch_size, seq_len], got {:?}",
186                shape
187            )));
188        }
189
190        let batch_size = shape[0];
191        let seq_len = shape[1];
192
193        // Get token embeddings
194        let inputs_embeds = self.token_embeddings.forward(input)?;
195
196        // Create position IDs
197        let mut position_ids = Array::zeros(IxDyn(&[batch_size, seq_len]));
198        for b in 0..batch_size {
199            for s in 0..seq_len {
200                position_ids[[b, s]] = F::from(s).expect("Failed to convert to float");
201            }
202        }
203
204        // Get position embeddings
205        let position_embeds = self.position_embeddings.forward(&position_ids)?;
206
207        // Combine embeddings
208        let embeddings = &inputs_embeds + &position_embeds;
209
210        // Apply dropout
211        let embeddings = self.dropout.forward(&embeddings)?;
212
213        Ok(embeddings)
214    }
215
216    fn backward(
217        &self,
218        _input: &Array<F, IxDyn>,
219        grad_output: &Array<F, IxDyn>,
220    ) -> Result<Array<F, IxDyn>> {
221        Ok(grad_output.clone())
222    }
223
224    fn update(&mut self, learning_rate: F) -> Result<()> {
225        self.token_embeddings.update(learning_rate)?;
226        self.position_embeddings.update(learning_rate)?;
227        Ok(())
228    }
229
230    fn as_any(&self) -> &dyn std::any::Any {
231        self
232    }
233
234    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
235        self
236    }
237}
238
239/// GPT MLP (feed-forward network)
240struct GPTMlp<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
241    /// First dense layer
242    fc1: Dense<F>,
243    /// Second dense layer
244    fc2: Dense<F>,
245    /// Dropout
246    dropout: Dropout<F>,
247}
248
249impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
250    for GPTMlp<F>
251{
252    fn clone(&self) -> Self {
253        Self {
254            fc1: self.fc1.clone(),
255            fc2: self.fc2.clone(),
256            dropout: self.dropout.clone(),
257        }
258    }
259}
260
261impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
262    GPTMlp<F>
263{
264    /// Create GPT MLP
265    pub fn new(config: &GPTConfig) -> Result<Self> {
266        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
267        let fc1 = Dense::new(
268            config.hidden_size,
269            config.intermediate_size,
270            None,
271            &mut rng1,
272        )?;
273
274        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
275        let fc2 = Dense::new(
276            config.intermediate_size,
277            config.hidden_size,
278            None,
279            &mut rng2,
280        )?;
281
282        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
283        let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng3)?;
284
285        Ok(Self { fc1, fc2, dropout })
286    }
287}
288
289impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
290    for GPTMlp<F>
291{
292    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
293        // Apply first dense layer
294        let hidden_states = self.fc1.forward(input)?;
295
296        // Apply GELU activation
297        let hidden_states = hidden_states.mapv(|x: F| {
298            let x3 = x * x * x;
299            x * F::from(0.5).expect("Failed to convert constant to float")
300                * (F::one()
301                    + (x + F::from(0.044715).expect("Failed to convert constant to float") * x3)
302                        .tanh())
303        });
304
305        // Apply second dense layer
306        let hidden_states = self.fc2.forward(&hidden_states)?;
307
308        // Apply dropout
309        let hidden_states = self.dropout.forward(&hidden_states)?;
310
311        Ok(hidden_states)
312    }
313
314    fn backward(
315        &self,
316        _input: &Array<F, IxDyn>,
317        grad_output: &Array<F, IxDyn>,
318    ) -> Result<Array<F, IxDyn>> {
319        Ok(grad_output.clone())
320    }
321
322    fn update(&mut self, learning_rate: F) -> Result<()> {
323        self.fc1.update(learning_rate)?;
324        self.fc2.update(learning_rate)?;
325        Ok(())
326    }
327
328    fn as_any(&self) -> &dyn std::any::Any {
329        self
330    }
331
332    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
333        self
334    }
335}
336
337/// GPT attention layer (masked multi-head attention)
338struct GPTAttention<
339    F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
340> {
341    /// Number of attention heads
342    num_attention_heads: usize,
343    /// Size of each attention head
344    attention_head_size: usize,
345    /// Query projection
346    query: Dense<F>,
347    /// Key projection
348    key: Dense<F>,
349    /// Value projection
350    value: Dense<F>,
351    /// Output projection
352    output: Dense<F>,
353    /// Attention dropout
354    attn_dropout: Dropout<F>,
355    /// Output dropout
356    resid_dropout: Dropout<F>,
357    /// Scale factor for attention scores
358    scale: F,
359}
360
361impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
362    for GPTAttention<F>
363{
364    fn clone(&self) -> Self {
365        Self {
366            num_attention_heads: self.num_attention_heads,
367            attention_head_size: self.attention_head_size,
368            query: self.query.clone(),
369            key: self.key.clone(),
370            value: self.value.clone(),
371            output: self.output.clone(),
372            attn_dropout: self.attn_dropout.clone(),
373            resid_dropout: self.resid_dropout.clone(),
374            scale: self.scale,
375        }
376    }
377}
378
379impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
380    GPTAttention<F>
381{
382    /// Create GPT attention layer
383    pub fn new(config: &GPTConfig) -> Result<Self> {
384        let hidden_size = config.hidden_size;
385        let num_attention_heads = config.num_attention_heads;
386        let attention_head_size = hidden_size / num_attention_heads;
387
388        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
389        let query = Dense::new(hidden_size, hidden_size, None, &mut rng1)?;
390
391        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
392        let key = Dense::new(hidden_size, hidden_size, None, &mut rng2)?;
393
394        let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
395        let value = Dense::new(hidden_size, hidden_size, None, &mut rng3)?;
396
397        let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
398        let output = Dense::new(hidden_size, hidden_size, None, &mut rng4)?;
399
400        let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
401        let attn_dropout = Dropout::new(config.attention_probs_dropout_prob, &mut rng5)?;
402
403        let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
404        let resid_dropout = Dropout::new(config.hidden_dropout_prob, &mut rng6)?;
405
406        let scale = F::from(1.0 / (attention_head_size as f64).sqrt()).expect("Operation failed");
407
408        Ok(Self {
409            num_attention_heads,
410            attention_head_size,
411            query,
412            key,
413            value,
414            output,
415            attn_dropout,
416            resid_dropout,
417            scale,
418        })
419    }
420}
421
422impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
423    for GPTAttention<F>
424{
425    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
426        let shape = input.shape();
427        if shape.len() != 3 {
428            return Err(NeuralError::InferenceError(format!(
429                "Expected input shape [batch_size, seq_len, hidden_size], got {:?}",
430                shape
431            )));
432        }
433
434        let batch_size = shape[0];
435        let seq_len = shape[1];
436        let hidden_size = shape[2];
437
438        // Project query, key, value
439        let query = self.query.forward(input)?;
440        let key = self.key.forward(input)?;
441        let value = self.value.forward(input)?;
442
443        // Simplified attention: just combine projections
444        // In a full implementation, we'd do proper multi-head attention with causal masking
445        let attention_output = &query + &key + &value;
446
447        // Apply output projection
448        let output = self.output.forward(&attention_output)?;
449        let output = self.resid_dropout.forward(&output)?;
450
451        // Suppress unused variable warnings
452        let _ = (batch_size, seq_len, hidden_size);
453
454        Ok(output)
455    }
456
457    fn backward(
458        &self,
459        _input: &Array<F, IxDyn>,
460        grad_output: &Array<F, IxDyn>,
461    ) -> Result<Array<F, IxDyn>> {
462        Ok(grad_output.clone())
463    }
464
465    fn update(&mut self, learning_rate: F) -> Result<()> {
466        self.query.update(learning_rate)?;
467        self.key.update(learning_rate)?;
468        self.value.update(learning_rate)?;
469        self.output.update(learning_rate)?;
470        Ok(())
471    }
472
473    fn as_any(&self) -> &dyn std::any::Any {
474        self
475    }
476
477    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
478        self
479    }
480}
481
482/// GPT block (attention + MLP)
483struct GPTBlock<
484    F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
485> {
486    /// Layer normalization for attention
487    ln_1: LayerNorm<F>,
488    /// Attention layer
489    attn: GPTAttention<F>,
490    /// Layer normalization for MLP
491    ln_2: LayerNorm<F>,
492    /// MLP
493    mlp: GPTMlp<F>,
494}
495
496impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
497    for GPTBlock<F>
498{
499    fn clone(&self) -> Self {
500        Self {
501            ln_1: self.ln_1.clone(),
502            attn: self.attn.clone(),
503            ln_2: self.ln_2.clone(),
504            mlp: self.mlp.clone(),
505        }
506    }
507}
508
509impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
510    GPTBlock<F>
511{
512    /// Create GPT block
513    pub fn new(config: &GPTConfig) -> Result<Self> {
514        let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
515        let ln_1 = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng1)?;
516
517        let attn = GPTAttention::new(config)?;
518
519        let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([55; 32]);
520        let ln_2 = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng2)?;
521
522        let mlp = GPTMlp::new(config)?;
523
524        Ok(Self {
525            ln_1,
526            attn,
527            ln_2,
528            mlp,
529        })
530    }
531}
532
533impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
534    for GPTBlock<F>
535{
536    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
537        // Attention with residual connection
538        let ln1_output = self.ln_1.forward(input)?;
539        let attn_output = self.attn.forward(&ln1_output)?;
540        let residual1 = input + &attn_output;
541
542        // MLP with residual connection
543        let ln2_output = self.ln_2.forward(&residual1)?;
544        let mlp_output = self.mlp.forward(&ln2_output)?;
545        let residual2 = &residual1 + &mlp_output;
546
547        Ok(residual2)
548    }
549
550    fn backward(
551        &self,
552        _input: &Array<F, IxDyn>,
553        grad_output: &Array<F, IxDyn>,
554    ) -> Result<Array<F, IxDyn>> {
555        Ok(grad_output.clone())
556    }
557
558    fn update(&mut self, learning_rate: F) -> Result<()> {
559        self.ln_1.update(learning_rate)?;
560        self.attn.update(learning_rate)?;
561        self.ln_2.update(learning_rate)?;
562        self.mlp.update(learning_rate)?;
563        Ok(())
564    }
565
566    fn as_any(&self) -> &dyn std::any::Any {
567        self
568    }
569
570    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
571        self
572    }
573}
574
575/// GPT model implementation
576pub struct GPTModel<
577    F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
578> {
579    /// Embeddings layer
580    embeddings: GPTEmbeddings<F>,
581    /// Transformer blocks
582    blocks: Vec<GPTBlock<F>>,
583    /// Final layer normalization
584    ln_f: LayerNorm<F>,
585    /// Model configuration
586    config: GPTConfig,
587}
588
589impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
590    for GPTModel<F>
591{
592    fn clone(&self) -> Self {
593        Self {
594            embeddings: self.embeddings.clone(),
595            blocks: self.blocks.clone(),
596            ln_f: self.ln_f.clone(),
597            config: self.config.clone(),
598        }
599    }
600}
601
602impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
603    GPTModel<F>
604{
605    /// Create a new GPT model
606    pub fn new(config: GPTConfig) -> Result<Self> {
607        let embeddings = GPTEmbeddings::new(&config)?;
608
609        // Create transformer blocks
610        let mut blocks = Vec::with_capacity(config.num_hidden_layers);
611        for _ in 0..config.num_hidden_layers {
612            blocks.push(GPTBlock::new(&config)?);
613        }
614
615        // Final layer normalization
616        let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([56; 32]);
617        let ln_f = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng)?;
618
619        Ok(Self {
620            embeddings,
621            blocks,
622            ln_f,
623            config,
624        })
625    }
626
627    /// Create a GPT-2 Small model
628    pub fn gpt2_small() -> Result<Self> {
629        let config = GPTConfig::gpt2_small();
630        Self::new(config)
631    }
632
633    /// Create a GPT-2 Medium model
634    pub fn gpt2_medium() -> Result<Self> {
635        let config = GPTConfig::gpt2_medium();
636        Self::new(config)
637    }
638
639    /// Create a GPT-2 Large model
640    pub fn gpt2_large() -> Result<Self> {
641        let config = GPTConfig::gpt2_large();
642        Self::new(config)
643    }
644
645    /// Create a custom GPT model
646    pub fn custom(
647        vocab_size: usize,
648        hidden_size: usize,
649        num_hidden_layers: usize,
650        num_attention_heads: usize,
651    ) -> Result<Self> {
652        let config = GPTConfig::custom(
653            vocab_size,
654            hidden_size,
655            num_hidden_layers,
656            num_attention_heads,
657        );
658        Self::new(config)
659    }
660
661    /// Get the model configuration
662    pub fn config(&self) -> &GPTConfig {
663        &self.config
664    }
665}
666
667impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
668    for GPTModel<F>
669{
670    fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
671        // Get embeddings
672        let mut hidden_states = self.embeddings.forward(input)?;
673
674        // Apply transformer blocks
675        for block in &self.blocks {
676            hidden_states = block.forward(&hidden_states)?;
677        }
678
679        // Apply final layer normalization
680        hidden_states = self.ln_f.forward(&hidden_states)?;
681
682        Ok(hidden_states)
683    }
684
685    fn backward(
686        &self,
687        _input: &Array<F, IxDyn>,
688        grad_output: &Array<F, IxDyn>,
689    ) -> Result<Array<F, IxDyn>> {
690        Ok(grad_output.clone())
691    }
692
693    fn update(&mut self, learning_rate: F) -> Result<()> {
694        self.embeddings.update(learning_rate)?;
695        for block in &mut self.blocks {
696            block.update(learning_rate)?;
697        }
698        self.ln_f.update(learning_rate)?;
699        Ok(())
700    }
701
702    fn as_any(&self) -> &dyn std::any::Any {
703        self
704    }
705
706    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
707        self
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714
715    #[test]
716    fn test_gpt_config_small() {
717        let config = GPTConfig::gpt2_small();
718        assert_eq!(config.vocab_size, 50257);
719        assert_eq!(config.hidden_size, 768);
720        assert_eq!(config.num_hidden_layers, 12);
721        assert_eq!(config.num_attention_heads, 12);
722    }
723
724    #[test]
725    fn test_gpt_config_medium() {
726        let config = GPTConfig::gpt2_medium();
727        assert_eq!(config.hidden_size, 1024);
728        assert_eq!(config.num_hidden_layers, 24);
729        assert_eq!(config.num_attention_heads, 16);
730    }
731
732    #[test]
733    fn test_gpt_config_large() {
734        let config = GPTConfig::gpt2_large();
735        assert_eq!(config.hidden_size, 1280);
736        assert_eq!(config.num_hidden_layers, 36);
737        assert_eq!(config.num_attention_heads, 20);
738    }
739
740    #[test]
741    fn test_gpt_config_custom() {
742        let config = GPTConfig::custom(10000, 256, 4, 4);
743        assert_eq!(config.vocab_size, 10000);
744        assert_eq!(config.hidden_size, 256);
745        assert_eq!(config.num_hidden_layers, 4);
746        assert_eq!(config.num_attention_heads, 4);
747        assert_eq!(config.intermediate_size, 1024);
748    }
749}