gpt_example/
gpt_example.rs1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::Layer;
3use scirs2_neural::models::{GPTConfig, GPTModel};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6 println!("GPT Model Example");
7
8 println!("Creating a small GPT model...");
10
11 let config = GPTConfig::custom(
12 10000, 128, 2, 2, );
17
18 let model = GPTModel::<f32>::new(config)?;
19
20 let input = Array::from_shape_fn(
23 IxDyn(&[2, 16]),
24 |_| rand::random::<f32>() * 100.0, );
26
27 println!("Input shape: {:?}", input.shape());
28
29 let hidden_states = model.forward(&input)?;
31
32 println!("Hidden states shape: {:?}", hidden_states.shape());
33
34 let logits = model.logits(&input)?;
36
37 println!("Logits shape: {:?}", logits.shape());
38 println!("Vocabulary size: {}", logits.shape()[2]);
39
40 println!("\nCreating a GPT-2 Small model...");
42
43 let gpt2_small = GPTModel::<f32>::gpt2_small()?;
44
45 let small_input = Array::from_shape_fn(
47 IxDyn(&[1, 32]),
48 |_| rand::random::<f32>() * 1000.0, );
50
51 println!("GPT-2 Small input shape: {:?}", small_input.shape());
52
53 let small_hidden_states = gpt2_small.forward(&small_input)?;
55
56 println!(
57 "GPT-2 Small hidden states shape: {:?}",
58 small_hidden_states.shape()
59 );
60 println!(
61 "GPT-2 Small hidden dimension: {}",
62 small_hidden_states.shape()[2]
63 );
64
65 let small_logits = gpt2_small.logits(&small_input)?;
67 println!("GPT-2 Small logits shape: {:?}", small_logits.shape());
68 println!("GPT-2 Small vocabulary size: {}", small_logits.shape()[2]);
69
70 println!("\nGPT example completed successfully!");
71
72 Ok(())
73}