1use crate::error::{NeuralError, Result};
10use crate::layers::{Dense, Dropout, Embedding, EmbeddingConfig, Layer, LayerNorm};
11use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
12use scirs2_core::numeric::{Float, NumAssign};
13use scirs2_core::random::SeedableRng;
14use scirs2_core::simd_ops::SimdUnifiedOps;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct GPTConfig {
21 pub vocab_size: usize,
23 pub max_position_embeddings: usize,
25 pub hidden_size: usize,
27 pub num_hidden_layers: usize,
29 pub num_attention_heads: usize,
31 pub intermediate_size: usize,
33 pub hidden_act: String,
35 pub hidden_dropout_prob: f64,
37 pub attention_probs_dropout_prob: f64,
39 pub layer_norm_eps: f64,
41 pub initializer_range: f64,
43}
44
45impl GPTConfig {
46 pub fn gpt2_small() -> Self {
48 Self {
49 vocab_size: 50257,
50 max_position_embeddings: 1024,
51 hidden_size: 768,
52 num_hidden_layers: 12,
53 num_attention_heads: 12,
54 intermediate_size: 3072,
55 hidden_act: "gelu".to_string(),
56 hidden_dropout_prob: 0.1,
57 attention_probs_dropout_prob: 0.1,
58 layer_norm_eps: 1e-5,
59 initializer_range: 0.02,
60 }
61 }
62
63 pub fn gpt2_medium() -> Self {
65 Self {
66 vocab_size: 50257,
67 max_position_embeddings: 1024,
68 hidden_size: 1024,
69 num_hidden_layers: 24,
70 num_attention_heads: 16,
71 intermediate_size: 4096,
72 hidden_act: "gelu".to_string(),
73 hidden_dropout_prob: 0.1,
74 attention_probs_dropout_prob: 0.1,
75 layer_norm_eps: 1e-5,
76 initializer_range: 0.02,
77 }
78 }
79
80 pub fn gpt2_large() -> Self {
82 Self {
83 vocab_size: 50257,
84 max_position_embeddings: 1024,
85 hidden_size: 1280,
86 num_hidden_layers: 36,
87 num_attention_heads: 20,
88 intermediate_size: 5120,
89 hidden_act: "gelu".to_string(),
90 hidden_dropout_prob: 0.1,
91 attention_probs_dropout_prob: 0.1,
92 layer_norm_eps: 1e-5,
93 initializer_range: 0.02,
94 }
95 }
96
97 pub fn custom(
99 vocab_size: usize,
100 hidden_size: usize,
101 num_hidden_layers: usize,
102 num_attention_heads: usize,
103 ) -> Self {
104 Self {
105 vocab_size,
106 max_position_embeddings: 1024,
107 hidden_size,
108 num_hidden_layers,
109 num_attention_heads,
110 intermediate_size: hidden_size * 4,
111 hidden_act: "gelu".to_string(),
112 hidden_dropout_prob: 0.1,
113 attention_probs_dropout_prob: 0.1,
114 layer_norm_eps: 1e-5,
115 initializer_range: 0.02,
116 }
117 }
118}
119
120struct GPTEmbeddings<
122 F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
123> {
124 token_embeddings: Embedding<F>,
126 position_embeddings: Embedding<F>,
128 dropout: Dropout<F>,
130}
131
132impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
133 for GPTEmbeddings<F>
134{
135 fn clone(&self) -> Self {
136 Self {
137 token_embeddings: self.token_embeddings.clone(),
138 position_embeddings: self.position_embeddings.clone(),
139 dropout: self.dropout.clone(),
140 }
141 }
142}
143
144impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
145 GPTEmbeddings<F>
146{
147 pub fn new(config: &GPTConfig) -> Result<Self> {
149 let token_embeddings = Embedding::new(EmbeddingConfig {
150 num_embeddings: config.vocab_size,
151 embedding_dim: config.hidden_size,
152 padding_idx: None,
153 max_norm: None,
154 norm_type: 2.0,
155 scale_grad_by_freq: false,
156 })?;
157
158 let position_embeddings = Embedding::new(EmbeddingConfig {
159 num_embeddings: config.max_position_embeddings,
160 embedding_dim: config.hidden_size,
161 padding_idx: None,
162 max_norm: None,
163 norm_type: 2.0,
164 scale_grad_by_freq: false,
165 })?;
166
167 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([44; 32]);
168 let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng3)?;
169
170 Ok(Self {
171 token_embeddings,
172 position_embeddings,
173 dropout,
174 })
175 }
176}
177
178impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
179 for GPTEmbeddings<F>
180{
181 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
182 let shape = input.shape();
183 if shape.len() != 2 {
184 return Err(NeuralError::InferenceError(format!(
185 "Expected input shape [batch_size, seq_len], got {:?}",
186 shape
187 )));
188 }
189
190 let batch_size = shape[0];
191 let seq_len = shape[1];
192
193 let inputs_embeds = self.token_embeddings.forward(input)?;
195
196 let mut position_ids = Array::zeros(IxDyn(&[batch_size, seq_len]));
198 for b in 0..batch_size {
199 for s in 0..seq_len {
200 position_ids[[b, s]] = F::from(s).expect("Failed to convert to float");
201 }
202 }
203
204 let position_embeds = self.position_embeddings.forward(&position_ids)?;
206
207 let embeddings = &inputs_embeds + &position_embeds;
209
210 let embeddings = self.dropout.forward(&embeddings)?;
212
213 Ok(embeddings)
214 }
215
216 fn backward(
217 &self,
218 _input: &Array<F, IxDyn>,
219 grad_output: &Array<F, IxDyn>,
220 ) -> Result<Array<F, IxDyn>> {
221 Ok(grad_output.clone())
222 }
223
224 fn update(&mut self, learning_rate: F) -> Result<()> {
225 self.token_embeddings.update(learning_rate)?;
226 self.position_embeddings.update(learning_rate)?;
227 Ok(())
228 }
229
230 fn as_any(&self) -> &dyn std::any::Any {
231 self
232 }
233
234 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
235 self
236 }
237}
238
239struct GPTMlp<F: Float + Debug + ScalarOperand + Send + Sync + NumAssign + 'static> {
241 fc1: Dense<F>,
243 fc2: Dense<F>,
245 dropout: Dropout<F>,
247}
248
249impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
250 for GPTMlp<F>
251{
252 fn clone(&self) -> Self {
253 Self {
254 fc1: self.fc1.clone(),
255 fc2: self.fc2.clone(),
256 dropout: self.dropout.clone(),
257 }
258 }
259}
260
261impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
262 GPTMlp<F>
263{
264 pub fn new(config: &GPTConfig) -> Result<Self> {
266 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([45; 32]);
267 let fc1 = Dense::new(
268 config.hidden_size,
269 config.intermediate_size,
270 None,
271 &mut rng1,
272 )?;
273
274 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([46; 32]);
275 let fc2 = Dense::new(
276 config.intermediate_size,
277 config.hidden_size,
278 None,
279 &mut rng2,
280 )?;
281
282 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([47; 32]);
283 let dropout = Dropout::new(config.hidden_dropout_prob, &mut rng3)?;
284
285 Ok(Self { fc1, fc2, dropout })
286 }
287}
288
289impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
290 for GPTMlp<F>
291{
292 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
293 let hidden_states = self.fc1.forward(input)?;
295
296 let hidden_states = hidden_states.mapv(|x: F| {
298 let x3 = x * x * x;
299 x * F::from(0.5).expect("Failed to convert constant to float")
300 * (F::one()
301 + (x + F::from(0.044715).expect("Failed to convert constant to float") * x3)
302 .tanh())
303 });
304
305 let hidden_states = self.fc2.forward(&hidden_states)?;
307
308 let hidden_states = self.dropout.forward(&hidden_states)?;
310
311 Ok(hidden_states)
312 }
313
314 fn backward(
315 &self,
316 _input: &Array<F, IxDyn>,
317 grad_output: &Array<F, IxDyn>,
318 ) -> Result<Array<F, IxDyn>> {
319 Ok(grad_output.clone())
320 }
321
322 fn update(&mut self, learning_rate: F) -> Result<()> {
323 self.fc1.update(learning_rate)?;
324 self.fc2.update(learning_rate)?;
325 Ok(())
326 }
327
328 fn as_any(&self) -> &dyn std::any::Any {
329 self
330 }
331
332 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
333 self
334 }
335}
336
337struct GPTAttention<
339 F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
340> {
341 num_attention_heads: usize,
343 attention_head_size: usize,
345 query: Dense<F>,
347 key: Dense<F>,
349 value: Dense<F>,
351 output: Dense<F>,
353 attn_dropout: Dropout<F>,
355 resid_dropout: Dropout<F>,
357 scale: F,
359}
360
361impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
362 for GPTAttention<F>
363{
364 fn clone(&self) -> Self {
365 Self {
366 num_attention_heads: self.num_attention_heads,
367 attention_head_size: self.attention_head_size,
368 query: self.query.clone(),
369 key: self.key.clone(),
370 value: self.value.clone(),
371 output: self.output.clone(),
372 attn_dropout: self.attn_dropout.clone(),
373 resid_dropout: self.resid_dropout.clone(),
374 scale: self.scale,
375 }
376 }
377}
378
379impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
380 GPTAttention<F>
381{
382 pub fn new(config: &GPTConfig) -> Result<Self> {
384 let hidden_size = config.hidden_size;
385 let num_attention_heads = config.num_attention_heads;
386 let attention_head_size = hidden_size / num_attention_heads;
387
388 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([48; 32]);
389 let query = Dense::new(hidden_size, hidden_size, None, &mut rng1)?;
390
391 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([49; 32]);
392 let key = Dense::new(hidden_size, hidden_size, None, &mut rng2)?;
393
394 let mut rng3 = scirs2_core::random::rngs::SmallRng::from_seed([50; 32]);
395 let value = Dense::new(hidden_size, hidden_size, None, &mut rng3)?;
396
397 let mut rng4 = scirs2_core::random::rngs::SmallRng::from_seed([51; 32]);
398 let output = Dense::new(hidden_size, hidden_size, None, &mut rng4)?;
399
400 let mut rng5 = scirs2_core::random::rngs::SmallRng::from_seed([52; 32]);
401 let attn_dropout = Dropout::new(config.attention_probs_dropout_prob, &mut rng5)?;
402
403 let mut rng6 = scirs2_core::random::rngs::SmallRng::from_seed([53; 32]);
404 let resid_dropout = Dropout::new(config.hidden_dropout_prob, &mut rng6)?;
405
406 let scale = F::from(1.0 / (attention_head_size as f64).sqrt()).expect("Operation failed");
407
408 Ok(Self {
409 num_attention_heads,
410 attention_head_size,
411 query,
412 key,
413 value,
414 output,
415 attn_dropout,
416 resid_dropout,
417 scale,
418 })
419 }
420}
421
422impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
423 for GPTAttention<F>
424{
425 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
426 let shape = input.shape();
427 if shape.len() != 3 {
428 return Err(NeuralError::InferenceError(format!(
429 "Expected input shape [batch_size, seq_len, hidden_size], got {:?}",
430 shape
431 )));
432 }
433
434 let batch_size = shape[0];
435 let seq_len = shape[1];
436 let hidden_size = shape[2];
437
438 let query = self.query.forward(input)?;
440 let key = self.key.forward(input)?;
441 let value = self.value.forward(input)?;
442
443 let attention_output = &query + &key + &value;
446
447 let output = self.output.forward(&attention_output)?;
449 let output = self.resid_dropout.forward(&output)?;
450
451 let _ = (batch_size, seq_len, hidden_size);
453
454 Ok(output)
455 }
456
457 fn backward(
458 &self,
459 _input: &Array<F, IxDyn>,
460 grad_output: &Array<F, IxDyn>,
461 ) -> Result<Array<F, IxDyn>> {
462 Ok(grad_output.clone())
463 }
464
465 fn update(&mut self, learning_rate: F) -> Result<()> {
466 self.query.update(learning_rate)?;
467 self.key.update(learning_rate)?;
468 self.value.update(learning_rate)?;
469 self.output.update(learning_rate)?;
470 Ok(())
471 }
472
473 fn as_any(&self) -> &dyn std::any::Any {
474 self
475 }
476
477 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
478 self
479 }
480}
481
482struct GPTBlock<
484 F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
485> {
486 ln_1: LayerNorm<F>,
488 attn: GPTAttention<F>,
490 ln_2: LayerNorm<F>,
492 mlp: GPTMlp<F>,
494}
495
496impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
497 for GPTBlock<F>
498{
499 fn clone(&self) -> Self {
500 Self {
501 ln_1: self.ln_1.clone(),
502 attn: self.attn.clone(),
503 ln_2: self.ln_2.clone(),
504 mlp: self.mlp.clone(),
505 }
506 }
507}
508
509impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
510 GPTBlock<F>
511{
512 pub fn new(config: &GPTConfig) -> Result<Self> {
514 let mut rng1 = scirs2_core::random::rngs::SmallRng::from_seed([54; 32]);
515 let ln_1 = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng1)?;
516
517 let attn = GPTAttention::new(config)?;
518
519 let mut rng2 = scirs2_core::random::rngs::SmallRng::from_seed([55; 32]);
520 let ln_2 = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng2)?;
521
522 let mlp = GPTMlp::new(config)?;
523
524 Ok(Self {
525 ln_1,
526 attn,
527 ln_2,
528 mlp,
529 })
530 }
531}
532
533impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
534 for GPTBlock<F>
535{
536 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
537 let ln1_output = self.ln_1.forward(input)?;
539 let attn_output = self.attn.forward(&ln1_output)?;
540 let residual1 = input + &attn_output;
541
542 let ln2_output = self.ln_2.forward(&residual1)?;
544 let mlp_output = self.mlp.forward(&ln2_output)?;
545 let residual2 = &residual1 + &mlp_output;
546
547 Ok(residual2)
548 }
549
550 fn backward(
551 &self,
552 _input: &Array<F, IxDyn>,
553 grad_output: &Array<F, IxDyn>,
554 ) -> Result<Array<F, IxDyn>> {
555 Ok(grad_output.clone())
556 }
557
558 fn update(&mut self, learning_rate: F) -> Result<()> {
559 self.ln_1.update(learning_rate)?;
560 self.attn.update(learning_rate)?;
561 self.ln_2.update(learning_rate)?;
562 self.mlp.update(learning_rate)?;
563 Ok(())
564 }
565
566 fn as_any(&self) -> &dyn std::any::Any {
567 self
568 }
569
570 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
571 self
572 }
573}
574
575pub struct GPTModel<
577 F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static,
578> {
579 embeddings: GPTEmbeddings<F>,
581 blocks: Vec<GPTBlock<F>>,
583 ln_f: LayerNorm<F>,
585 config: GPTConfig,
587}
588
589impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Clone
590 for GPTModel<F>
591{
592 fn clone(&self) -> Self {
593 Self {
594 embeddings: self.embeddings.clone(),
595 blocks: self.blocks.clone(),
596 ln_f: self.ln_f.clone(),
597 config: self.config.clone(),
598 }
599 }
600}
601
602impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static>
603 GPTModel<F>
604{
605 pub fn new(config: GPTConfig) -> Result<Self> {
607 let embeddings = GPTEmbeddings::new(&config)?;
608
609 let mut blocks = Vec::with_capacity(config.num_hidden_layers);
611 for _ in 0..config.num_hidden_layers {
612 blocks.push(GPTBlock::new(&config)?);
613 }
614
615 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([56; 32]);
617 let ln_f = LayerNorm::new(config.hidden_size, config.layer_norm_eps, &mut rng)?;
618
619 Ok(Self {
620 embeddings,
621 blocks,
622 ln_f,
623 config,
624 })
625 }
626
627 pub fn gpt2_small() -> Result<Self> {
629 let config = GPTConfig::gpt2_small();
630 Self::new(config)
631 }
632
633 pub fn gpt2_medium() -> Result<Self> {
635 let config = GPTConfig::gpt2_medium();
636 Self::new(config)
637 }
638
639 pub fn gpt2_large() -> Result<Self> {
641 let config = GPTConfig::gpt2_large();
642 Self::new(config)
643 }
644
645 pub fn custom(
647 vocab_size: usize,
648 hidden_size: usize,
649 num_hidden_layers: usize,
650 num_attention_heads: usize,
651 ) -> Result<Self> {
652 let config = GPTConfig::custom(
653 vocab_size,
654 hidden_size,
655 num_hidden_layers,
656 num_attention_heads,
657 );
658 Self::new(config)
659 }
660
661 pub fn config(&self) -> &GPTConfig {
663 &self.config
664 }
665}
666
667impl<F: Float + Debug + ScalarOperand + Send + Sync + SimdUnifiedOps + NumAssign + 'static> Layer<F>
668 for GPTModel<F>
669{
670 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
671 let mut hidden_states = self.embeddings.forward(input)?;
673
674 for block in &self.blocks {
676 hidden_states = block.forward(&hidden_states)?;
677 }
678
679 hidden_states = self.ln_f.forward(&hidden_states)?;
681
682 Ok(hidden_states)
683 }
684
685 fn backward(
686 &self,
687 _input: &Array<F, IxDyn>,
688 grad_output: &Array<F, IxDyn>,
689 ) -> Result<Array<F, IxDyn>> {
690 Ok(grad_output.clone())
691 }
692
693 fn update(&mut self, learning_rate: F) -> Result<()> {
694 self.embeddings.update(learning_rate)?;
695 for block in &mut self.blocks {
696 block.update(learning_rate)?;
697 }
698 self.ln_f.update(learning_rate)?;
699 Ok(())
700 }
701
702 fn as_any(&self) -> &dyn std::any::Any {
703 self
704 }
705
706 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
707 self
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714
715 #[test]
716 fn test_gpt_config_small() {
717 let config = GPTConfig::gpt2_small();
718 assert_eq!(config.vocab_size, 50257);
719 assert_eq!(config.hidden_size, 768);
720 assert_eq!(config.num_hidden_layers, 12);
721 assert_eq!(config.num_attention_heads, 12);
722 }
723
724 #[test]
725 fn test_gpt_config_medium() {
726 let config = GPTConfig::gpt2_medium();
727 assert_eq!(config.hidden_size, 1024);
728 assert_eq!(config.num_hidden_layers, 24);
729 assert_eq!(config.num_attention_heads, 16);
730 }
731
732 #[test]
733 fn test_gpt_config_large() {
734 let config = GPTConfig::gpt2_large();
735 assert_eq!(config.hidden_size, 1280);
736 assert_eq!(config.num_hidden_layers, 36);
737 assert_eq!(config.num_attention_heads, 20);
738 }
739
740 #[test]
741 fn test_gpt_config_custom() {
742 let config = GPTConfig::custom(10000, 256, 4, 4);
743 assert_eq!(config.vocab_size, 10000);
744 assert_eq!(config.hidden_size, 256);
745 assert_eq!(config.num_hidden_layers, 4);
746 assert_eq!(config.num_attention_heads, 4);
747 assert_eq!(config.intermediate_size, 1024);
748 }
749}