gpt_example/
gpt_example.rs

1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::Layer;
3use scirs2_neural::models::{GPTConfig, GPTModel};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6    println!("GPT Model Example");
7
8    // Create a small GPT model for demonstration
9    println!("Creating a small GPT model...");
10
11    let config = GPTConfig::custom(
12        10000, // vocab_size
13        128,   // hidden_size
14        2,     // num_hidden_layers
15        2,     // num_attention_heads
16    );
17
18    let model = GPTModel::<f32>::new(config)?;
19
20    // Create dummy input (batch_size=2, seq_len=16)
21    // Input tensor contains token IDs
22    let input = Array::from_shape_fn(
23        IxDyn(&[2, 16]),
24        |_| rand::random::<f32>() * 100.0, // Random token IDs between 0 and 100
25    );
26
27    println!("Input shape: {:?}", input.shape());
28
29    // Forward pass to get hidden states
30    let hidden_states = model.forward(&input)?;
31
32    println!("Hidden states shape: {:?}", hidden_states.shape());
33
34    // Calculate logits for next-token prediction
35    let logits = model.logits(&input)?;
36
37    println!("Logits shape: {:?}", logits.shape());
38    println!("Vocabulary size: {}", logits.shape()[2]);
39
40    // Let's create a GPT-2 Small model
41    println!("\nCreating a GPT-2 Small model...");
42
43    let gpt2_small = GPTModel::<f32>::gpt2_small()?;
44
45    // Create dummy input for a longer sequence
46    let small_input = Array::from_shape_fn(
47        IxDyn(&[1, 32]),
48        |_| rand::random::<f32>() * 1000.0, // Random token IDs
49    );
50
51    println!("GPT-2 Small input shape: {:?}", small_input.shape());
52
53    // Forward pass
54    let small_hidden_states = gpt2_small.forward(&small_input)?;
55
56    println!(
57        "GPT-2 Small hidden states shape: {:?}",
58        small_hidden_states.shape()
59    );
60    println!(
61        "GPT-2 Small hidden dimension: {}",
62        small_hidden_states.shape()[2]
63    );
64
65    // For text generation (logits for next token prediction)
66    let small_logits = gpt2_small.logits(&small_input)?;
67    println!("GPT-2 Small logits shape: {:?}", small_logits.shape());
68    println!("GPT-2 Small vocabulary size: {}", small_logits.shape()[2]);
69
70    println!("\nGPT example completed successfully!");
71
72    Ok(())
73}