ruvector_attention/training/
loss.rs1#[derive(Clone, Copy, Debug, Default, PartialEq)]
7pub enum Reduction {
8 #[default]
9 Mean,
10 Sum,
11 None,
12}
13
14pub trait Loss: Send + Sync {
16 fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32;
18
19 fn compute_with_gradients(
21 &self,
22 anchor: &[f32],
23 positive: &[f32],
24 negatives: &[&[f32]],
25 ) -> (f32, Vec<f32>);
26}
27
28pub struct InfoNCELoss {
32 temperature: f32,
33}
34
35impl InfoNCELoss {
36 pub fn new(temperature: f32) -> Self {
37 Self {
38 temperature: temperature.max(0.01),
39 }
40 }
41
42 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
43 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
44 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
45 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
46 dot / (norm_a * norm_b)
47 }
48}
49
50impl Loss for InfoNCELoss {
51 fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
52 let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
53
54 let neg_sims: Vec<f32> = negatives
55 .iter()
56 .map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
57 .collect();
58
59 let max_sim = neg_sims
61 .iter()
62 .copied()
63 .chain(std::iter::once(pos_sim))
64 .fold(f32::NEG_INFINITY, f32::max);
65
66 let sum_exp: f32 =
67 neg_sims.iter().map(|s| (s - max_sim).exp()).sum::<f32>() + (pos_sim - max_sim).exp();
68
69 let log_sum_exp = max_sim + sum_exp.ln();
70
71 log_sum_exp - pos_sim
72 }
73
74 fn compute_with_gradients(
75 &self,
76 anchor: &[f32],
77 positive: &[f32],
78 negatives: &[&[f32]],
79 ) -> (f32, Vec<f32>) {
80 let dim = anchor.len();
81 let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
82
83 let neg_sims: Vec<f32> = negatives
84 .iter()
85 .map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
86 .collect();
87
88 let max_sim = neg_sims
90 .iter()
91 .copied()
92 .chain(std::iter::once(pos_sim))
93 .fold(f32::NEG_INFINITY, f32::max);
94
95 let pos_exp = (pos_sim - max_sim).exp();
96 let neg_exps: Vec<f32> = neg_sims.iter().map(|s| (s - max_sim).exp()).collect();
97 let total_exp: f32 = pos_exp + neg_exps.iter().sum::<f32>();
98
99 let pos_weight = pos_exp / total_exp;
100 let neg_weights: Vec<f32> = neg_exps.iter().map(|e| e / total_exp).collect();
101
102 let loss = -(pos_weight.ln());
104
105 let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
108 let norm_p: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
109
110 let mut gradients = vec![0.0f32; dim];
111
112 let dot_ap: f32 = anchor.iter().zip(positive.iter()).map(|(a, p)| a * p).sum();
114 for i in 0..dim {
115 let d_sim = (positive[i] / (norm_a * norm_p))
116 - (anchor[i] * dot_ap / (norm_a.powi(3) * norm_p));
117 gradients[i] += (pos_weight - 1.0) * d_sim / self.temperature;
118 }
119
120 for (neg, &weight) in negatives.iter().zip(neg_weights.iter()) {
122 let norm_n: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
123 let dot_an: f32 = anchor.iter().zip(neg.iter()).map(|(a, n)| a * n).sum();
124
125 for i in 0..dim {
126 let d_sim =
127 (neg[i] / (norm_a * norm_n)) - (anchor[i] * dot_an / (norm_a.powi(3) * norm_n));
128 gradients[i] += weight * d_sim / self.temperature;
129 }
130 }
131
132 (loss, gradients)
133 }
134}
135
136pub struct LocalContrastiveLoss {
138 margin: f32,
139 reduction: Reduction,
140}
141
142impl LocalContrastiveLoss {
143 pub fn new(margin: f32) -> Self {
144 Self {
145 margin,
146 reduction: Reduction::Mean,
147 }
148 }
149
150 pub fn with_reduction(mut self, reduction: Reduction) -> Self {
151 self.reduction = reduction;
152 self
153 }
154
155 fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
156 a.iter()
157 .zip(b.iter())
158 .map(|(x, y)| (x - y).powi(2))
159 .sum::<f32>()
160 .sqrt()
161 }
162}
163
164impl Loss for LocalContrastiveLoss {
165 fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
166 let d_pos = Self::euclidean_distance(anchor, positive);
167
168 let losses: Vec<f32> = negatives
169 .iter()
170 .map(|neg| {
171 let d_neg = Self::euclidean_distance(anchor, neg);
172 (d_pos - d_neg + self.margin).max(0.0)
173 })
174 .collect();
175
176 match self.reduction {
177 Reduction::Mean => losses.iter().sum::<f32>() / losses.len().max(1) as f32,
178 Reduction::Sum => losses.iter().sum(),
179 Reduction::None => losses.first().copied().unwrap_or(0.0),
180 }
181 }
182
183 fn compute_with_gradients(
184 &self,
185 anchor: &[f32],
186 positive: &[f32],
187 negatives: &[&[f32]],
188 ) -> (f32, Vec<f32>) {
189 let dim = anchor.len();
190 let d_pos = Self::euclidean_distance(anchor, positive);
191
192 let mut total_loss = 0.0f32;
193 let mut gradients = vec![0.0f32; dim];
194 let mut active_count = 0;
195
196 for neg in negatives.iter() {
197 let d_neg = Self::euclidean_distance(anchor, neg);
198 let margin_loss = d_pos - d_neg + self.margin;
199
200 if margin_loss > 0.0 {
201 total_loss += margin_loss;
202 active_count += 1;
203
204 for i in 0..dim {
206 if d_pos > 1e-8 {
207 gradients[i] += (anchor[i] - positive[i]) / d_pos;
208 }
209 if d_neg > 1e-8 {
210 gradients[i] -= (anchor[i] - neg[i]) / d_neg;
211 }
212 }
213 }
214 }
215
216 let loss = match self.reduction {
217 Reduction::Mean if active_count > 0 => {
218 gradients.iter_mut().for_each(|g| *g /= active_count as f32);
219 total_loss / active_count as f32
220 }
221 Reduction::Sum => total_loss,
222 _ => total_loss / negatives.len().max(1) as f32,
223 };
224
225 (loss, gradients)
226 }
227}
228
229pub struct SpectralRegularization {
231 weight: f32,
232}
233
234impl SpectralRegularization {
235 pub fn new(weight: f32) -> Self {
236 Self { weight }
237 }
238
239 pub fn compute_batch(&self, embeddings: &[&[f32]]) -> f32 {
241 if embeddings.is_empty() {
242 return 0.0;
243 }
244
245 let dim = embeddings[0].len();
246 let n = embeddings.len();
247
248 let mut var_sum = 0.0f32;
250
251 for d in 0..dim {
252 let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
253 let var: f32 = embeddings
254 .iter()
255 .map(|e| (e[d] - mean).powi(2))
256 .sum::<f32>()
257 / n as f32;
258 var_sum += var;
259 }
260
261 let avg_var = var_sum / dim as f32;
263 let var_of_var: f32 = {
264 let mut sum = 0.0;
265 for d in 0..dim {
266 let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
267 let var: f32 = embeddings
268 .iter()
269 .map(|e| (e[d] - mean).powi(2))
270 .sum::<f32>()
271 / n as f32;
272 sum += (var - avg_var).powi(2);
273 }
274 sum / dim as f32
275 };
276
277 self.weight * var_of_var
278 }
279}
280
281impl Loss for SpectralRegularization {
282 fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
283 let mut all_embeddings: Vec<&[f32]> = Vec::with_capacity(2 + negatives.len());
284 all_embeddings.push(anchor);
285 all_embeddings.push(positive);
286 all_embeddings.extend(negatives.iter().copied());
287
288 self.compute_batch(&all_embeddings)
289 }
290
291 fn compute_with_gradients(
292 &self,
293 anchor: &[f32],
294 positive: &[f32],
295 negatives: &[&[f32]],
296 ) -> (f32, Vec<f32>) {
297 let loss = self.compute(anchor, positive, negatives);
298 let gradients = vec![0.0f32; anchor.len()];
300 (loss, gradients)
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_infonce_loss() {
310 let loss = InfoNCELoss::new(0.07);
311
312 let anchor = vec![1.0, 0.0, 0.0];
313 let positive = vec![0.9, 0.1, 0.0];
314 let negatives: Vec<Vec<f32>> = vec![vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]];
315 let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
316
317 let loss_val = loss.compute(&anchor, &positive, &neg_refs);
318 assert!(loss_val >= 0.0);
319 }
320
321 #[test]
322 fn test_infonce_gradients() {
323 let loss = InfoNCELoss::new(0.1);
324
325 let anchor = vec![0.5; 64];
326 let positive = vec![0.6; 64];
327 let negatives: Vec<Vec<f32>> = vec![vec![0.1; 64]; 5];
328 let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
329
330 let (loss_val, grads) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
331
332 assert!(loss_val >= 0.0);
333 assert_eq!(grads.len(), 64);
334 }
335
336 #[test]
337 fn test_local_contrastive() {
338 let loss = LocalContrastiveLoss::new(1.0);
339
340 let anchor = vec![0.0, 0.0];
341 let positive = vec![0.1, 0.0]; let negatives: Vec<Vec<f32>> = vec![vec![2.0, 0.0], vec![0.0, 2.0]]; let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
344
345 let loss_val = loss.compute(&anchor, &positive, &neg_refs);
346 assert!(loss_val >= 0.0);
347 }
348
349 #[test]
350 fn test_spectral_regularization() {
351 let reg = SpectralRegularization::new(0.01);
352
353 let embeddings: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
354 let emb_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
355
356 let loss_val = reg.compute_batch(&emb_refs);
357 assert!(loss_val >= 0.0);
358 }
359}