pub struct CrossEntropyLoss { /* private fields */ }
Expand description
Cross-entropy loss function.
The cross-entropy loss is defined as: L = -sum(y_true * log(y_pred)) where y_true are target probabilities and y_pred are predicted probabilities. It is commonly used for classification problems.
§Examples
use scirs2_neural::losses::CrossEntropyLoss;
use scirs2_neural::losses::Loss;
use scirs2_core::ndarray::{Array, arr1, arr2};
let ce = CrossEntropyLoss::new(1e-10);
// One-hot encoded targets and softmax'd predictions for a 3-class problem
let predictions = arr2(&[
[0.7, 0.2, 0.1], // First sample, class probabilities
[0.3, 0.6, 0.1] // Second sample, class probabilities
]).into_dyn();
let targets = arr2(&[
[1.0, 0.0, 0.0], // First sample, true class is 0
[0.0, 1.0, 0.0] // Second sample, true class is 1
]).into_dyn();
// Forward pass to calculate loss
let loss = ce.forward(&predictions, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = ce.backward(&predictions, &targets).unwrap();
Implementations§
Source§impl CrossEntropyLoss
impl CrossEntropyLoss
Sourcepub fn new(epsilon: f64) -> Self
pub fn new(epsilon: f64) -> Self
Create a new cross-entropy loss function
§Arguments
epsilon
- Small value to add to predictions to avoid log(0)
Examples found in repository?
examples/loss_functions_example.rs (line 25)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Loss functions example");
9 // Mean Squared Error example
10 println!("\n--- Mean Squared Error Example ---");
11 let mse = MeanSquaredError::new();
12 // Create sample data for regression
13 let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14 let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15 // Calculate loss
16 let loss = mse.forward(&predictions, &targets)?;
17 println!("Predictions: {predictions:?}");
18 println!("Targets: {targets:?}");
19 println!("MSE Loss: {loss:.4}");
20 // Calculate gradients
21 let gradients = mse.backward(&predictions, &targets)?;
22 println!("MSE Gradients: {gradients:?}");
23 // Cross-Entropy Loss example
24 println!("\n--- Cross-Entropy Loss Example ---");
25 let ce = CrossEntropyLoss::new(1e-10);
26 // Create sample data for multi-class classification
27 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29 let loss = ce.forward(&predictions, &targets)?;
30 println!("Predictions (probabilities):");
31 println!("{predictions:?}");
32 println!("Targets (one-hot):");
33 println!("{targets:?}");
34 println!("Cross-Entropy Loss: {loss:.4}");
35 let gradients = ce.backward(&predictions, &targets)?;
36 println!("Cross-Entropy Gradients:");
37 println!("{gradients:?}");
38 // Focal Loss example
39 println!("\n--- Focal Loss Example ---");
40 let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41 // Create sample data for imbalanced classification
42 let loss = focal.forward(&predictions, &targets)?;
43 println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44 let _gradients = focal.backward(&predictions, &targets)?;
45 println!("Focal Loss Gradients:");
46 // Contrastive Loss example
47 println!("\n--- Contrastive Loss Example ---");
48 let contrastive = ContrastiveLoss::new(1.0);
49 // Create sample data for similarity learning
50 // Embedding pairs (batch_size x 2 x embedding_dim)
51 let embeddings = Array::from_shape_vec(
52 IxDyn(&[2, 2, 3]),
53 vec![
54 0.1, 0.2, 0.3, // First pair, first embedding
55 0.1, 0.3, 0.3, // First pair, second embedding (similar)
56 0.5, 0.5, 0.5, // Second pair, first embedding
57 0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58 ],
59 )?;
60 // Labels: 1 for similar pairs, 0 for dissimilar
61 let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62 let loss = contrastive.forward(&embeddings, &labels)?;
63 println!("Embeddings (batch_size x 2 x embedding_dim):");
64 println!("{embeddings:?}");
65 println!("Labels (1 for similar, 0 for dissimilar):");
66 println!("{labels:?}");
67 println!("Contrastive Loss (margin=1.0): {loss:.4}");
68 let gradients = contrastive.backward(&embeddings, &labels)?;
69 println!("Contrastive Loss Gradients (first few):");
70 let gradient_slice = gradients.slice(scirs2_core::ndarray::s![0, .., 0]);
71 println!("{gradient_slice:?}");
72 // Triplet Loss example
73 println!("\n--- Triplet Loss Example ---");
74 let triplet = TripletLoss::new(0.5);
75 // Create sample data for triplet learning
76 // Embedding triplets (batch_size x 3 x embedding_dim)
77 let triplet_embeddings = Array3::from_shape_vec(
78 (2, 3, 3),
79 vec![
80 0.1, 0.2, 0.3, // First triplet, anchor
81 0.1, 0.3, 0.3, // First triplet, positive
82 0.5, 0.5, 0.5, // First triplet, negative
83 0.6, 0.6, 0.6, // Second triplet, anchor
84 0.5, 0.6, 0.6, // Second triplet, positive
85 0.1, 0.1, 0.1, // Second triplet, negative
86 ],
87 )?
88 .into_dyn();
89 // Dummy labels (not used by triplet loss)
90 let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91 let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92 println!("Embeddings (batch_size x 3 x embedding_dim):");
93 println!(" - First dimension: batch size");
94 println!(" - Second dimension: [anchor, positive, negative]");
95 println!(" - Third dimension: embedding components");
96 println!("Triplet Loss (margin=0.5): {loss:.4}");
97 let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98 println!("Triplet Loss Gradients (first few):");
99 Ok(())
100}
Trait Implementations§
Source§impl Clone for CrossEntropyLoss
impl Clone for CrossEntropyLoss
Source§fn clone(&self) -> CrossEntropyLoss
fn clone(&self) -> CrossEntropyLoss
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moreSource§impl Debug for CrossEntropyLoss
impl Debug for CrossEntropyLoss
Source§impl Default for CrossEntropyLoss
impl Default for CrossEntropyLoss
Source§impl<F: Float + Debug> Loss<F> for CrossEntropyLoss
impl<F: Float + Debug> Loss<F> for CrossEntropyLoss
impl Copy for CrossEntropyLoss
Auto Trait Implementations§
impl Freeze for CrossEntropyLoss
impl RefUnwindSafe for CrossEntropyLoss
impl Send for CrossEntropyLoss
impl Sync for CrossEntropyLoss
impl Unpin for CrossEntropyLoss
impl UnwindSafe for CrossEntropyLoss
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more