ruvector_attention/training/
loss.rs

1//! Loss functions for attention-based learning
2//!
3//! Includes contrastive losses optimized for representation learning.
4
5/// Reduction method for loss computation
6#[derive(Clone, Copy, Debug, Default, PartialEq)]
7pub enum Reduction {
8    #[default]
9    Mean,
10    Sum,
11    None,
12}
13
14/// Loss trait for attention training
15pub trait Loss: Send + Sync {
16    /// Compute loss value
17    fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32;
18
19    /// Compute loss with gradients for anchor
20    fn compute_with_gradients(
21        &self,
22        anchor: &[f32],
23        positive: &[f32],
24        negatives: &[&[f32]],
25    ) -> (f32, Vec<f32>);
26}
27
28/// InfoNCE contrastive loss
29///
30/// L = -log(exp(sim(a,p)/τ) / Σexp(sim(a,n)/τ))
31pub struct InfoNCELoss {
32    temperature: f32,
33}
34
35impl InfoNCELoss {
36    pub fn new(temperature: f32) -> Self {
37        Self {
38            temperature: temperature.max(0.01),
39        }
40    }
41
42    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
43        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
44        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
45        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
46        dot / (norm_a * norm_b)
47    }
48}
49
50impl Loss for InfoNCELoss {
51    fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
52        let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
53
54        let neg_sims: Vec<f32> = negatives
55            .iter()
56            .map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
57            .collect();
58
59        // Stable log-sum-exp
60        let max_sim = neg_sims
61            .iter()
62            .copied()
63            .chain(std::iter::once(pos_sim))
64            .fold(f32::NEG_INFINITY, f32::max);
65
66        let sum_exp: f32 =
67            neg_sims.iter().map(|s| (s - max_sim).exp()).sum::<f32>() + (pos_sim - max_sim).exp();
68
69        let log_sum_exp = max_sim + sum_exp.ln();
70
71        log_sum_exp - pos_sim
72    }
73
74    fn compute_with_gradients(
75        &self,
76        anchor: &[f32],
77        positive: &[f32],
78        negatives: &[&[f32]],
79    ) -> (f32, Vec<f32>) {
80        let dim = anchor.len();
81        let pos_sim = Self::cosine_similarity(anchor, positive) / self.temperature;
82
83        let neg_sims: Vec<f32> = negatives
84            .iter()
85            .map(|n| Self::cosine_similarity(anchor, n) / self.temperature)
86            .collect();
87
88        // Compute softmax weights
89        let max_sim = neg_sims
90            .iter()
91            .copied()
92            .chain(std::iter::once(pos_sim))
93            .fold(f32::NEG_INFINITY, f32::max);
94
95        let pos_exp = (pos_sim - max_sim).exp();
96        let neg_exps: Vec<f32> = neg_sims.iter().map(|s| (s - max_sim).exp()).collect();
97        let total_exp: f32 = pos_exp + neg_exps.iter().sum::<f32>();
98
99        let pos_weight = pos_exp / total_exp;
100        let neg_weights: Vec<f32> = neg_exps.iter().map(|e| e / total_exp).collect();
101
102        // Loss value
103        let loss = -(pos_weight.ln());
104
105        // Gradient with respect to anchor
106        // ∂L/∂anchor = (p_pos - 1) * ∂sim(a,p)/∂a + Σ p_neg_i * ∂sim(a,n_i)/∂a
107        let norm_a: f32 = anchor.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
108        let norm_p: f32 = positive.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
109
110        let mut gradients = vec![0.0f32; dim];
111
112        // Gradient from positive
113        let dot_ap: f32 = anchor.iter().zip(positive.iter()).map(|(a, p)| a * p).sum();
114        for i in 0..dim {
115            let d_sim = (positive[i] / (norm_a * norm_p))
116                - (anchor[i] * dot_ap / (norm_a.powi(3) * norm_p));
117            gradients[i] += (pos_weight - 1.0) * d_sim / self.temperature;
118        }
119
120        // Gradient from negatives
121        for (neg, &weight) in negatives.iter().zip(neg_weights.iter()) {
122            let norm_n: f32 = neg.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
123            let dot_an: f32 = anchor.iter().zip(neg.iter()).map(|(a, n)| a * n).sum();
124
125            for i in 0..dim {
126                let d_sim =
127                    (neg[i] / (norm_a * norm_n)) - (anchor[i] * dot_an / (norm_a.powi(3) * norm_n));
128                gradients[i] += weight * d_sim / self.temperature;
129            }
130        }
131
132        (loss, gradients)
133    }
134}
135
136/// Local contrastive loss for neighborhood preservation
137pub struct LocalContrastiveLoss {
138    margin: f32,
139    reduction: Reduction,
140}
141
142impl LocalContrastiveLoss {
143    pub fn new(margin: f32) -> Self {
144        Self {
145            margin,
146            reduction: Reduction::Mean,
147        }
148    }
149
150    pub fn with_reduction(mut self, reduction: Reduction) -> Self {
151        self.reduction = reduction;
152        self
153    }
154
155    fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
156        a.iter()
157            .zip(b.iter())
158            .map(|(x, y)| (x - y).powi(2))
159            .sum::<f32>()
160            .sqrt()
161    }
162}
163
164impl Loss for LocalContrastiveLoss {
165    fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
166        let d_pos = Self::euclidean_distance(anchor, positive);
167
168        let losses: Vec<f32> = negatives
169            .iter()
170            .map(|neg| {
171                let d_neg = Self::euclidean_distance(anchor, neg);
172                (d_pos - d_neg + self.margin).max(0.0)
173            })
174            .collect();
175
176        match self.reduction {
177            Reduction::Mean => losses.iter().sum::<f32>() / losses.len().max(1) as f32,
178            Reduction::Sum => losses.iter().sum(),
179            Reduction::None => losses.first().copied().unwrap_or(0.0),
180        }
181    }
182
183    fn compute_with_gradients(
184        &self,
185        anchor: &[f32],
186        positive: &[f32],
187        negatives: &[&[f32]],
188    ) -> (f32, Vec<f32>) {
189        let dim = anchor.len();
190        let d_pos = Self::euclidean_distance(anchor, positive);
191
192        let mut total_loss = 0.0f32;
193        let mut gradients = vec![0.0f32; dim];
194        let mut active_count = 0;
195
196        for neg in negatives.iter() {
197            let d_neg = Self::euclidean_distance(anchor, neg);
198            let margin_loss = d_pos - d_neg + self.margin;
199
200            if margin_loss > 0.0 {
201                total_loss += margin_loss;
202                active_count += 1;
203
204                // Gradient: ∂L/∂a = (a - p)/d_pos - (a - n)/d_neg
205                for i in 0..dim {
206                    if d_pos > 1e-8 {
207                        gradients[i] += (anchor[i] - positive[i]) / d_pos;
208                    }
209                    if d_neg > 1e-8 {
210                        gradients[i] -= (anchor[i] - neg[i]) / d_neg;
211                    }
212                }
213            }
214        }
215
216        let loss = match self.reduction {
217            Reduction::Mean if active_count > 0 => {
218                gradients.iter_mut().for_each(|g| *g /= active_count as f32);
219                total_loss / active_count as f32
220            }
221            Reduction::Sum => total_loss,
222            _ => total_loss / negatives.len().max(1) as f32,
223        };
224
225        (loss, gradients)
226    }
227}
228
229/// Spectral regularization for smooth representations
230pub struct SpectralRegularization {
231    weight: f32,
232}
233
234impl SpectralRegularization {
235    pub fn new(weight: f32) -> Self {
236        Self { weight }
237    }
238
239    /// Compute spectral norm regularization for a batch of embeddings
240    pub fn compute_batch(&self, embeddings: &[&[f32]]) -> f32 {
241        if embeddings.is_empty() {
242            return 0.0;
243        }
244
245        let dim = embeddings[0].len();
246        let n = embeddings.len();
247
248        // Compute covariance matrix diagonal approximation
249        let mut var_sum = 0.0f32;
250
251        for d in 0..dim {
252            let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
253            let var: f32 = embeddings
254                .iter()
255                .map(|e| (e[d] - mean).powi(2))
256                .sum::<f32>()
257                / n as f32;
258            var_sum += var;
259        }
260
261        // Regularization: encourage uniform variance across dimensions
262        let avg_var = var_sum / dim as f32;
263        let var_of_var: f32 = {
264            let mut sum = 0.0;
265            for d in 0..dim {
266                let mean: f32 = embeddings.iter().map(|e| e[d]).sum::<f32>() / n as f32;
267                let var: f32 = embeddings
268                    .iter()
269                    .map(|e| (e[d] - mean).powi(2))
270                    .sum::<f32>()
271                    / n as f32;
272                sum += (var - avg_var).powi(2);
273            }
274            sum / dim as f32
275        };
276
277        self.weight * var_of_var
278    }
279}
280
281impl Loss for SpectralRegularization {
282    fn compute(&self, anchor: &[f32], positive: &[f32], negatives: &[&[f32]]) -> f32 {
283        let mut all_embeddings: Vec<&[f32]> = Vec::with_capacity(2 + negatives.len());
284        all_embeddings.push(anchor);
285        all_embeddings.push(positive);
286        all_embeddings.extend(negatives.iter().copied());
287
288        self.compute_batch(&all_embeddings)
289    }
290
291    fn compute_with_gradients(
292        &self,
293        anchor: &[f32],
294        positive: &[f32],
295        negatives: &[&[f32]],
296    ) -> (f32, Vec<f32>) {
297        let loss = self.compute(anchor, positive, negatives);
298        // Simplified: no gradient for spectral reg (typically used as auxiliary)
299        let gradients = vec![0.0f32; anchor.len()];
300        (loss, gradients)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_infonce_loss() {
310        let loss = InfoNCELoss::new(0.07);
311
312        let anchor = vec![1.0, 0.0, 0.0];
313        let positive = vec![0.9, 0.1, 0.0];
314        let negatives: Vec<Vec<f32>> = vec![vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0]];
315        let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
316
317        let loss_val = loss.compute(&anchor, &positive, &neg_refs);
318        assert!(loss_val >= 0.0);
319    }
320
321    #[test]
322    fn test_infonce_gradients() {
323        let loss = InfoNCELoss::new(0.1);
324
325        let anchor = vec![0.5; 64];
326        let positive = vec![0.6; 64];
327        let negatives: Vec<Vec<f32>> = vec![vec![0.1; 64]; 5];
328        let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
329
330        let (loss_val, grads) = loss.compute_with_gradients(&anchor, &positive, &neg_refs);
331
332        assert!(loss_val >= 0.0);
333        assert_eq!(grads.len(), 64);
334    }
335
336    #[test]
337    fn test_local_contrastive() {
338        let loss = LocalContrastiveLoss::new(1.0);
339
340        let anchor = vec![0.0, 0.0];
341        let positive = vec![0.1, 0.0]; // Close
342        let negatives: Vec<Vec<f32>> = vec![vec![2.0, 0.0], vec![0.0, 2.0]]; // Far
343        let neg_refs: Vec<&[f32]> = negatives.iter().map(|n| n.as_slice()).collect();
344
345        let loss_val = loss.compute(&anchor, &positive, &neg_refs);
346        assert!(loss_val >= 0.0);
347    }
348
349    #[test]
350    fn test_spectral_regularization() {
351        let reg = SpectralRegularization::new(0.01);
352
353        let embeddings: Vec<Vec<f32>> = (0..10).map(|i| vec![i as f32 * 0.1; 32]).collect();
354        let emb_refs: Vec<&[f32]> = embeddings.iter().map(|e| e.as_slice()).collect();
355
356        let loss_val = reg.compute_batch(&emb_refs);
357        assert!(loss_val >= 0.0);
358    }
359}