syntaxdot_transformers/
loss.rs

1use syntaxdot_tch_ext::tensor::SumDim;
2use tch::{Kind, Reduction, Tensor};
3
4use crate::TransformerError;
5
6trait Reduce {
7    type Error;
8
9    fn reduce(&self, t: &Tensor) -> Result<Tensor, Self::Error>;
10}
11
12impl Reduce for Reduction {
13    type Error = TransformerError;
14
15    fn reduce(&self, t: &Tensor) -> Result<Tensor, Self::Error> {
16        match self {
17            Reduction::None => Ok(t.shallow_clone()),
18            Reduction::Mean => Ok(t.f_mean(t.kind())?),
19            Reduction::Sum => Ok(t.f_sum(t.kind())?),
20            Reduction::Other(_) => unimplemented!(),
21        }
22    }
23}
24
25/// Cross-entropy loss function.
26pub struct CrossEntropyLoss {
27    ignore_index: i64,
28    label_smoothing: Option<f64>,
29    reduction: Reduction,
30}
31
32impl CrossEntropyLoss {
33    /// Construct the cross-entropy loss function.
34    ///
35    /// Do not include targets that have `ignore_index` as their value in the
36    /// loss computation. If `label_smoothing` is set to *p*, then the correct
37    /// label gets probability *1-p* and the probability *p* is distributed
38    /// across incorrect labels. `reduction` specifies how the losses should
39    /// be reduced/summarized.
40    pub fn new(ignore_index: i64, label_smoothing: Option<f64>, reduction: Reduction) -> Self {
41        CrossEntropyLoss {
42            ignore_index,
43            label_smoothing,
44            reduction,
45        }
46    }
47
48    /// Compute the cross-entropy loss.
49    ///
50    /// `logits` should be the unnormalized probablilities of shape
51    /// `[batch_size, n_classes]` and `targets` the gold-standard labels
52    /// with shape `[batch_size]`.
53    ///
54    /// The optional target mask has to be of shape `[batch_size, n_classes]`.
55    /// If the mask is not provided, then all `n_classes` will be used in
56    /// label smoothing.
57    pub fn forward(
58        &self,
59        logits: &Tensor,
60        targets: &Tensor,
61        target_mask: Option<&Tensor>,
62    ) -> Result<Tensor, TransformerError> {
63        let (_, n_classes) = logits.size2()?;
64        let log_probs = logits.f_log_softmax(-1, logits.kind())?;
65
66        match self.label_smoothing {
67            Some(label_smoothing) => {
68                let token_mask = targets.f_ne(self.ignore_index)?;
69
70                // Do not attempt to use negative indices for the correct target.
71                let targets_non_negative =
72                    targets.f_where_scalarother(&targets.f_ne(self.ignore_index)?, 0)?;
73
74                // Set all labels to label_smoothing and the target to 1-label_smoothing.
75                let smoothed_targets = tch::no_grad(|| match target_mask {
76                    None => {
77                        Tensor::f_full_like(&log_probs, label_smoothing / (n_classes - 1) as f64)?
78                            .f_scatter_value(
79                                1,
80                                &targets_non_negative.f_unsqueeze(1)?,
81                                1. - label_smoothing,
82                            )
83                    }
84                    Some(target_mask) => {
85                        let batch_probs = label_smoothing
86                            / target_mask
87                                .f_sum_dim(-1, false, Kind::Float)?
88                                .f_sub_scalar(1)?;
89                        Tensor::f_zeros_like(&log_probs)?
90                            // Set label probabilities to batch smoothing probability.
91                            .f_add_(&batch_probs.f_unsqueeze(-1)?)?
92                            // Mask out padding.
93                            .f_mul(&target_mask.to_kind(Kind::Float))?
94                            // Assign probabilities to gold standard labels.
95                            .f_scatter_value(
96                                1,
97                                &targets_non_negative.f_unsqueeze(1)?,
98                                1. - label_smoothing,
99                            )
100                    }
101                })?;
102                let losses = (smoothed_targets.f_neg()?.f_mul(&log_probs)?).f_sum_dim(
103                    -1,
104                    false,
105                    log_probs.kind(),
106                )?;
107
108                Ok(self.reduction.reduce(&losses.masked_select(&token_mask))?)
109            }
110            None => Ok(log_probs.f_nll_loss::<&Tensor>(
111                targets,
112                None,
113                self.reduction,
114                self.ignore_index,
115            )?),
116        }
117    }
118}
119
120/// Mean squared error loss normalization.
121pub enum MSELossNormalization {
122    /// Take the mean of all losses.
123    Mean,
124
125    /// Normalize by squared L2 norm of columns.
126    ///
127    /// The resulting losses are averaged.
128    SquaredL2Norm,
129}
130
131/// Mean squared error loss.
132pub struct MSELoss {
133    normalization: MSELossNormalization,
134}
135
136impl MSELoss {
137    /// Construct mean squared error loss.
138    pub fn new(normalization: MSELossNormalization) -> Self {
139        MSELoss { normalization }
140    }
141
142    /// Compute the mean squared error loss.
143    ///
144    /// Computes the loss of `prediction` given `target`. The tensor
145    /// must be two-dimensional and have the same shape. The loss is
146    /// returned as a scalar.
147    pub fn forward(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, tch::TchError> {
148        let reduction = match self.normalization {
149            MSELossNormalization::Mean => Reduction::Mean,
150            MSELossNormalization::SquaredL2Norm => Reduction::None,
151        };
152
153        let loss = prediction.f_mse_loss(target, reduction);
154
155        match self.normalization {
156            MSELossNormalization::Mean => loss,
157            MSELossNormalization::SquaredL2Norm => {
158                let norm = target.f_frobenius_norm(&[1], true)?.f_square()?;
159                let (batch_size, _) = target.size2()?;
160                loss?
161                    .f_div(&norm)?
162                    .f_sum(Kind::Float)?
163                    .f_div_scalar(batch_size)
164            }
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::convert::TryInto;
172
173    use approx::assert_abs_diff_eq;
174    use ndarray::{array, ArrayD};
175    use tch::{Reduction, Tensor};
176
177    use crate::loss::CrossEntropyLoss;
178
179    use super::MSELoss;
180
181    #[test]
182    fn cross_entropy_loss_without_label_smoothing() {
183        let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
184        let targets = Tensor::of_slice(&[2i64]).view([1]);
185        let cross_entropy_loss = CrossEntropyLoss::new(-1, None, Reduction::None);
186        let loss: ArrayD<f32> = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap())
187            .try_into()
188            .unwrap();
189
190        assert_abs_diff_eq!(loss, array![0.432653].into_dyn(), epsilon = 1e-6);
191    }
192
193    #[test]
194    fn cross_entropy_with_label_smoothing() {
195        let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
196        let targets = Tensor::of_slice(&[2i64]).view([1]);
197        let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None);
198        let loss: ArrayD<f32> = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap())
199            .try_into()
200            .unwrap();
201        assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
202    }
203
204    #[test]
205    fn cross_entropy_with_label_smoothing_and_mask() {
206        let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
207        let target_mask = Tensor::of_slice(&[true, false, true, false, true]).view([1, 5]);
208        let targets = Tensor::of_slice(&[2i64]).view([1]);
209        let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None);
210        let loss: ArrayD<f32> = (&cross_entropy_loss
211            .forward(&logits, &targets, Some(&target_mask))
212            .unwrap())
213            .try_into()
214            .unwrap();
215        assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
216    }
217
218    #[test]
219    fn mse_loss_with_averaging() {
220        let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([1, 4]);
221        let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([1, 4]);
222        let mse_loss = MSELoss::new(super::MSELossNormalization::Mean);
223        let loss = &mse_loss.forward(&prediction, &target).unwrap();
224        assert_abs_diff_eq!(f32::from(loss), 0.375f32, epsilon = 1e-6);
225    }
226
227    #[test]
228    fn mse_loss_with_squared_l2_norm() {
229        let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([2, 2]);
230        let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([2, 2]);
231        let mse_loss = MSELoss::new(super::MSELossNormalization::SquaredL2Norm);
232        let loss = mse_loss.forward(&prediction, &target).unwrap();
233        assert_abs_diff_eq!(f32::from(loss), 0.5, epsilon = 1e-6);
234    }
235}