transformer_example/
transformer_example.rs

1//! Transformer model example
2//!
3//! This example demonstrates how to create and use a transformer model
4//! with the scirs2-neural crate.
5
6use ndarray::Array3;
7use rand::rngs::SmallRng;
8use rand::SeedableRng;
9use scirs2_neural::layers::Layer;
10use scirs2_neural::transformer::{Transformer, TransformerConfig};
11use scirs2_neural::utils::PositionalEncodingType;
12
13fn main() -> Result<(), Box<dyn std::error::Error>> {
14    println!("Transformer Model Example");
15    println!("========================");
16
17    // Create a seeded RNG for reproducibility
18    let mut rng = SmallRng::seed_from_u64(42);
19
20    // Create a small transformer configuration for demonstration
21    let config = TransformerConfig {
22        d_model: 64,                                           // Embedding dimension
23        n_encoder_layers: 2,                                   // Number of encoder layers
24        n_decoder_layers: 2,                                   // Number of decoder layers
25        n_heads: 4,                                            // Number of attention heads
26        d_ff: 128,       // Feed-forward network hidden dimension
27        max_seq_len: 50, // Maximum sequence length
28        dropout: 0.1,    // Dropout rate
29        pos_encoding_type: PositionalEncodingType::Sinusoidal, // Positional encoding type
30        epsilon: 1e-5,   // Small constant for layer normalization
31    };
32
33    println!("Creating transformer model with config:");
34    println!("  - d_model: {}", config.d_model);
35    println!("  - n_encoder_layers: {}", config.n_encoder_layers);
36    println!("  - n_decoder_layers: {}", config.n_decoder_layers);
37    println!("  - n_heads: {}", config.n_heads);
38    println!("  - d_ff: {}", config.d_ff);
39    println!("  - max_seq_len: {}", config.max_seq_len);
40
41    // Create the transformer model
42    let transformer = Transformer::<f64>::new(config, &mut rng)?;
43
44    // Create sample inputs
45    // In a real application, these would be token embeddings
46    let batch_size = 2;
47    let src_seq_len = 10;
48    let tgt_seq_len = 8;
49    let d_model = 64;
50
51    println!("\nSample dimensions:");
52    println!("  - Batch size: {}", batch_size);
53    println!("  - Source sequence length: {}", src_seq_len);
54    println!("  - Target sequence length: {}", tgt_seq_len);
55
56    // Create source and target sequence embeddings
57    // In practice, these would come from embedding layers or tokenizers
58    let src_embeddings = Array3::<f64>::from_elem((batch_size, src_seq_len, d_model), 0.1);
59    let tgt_embeddings = Array3::<f64>::from_elem((batch_size, tgt_seq_len, d_model), 0.1);
60
61    // Convert to dyn format once and reuse
62    let src_embeddings_dyn = src_embeddings.clone().into_dyn();
63    let tgt_embeddings_dyn = tgt_embeddings.clone().into_dyn();
64
65    println!("\nRunning encoder-only inference...");
66    // Run encoder-only inference (useful for tasks like classification)
67    let encoder_output = transformer.forward(&src_embeddings_dyn)?;
68    println!("Encoder output shape: {:?}", encoder_output.shape());
69
70    println!("\nRunning full transformer inference (training mode)...");
71    // Run full transformer training (teacher forcing)
72    let output_train = transformer.forward_train(&src_embeddings_dyn, &tgt_embeddings_dyn)?;
73    println!("Training output shape: {:?}", output_train.shape());
74
75    println!("\nRunning autoregressive inference (one step)...");
76    // Simulate autoregressive generation (one step)
77    // In practice, we would use a loop to generate tokens one by one
78    let first_token = Array3::<f64>::from_elem((batch_size, 1, d_model), 0.1);
79    let first_token_dyn = first_token.clone().into_dyn();
80    let output_inference = transformer.forward_inference(&src_embeddings_dyn, &first_token_dyn)?;
81    println!("Inference output shape: {:?}", output_inference.shape());
82
83    println!("\nExample completed successfully");
84    Ok(())
85}