loss_functions_example/
loss_functions_example.rs

1use scirs2_core::ndarray::{Array, Array3, IxDyn};
2use scirs2_neural::losses::{
3    ContrastiveLoss, CrossEntropyLoss, FocalLoss, Loss, MeanSquaredError, TripletLoss,
4};
5
6#[allow(dead_code)]
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Loss functions example");
9    // Mean Squared Error example
10    println!("\n--- Mean Squared Error Example ---");
11    let mse = MeanSquaredError::new();
12    // Create sample data for regression
13    let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14    let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15    // Calculate loss
16    let loss = mse.forward(&predictions, &targets)?;
17    println!("Predictions: {predictions:?}");
18    println!("Targets: {targets:?}");
19    println!("MSE Loss: {loss:.4}");
20    // Calculate gradients
21    let gradients = mse.backward(&predictions, &targets)?;
22    println!("MSE Gradients: {gradients:?}");
23    // Cross-Entropy Loss example
24    println!("\n--- Cross-Entropy Loss Example ---");
25    let ce = CrossEntropyLoss::new(1e-10);
26    // Create sample data for multi-class classification
27    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29    let loss = ce.forward(&predictions, &targets)?;
30    println!("Predictions (probabilities):");
31    println!("{predictions:?}");
32    println!("Targets (one-hot):");
33    println!("{targets:?}");
34    println!("Cross-Entropy Loss: {loss:.4}");
35    let gradients = ce.backward(&predictions, &targets)?;
36    println!("Cross-Entropy Gradients:");
37    println!("{gradients:?}");
38    // Focal Loss example
39    println!("\n--- Focal Loss Example ---");
40    let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41    // Create sample data for imbalanced classification
42    let loss = focal.forward(&predictions, &targets)?;
43    println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44    let _gradients = focal.backward(&predictions, &targets)?;
45    println!("Focal Loss Gradients:");
46    // Contrastive Loss example
47    println!("\n--- Contrastive Loss Example ---");
48    let contrastive = ContrastiveLoss::new(1.0);
49    // Create sample data for similarity learning
50    // Embedding pairs (batch_size x 2 x embedding_dim)
51    let embeddings = Array::from_shape_vec(
52        IxDyn(&[2, 2, 3]),
53        vec![
54            0.1, 0.2, 0.3, // First pair, first embedding
55            0.1, 0.3, 0.3, // First pair, second embedding (similar)
56            0.5, 0.5, 0.5, // Second pair, first embedding
57            0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58        ],
59    )?;
60    // Labels: 1 for similar pairs, 0 for dissimilar
61    let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62    let loss = contrastive.forward(&embeddings, &labels)?;
63    println!("Embeddings (batch_size x 2 x embedding_dim):");
64    println!("{embeddings:?}");
65    println!("Labels (1 for similar, 0 for dissimilar):");
66    println!("{labels:?}");
67    println!("Contrastive Loss (margin=1.0): {loss:.4}");
68    let gradients = contrastive.backward(&embeddings, &labels)?;
69    println!("Contrastive Loss Gradients (first few):");
70    let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71    println!("{gradient_slice:?}");
72    // Triplet Loss example
73    println!("\n--- Triplet Loss Example ---");
74    let triplet = TripletLoss::new(0.5);
75    // Create sample data for triplet learning
76    // Embedding triplets (batch_size x 3 x embedding_dim)
77    let triplet_embeddings = Array3::from_shape_vec(
78        (2, 3, 3),
79        vec![
80            0.1, 0.2, 0.3, // First triplet, anchor
81            0.1, 0.3, 0.3, // First triplet, positive
82            0.5, 0.5, 0.5, // First triplet, negative
83            0.6, 0.6, 0.6, // Second triplet, anchor
84            0.5, 0.6, 0.6, // Second triplet, positive
85            0.1, 0.1, 0.1, // Second triplet, negative
86        ],
87    )?
88    .into_dyn();
89    // Dummy labels (not used by triplet loss)
90    let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91    let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92    println!("Embeddings (batch_size x 3 x embedding_dim):");
93    println!("  - First dimension: batch size");
94    println!("  - Second dimension: [anchor, positive, negative]");
95    println!("  - Third dimension: embedding components");
96    println!("Triplet Loss (margin=0.5): {loss:.4}");
97    let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98    println!("Triplet Loss Gradients (first few):");
99    Ok(())
100}