1use syntaxdot_tch_ext::tensor::SumDim;
2use tch::{Kind, Reduction, Tensor};
3
4use crate::TransformerError;
5
6trait Reduce {
7 type Error;
8
9 fn reduce(&self, t: &Tensor) -> Result<Tensor, Self::Error>;
10}
11
12impl Reduce for Reduction {
13 type Error = TransformerError;
14
15 fn reduce(&self, t: &Tensor) -> Result<Tensor, Self::Error> {
16 match self {
17 Reduction::None => Ok(t.shallow_clone()),
18 Reduction::Mean => Ok(t.f_mean(t.kind())?),
19 Reduction::Sum => Ok(t.f_sum(t.kind())?),
20 Reduction::Other(_) => unimplemented!(),
21 }
22 }
23}
24
25pub struct CrossEntropyLoss {
27 ignore_index: i64,
28 label_smoothing: Option<f64>,
29 reduction: Reduction,
30}
31
32impl CrossEntropyLoss {
33 pub fn new(ignore_index: i64, label_smoothing: Option<f64>, reduction: Reduction) -> Self {
41 CrossEntropyLoss {
42 ignore_index,
43 label_smoothing,
44 reduction,
45 }
46 }
47
48 pub fn forward(
58 &self,
59 logits: &Tensor,
60 targets: &Tensor,
61 target_mask: Option<&Tensor>,
62 ) -> Result<Tensor, TransformerError> {
63 let (_, n_classes) = logits.size2()?;
64 let log_probs = logits.f_log_softmax(-1, logits.kind())?;
65
66 match self.label_smoothing {
67 Some(label_smoothing) => {
68 let token_mask = targets.f_ne(self.ignore_index)?;
69
70 let targets_non_negative =
72 targets.f_where_scalarother(&targets.f_ne(self.ignore_index)?, 0)?;
73
74 let smoothed_targets = tch::no_grad(|| match target_mask {
76 None => {
77 Tensor::f_full_like(&log_probs, label_smoothing / (n_classes - 1) as f64)?
78 .f_scatter_value(
79 1,
80 &targets_non_negative.f_unsqueeze(1)?,
81 1. - label_smoothing,
82 )
83 }
84 Some(target_mask) => {
85 let batch_probs = label_smoothing
86 / target_mask
87 .f_sum_dim(-1, false, Kind::Float)?
88 .f_sub_scalar(1)?;
89 Tensor::f_zeros_like(&log_probs)?
90 .f_add_(&batch_probs.f_unsqueeze(-1)?)?
92 .f_mul(&target_mask.to_kind(Kind::Float))?
94 .f_scatter_value(
96 1,
97 &targets_non_negative.f_unsqueeze(1)?,
98 1. - label_smoothing,
99 )
100 }
101 })?;
102 let losses = (smoothed_targets.f_neg()?.f_mul(&log_probs)?).f_sum_dim(
103 -1,
104 false,
105 log_probs.kind(),
106 )?;
107
108 Ok(self.reduction.reduce(&losses.masked_select(&token_mask))?)
109 }
110 None => Ok(log_probs.f_nll_loss::<&Tensor>(
111 targets,
112 None,
113 self.reduction,
114 self.ignore_index,
115 )?),
116 }
117 }
118}
119
120pub enum MSELossNormalization {
122 Mean,
124
125 SquaredL2Norm,
129}
130
131pub struct MSELoss {
133 normalization: MSELossNormalization,
134}
135
136impl MSELoss {
137 pub fn new(normalization: MSELossNormalization) -> Self {
139 MSELoss { normalization }
140 }
141
142 pub fn forward(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, tch::TchError> {
148 let reduction = match self.normalization {
149 MSELossNormalization::Mean => Reduction::Mean,
150 MSELossNormalization::SquaredL2Norm => Reduction::None,
151 };
152
153 let loss = prediction.f_mse_loss(target, reduction);
154
155 match self.normalization {
156 MSELossNormalization::Mean => loss,
157 MSELossNormalization::SquaredL2Norm => {
158 let norm = target.f_frobenius_norm(&[1], true)?.f_square()?;
159 let (batch_size, _) = target.size2()?;
160 loss?
161 .f_div(&norm)?
162 .f_sum(Kind::Float)?
163 .f_div_scalar(batch_size)
164 }
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::convert::TryInto;
172
173 use approx::assert_abs_diff_eq;
174 use ndarray::{array, ArrayD};
175 use tch::{Reduction, Tensor};
176
177 use crate::loss::CrossEntropyLoss;
178
179 use super::MSELoss;
180
181 #[test]
182 fn cross_entropy_loss_without_label_smoothing() {
183 let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
184 let targets = Tensor::of_slice(&[2i64]).view([1]);
185 let cross_entropy_loss = CrossEntropyLoss::new(-1, None, Reduction::None);
186 let loss: ArrayD<f32> = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap())
187 .try_into()
188 .unwrap();
189
190 assert_abs_diff_eq!(loss, array![0.432653].into_dyn(), epsilon = 1e-6);
191 }
192
193 #[test]
194 fn cross_entropy_with_label_smoothing() {
195 let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
196 let targets = Tensor::of_slice(&[2i64]).view([1]);
197 let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None);
198 let loss: ArrayD<f32> = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap())
199 .try_into()
200 .unwrap();
201 assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
202 }
203
204 #[test]
205 fn cross_entropy_with_label_smoothing_and_mask() {
206 let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
207 let target_mask = Tensor::of_slice(&[true, false, true, false, true]).view([1, 5]);
208 let targets = Tensor::of_slice(&[2i64]).view([1]);
209 let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None);
210 let loss: ArrayD<f32> = (&cross_entropy_loss
211 .forward(&logits, &targets, Some(&target_mask))
212 .unwrap())
213 .try_into()
214 .unwrap();
215 assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
216 }
217
218 #[test]
219 fn mse_loss_with_averaging() {
220 let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([1, 4]);
221 let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([1, 4]);
222 let mse_loss = MSELoss::new(super::MSELossNormalization::Mean);
223 let loss = &mse_loss.forward(&prediction, &target).unwrap();
224 assert_abs_diff_eq!(f32::from(loss), 0.375f32, epsilon = 1e-6);
225 }
226
227 #[test]
228 fn mse_loss_with_squared_l2_norm() {
229 let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([2, 2]);
230 let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([2, 2]);
231 let mse_loss = MSELoss::new(super::MSELossNormalization::SquaredL2Norm);
232 let loss = mse_loss.forward(&prediction, &target).unwrap();
233 assert_abs_diff_eq!(f32::from(loss), 0.5, epsilon = 1e-6);
234 }
235}