embedding_example/
embedding_example.rs1use ndarray::{Array, IxDyn};
2use scirs2_neural::layers::{
3 Embedding, EmbeddingConfig, Layer, PatchEmbedding, PositionalEmbedding,
4};
5
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7 println!("Running embedding examples...");
8
9 println!("\n--- Basic Embedding Example ---");
11 let config = EmbeddingConfig {
12 num_embeddings: 10,
13 embedding_dim: 5,
14 padding_idx: Some(0),
15 max_norm: None,
16 norm_type: 2.0,
17 scale_grad_by_freq: false,
18 sparse: false,
19 };
20
21 let embedding = Embedding::<f32>::new(config)?;
22
23 let indices = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1, 2, 0, 3, 0, 4])?;
25
26 let output = embedding.forward(&indices.mapv(|x| x as f32))?;
28
29 println!("Input indices shape: {:?}", indices.shape());
30 println!("Output embeddings shape: {:?}", output.shape());
31 println!(
32 "First embedding vector: {:?}",
33 output.slice(ndarray::s![0, 0, ..]).to_owned()
34 );
35
36 println!("\n--- Positional Embedding Example ---");
38
39 let pos_embedding = PositionalEmbedding::<f32>::new(10, 8, false)?;
41
42 let token_embeddings = Array::from_shape_fn(IxDyn(&[2, 5, 8]), |_| 1.0f32);
44
45 let output = pos_embedding.forward(&token_embeddings)?;
47
48 println!("Token embeddings shape: {:?}", token_embeddings.shape());
49 println!(
50 "Output with positional encoding shape: {:?}",
51 output.shape()
52 );
53 println!(
54 "First token before positional encoding: {:?}",
55 token_embeddings.slice(ndarray::s![0, 0, ..]).to_owned()
56 );
57 println!(
58 "First token after positional encoding: {:?}",
59 output.slice(ndarray::s![0, 0, ..]).to_owned()
60 );
61
62 println!("\n--- Patch Embedding Example ---");
64
65 let patch_embedding = PatchEmbedding::<f32>::new((32, 32), (8, 8), 3, 96, true)?;
67
68 let image_input = Array::from_shape_fn(IxDyn(&[1, 3, 32, 32]), |_| rand::random::<f32>());
70
71 let output = patch_embedding.forward(&image_input)?;
73
74 println!("Input image shape: {:?}", image_input.shape());
75 println!("Patch embeddings shape: {:?}", output.shape());
76 println!("Number of patches: {}", patch_embedding.num_patches());
77 println!("Embedding dimension: {}", patch_embedding.embedding_dim);
78
79 println!(
81 "First patch embedding (first 5 values): {:?}",
82 output.slice(ndarray::s![0, 0, ..5]).to_owned()
83 );
84
85 println!("\nAll embedding examples completed successfully!");
86
87 Ok(())
88}