1use scirs2_core::ndarray::{Array, Array3, IxDyn};
2use scirs2_neural::losses::{
3 ContrastiveLoss, CrossEntropyLoss, FocalLoss, Loss, MeanSquaredError, TripletLoss,
4};
5
6#[allow(dead_code)]
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Loss functions example");
9 println!("\n--- Mean Squared Error Example ---");
11 let mse = MeanSquaredError::new();
12 let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14 let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15 let loss = mse.forward(&predictions, &targets)?;
17 println!("Predictions: {predictions:?}");
18 println!("Targets: {targets:?}");
19 println!("MSE Loss: {loss:.4}");
20 let gradients = mse.backward(&predictions, &targets)?;
22 println!("MSE Gradients: {gradients:?}");
23 println!("\n--- Cross-Entropy Loss Example ---");
25 let ce = CrossEntropyLoss::new(1e-10);
26 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29 let loss = ce.forward(&predictions, &targets)?;
30 println!("Predictions (probabilities):");
31 println!("{predictions:?}");
32 println!("Targets (one-hot):");
33 println!("{targets:?}");
34 println!("Cross-Entropy Loss: {loss:.4}");
35 let gradients = ce.backward(&predictions, &targets)?;
36 println!("Cross-Entropy Gradients:");
37 println!("{gradients:?}");
38 println!("\n--- Focal Loss Example ---");
40 let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41 let loss = focal.forward(&predictions, &targets)?;
43 println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44 let _gradients = focal.backward(&predictions, &targets)?;
45 println!("Focal Loss Gradients:");
46 println!("\n--- Contrastive Loss Example ---");
48 let contrastive = ContrastiveLoss::new(1.0);
49 let embeddings = Array::from_shape_vec(
52 IxDyn(&[2, 2, 3]),
53 vec![
54 0.1, 0.2, 0.3, 0.1, 0.3, 0.3, 0.5, 0.5, 0.5, 0.9, 0.8, 0.7, ],
59 )?;
60 let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62 let loss = contrastive.forward(&embeddings, &labels)?;
63 println!("Embeddings (batch_size x 2 x embedding_dim):");
64 println!("{embeddings:?}");
65 println!("Labels (1 for similar, 0 for dissimilar):");
66 println!("{labels:?}");
67 println!("Contrastive Loss (margin=1.0): {loss:.4}");
68 let gradients = contrastive.backward(&embeddings, &labels)?;
69 println!("Contrastive Loss Gradients (first few):");
70 let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71 println!("{gradient_slice:?}");
72 println!("\n--- Triplet Loss Example ---");
74 let triplet = TripletLoss::new(0.5);
75 let triplet_embeddings = Array3::from_shape_vec(
78 (2, 3, 3),
79 vec![
80 0.1, 0.2, 0.3, 0.1, 0.3, 0.3, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.5, 0.6, 0.6, 0.1, 0.1, 0.1, ],
87 )?
88 .into_dyn();
89 let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91 let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92 println!("Embeddings (batch_size x 3 x embedding_dim):");
93 println!(" - First dimension: batch size");
94 println!(" - Second dimension: [anchor, positive, negative]");
95 println!(" - Third dimension: embedding components");
96 println!("Triplet Loss (margin=0.5): {loss:.4}");
97 let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98 println!("Triplet Loss Gradients (first few):");
99 Ok(())
100}