transformer_example/
transformer_example.rs1use ndarray::Array3;
7use rand::rngs::SmallRng;
8use rand::SeedableRng;
9use scirs2_neural::layers::Layer;
10use scirs2_neural::transformer::{Transformer, TransformerConfig};
11use scirs2_neural::utils::PositionalEncodingType;
12
13fn main() -> Result<(), Box<dyn std::error::Error>> {
14 println!("Transformer Model Example");
15 println!("========================");
16
17 let mut rng = SmallRng::seed_from_u64(42);
19
20 let config = TransformerConfig {
22 d_model: 64, n_encoder_layers: 2, n_decoder_layers: 2, n_heads: 4, d_ff: 128, max_seq_len: 50, dropout: 0.1, pos_encoding_type: PositionalEncodingType::Sinusoidal, epsilon: 1e-5, };
32
33 println!("Creating transformer model with config:");
34 println!(" - d_model: {}", config.d_model);
35 println!(" - n_encoder_layers: {}", config.n_encoder_layers);
36 println!(" - n_decoder_layers: {}", config.n_decoder_layers);
37 println!(" - n_heads: {}", config.n_heads);
38 println!(" - d_ff: {}", config.d_ff);
39 println!(" - max_seq_len: {}", config.max_seq_len);
40
41 let transformer = Transformer::<f64>::new(config, &mut rng)?;
43
44 let batch_size = 2;
47 let src_seq_len = 10;
48 let tgt_seq_len = 8;
49 let d_model = 64;
50
51 println!("\nSample dimensions:");
52 println!(" - Batch size: {}", batch_size);
53 println!(" - Source sequence length: {}", src_seq_len);
54 println!(" - Target sequence length: {}", tgt_seq_len);
55
56 let src_embeddings = Array3::<f64>::from_elem((batch_size, src_seq_len, d_model), 0.1);
59 let tgt_embeddings = Array3::<f64>::from_elem((batch_size, tgt_seq_len, d_model), 0.1);
60
61 let src_embeddings_dyn = src_embeddings.clone().into_dyn();
63 let tgt_embeddings_dyn = tgt_embeddings.clone().into_dyn();
64
65 println!("\nRunning encoder-only inference...");
66 let encoder_output = transformer.forward(&src_embeddings_dyn)?;
68 println!("Encoder output shape: {:?}", encoder_output.shape());
69
70 println!("\nRunning full transformer inference (training mode)...");
71 let output_train = transformer.forward_train(&src_embeddings_dyn, &tgt_embeddings_dyn)?;
73 println!("Training output shape: {:?}", output_train.shape());
74
75 println!("\nRunning autoregressive inference (one step)...");
76 let first_token = Array3::<f64>::from_elem((batch_size, 1, d_model), 0.1);
79 let first_token_dyn = first_token.clone().into_dyn();
80 let output_inference = transformer.forward_inference(&src_embeddings_dyn, &first_token_dyn)?;
81 println!("Inference output shape: {:?}", output_inference.shape());
82
83 println!("\nExample completed successfully");
84 Ok(())
85}