pub struct MeanSquaredError;
Expand description
Mean squared error loss function.
The MSE is calculated as the average of squared differences between predictions and targets: MSE = mean((predictions - targets)^2) It is commonly used for regression problems.
§Examples
use scirs2_neural::losses::MeanSquaredError;
use scirs2_neural::losses::Loss;
use ndarray::{Array, arr1};
let mse = MeanSquaredError::new();
let predictions = arr1(&[1.0, 2.0, 3.0]).into_dyn();
let targets = arr1(&[1.5, 1.8, 2.5]).into_dyn();
// Forward pass to calculate loss
let loss = mse.forward(&predictions, &targets).unwrap();
// Backward pass to calculate gradients
let gradients = mse.backward(&predictions, &targets).unwrap();
Implementations§
Source§impl MeanSquaredError
impl MeanSquaredError
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new mean squared error loss function
Examples found in repository?
examples/loss_functions_example.rs (line 11)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Loss functions example");
9 // Mean Squared Error example
10 println!("\n--- Mean Squared Error Example ---");
11 let mse = MeanSquaredError::new();
12 // Create sample data for regression
13 let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
14 let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
15 // Calculate loss
16 let loss = mse.forward(&predictions, &targets)?;
17 println!("Predictions: {predictions:?}");
18 println!("Targets: {targets:?}");
19 println!("MSE Loss: {loss:.4}");
20 // Calculate gradients
21 let gradients = mse.backward(&predictions, &targets)?;
22 println!("MSE Gradients: {gradients:?}");
23 // Cross-Entropy Loss example
24 println!("\n--- Cross-Entropy Loss Example ---");
25 let ce = CrossEntropyLoss::new(1e-10);
26 // Create sample data for multi-class classification
27 let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
28 let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
29 let loss = ce.forward(&predictions, &targets)?;
30 println!("Predictions (probabilities):");
31 println!("{predictions:?}");
32 println!("Targets (one-hot):");
33 println!("{targets:?}");
34 println!("Cross-Entropy Loss: {loss:.4}");
35 let gradients = ce.backward(&predictions, &targets)?;
36 println!("Cross-Entropy Gradients:");
37 println!("{gradients:?}");
38 // Focal Loss example
39 println!("\n--- Focal Loss Example ---");
40 let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
41 // Create sample data for imbalanced classification
42 let loss = focal.forward(&predictions, &targets)?;
43 println!("Focal Loss (gamma=2.0, alpha=0.25): {loss:.4}");
44 let _gradients = focal.backward(&predictions, &targets)?;
45 println!("Focal Loss Gradients:");
46 // Contrastive Loss example
47 println!("\n--- Contrastive Loss Example ---");
48 let contrastive = ContrastiveLoss::new(1.0);
49 // Create sample data for similarity learning
50 // Embedding pairs (batch_size x 2 x embedding_dim)
51 let embeddings = Array::from_shape_vec(
52 IxDyn(&[2, 2, 3]),
53 vec![
54 0.1, 0.2, 0.3, // First pair, first embedding
55 0.1, 0.3, 0.3, // First pair, second embedding (similar)
56 0.5, 0.5, 0.5, // Second pair, first embedding
57 0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
58 ],
59 )?;
60 // Labels: 1 for similar pairs, 0 for dissimilar
61 let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
62 let loss = contrastive.forward(&embeddings, &labels)?;
63 println!("Embeddings (batch_size x 2 x embedding_dim):");
64 println!("{embeddings:?}");
65 println!("Labels (1 for similar, 0 for dissimilar):");
66 println!("{labels:?}");
67 println!("Contrastive Loss (margin=1.0): {loss:.4}");
68 let gradients = contrastive.backward(&embeddings, &labels)?;
69 println!("Contrastive Loss Gradients (first few):");
70 let gradient_slice = gradients.slice(ndarray::s![0, .., 0]);
71 println!("{gradient_slice:?}");
72 // Triplet Loss example
73 println!("\n--- Triplet Loss Example ---");
74 let triplet = TripletLoss::new(0.5);
75 // Create sample data for triplet learning
76 // Embedding triplets (batch_size x 3 x embedding_dim)
77 let triplet_embeddings = Array3::from_shape_vec(
78 (2, 3, 3),
79 vec![
80 0.1, 0.2, 0.3, // First triplet, anchor
81 0.1, 0.3, 0.3, // First triplet, positive
82 0.5, 0.5, 0.5, // First triplet, negative
83 0.6, 0.6, 0.6, // Second triplet, anchor
84 0.5, 0.6, 0.6, // Second triplet, positive
85 0.1, 0.1, 0.1, // Second triplet, negative
86 ],
87 )?
88 .into_dyn();
89 // Dummy labels (not used by triplet loss)
90 let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
91 let loss = triplet.forward(&triplet_embeddings, &dummy_labels)?;
92 println!("Embeddings (batch_size x 3 x embedding_dim):");
93 println!(" - First dimension: batch size");
94 println!(" - Second dimension: [anchor, positive, negative]");
95 println!(" - Third dimension: embedding components");
96 println!("Triplet Loss (margin=0.5): {loss:.4}");
97 let _gradients = triplet.backward(&triplet_embeddings, &dummy_labels)?;
98 println!("Triplet Loss Gradients (first few):");
99 Ok(())
100}
Trait Implementations§
Source§impl Clone for MeanSquaredError
impl Clone for MeanSquaredError
Source§fn clone(&self) -> MeanSquaredError
fn clone(&self) -> MeanSquaredError
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moreSource§impl Debug for MeanSquaredError
impl Debug for MeanSquaredError
Source§impl Default for MeanSquaredError
impl Default for MeanSquaredError
Source§impl<F: Float + Debug> Loss<F> for MeanSquaredError
impl<F: Float + Debug> Loss<F> for MeanSquaredError
impl Copy for MeanSquaredError
Auto Trait Implementations§
impl Freeze for MeanSquaredError
impl RefUnwindSafe for MeanSquaredError
impl Send for MeanSquaredError
impl Sync for MeanSquaredError
impl Unpin for MeanSquaredError
impl UnwindSafe for MeanSquaredError
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more