ruvector_attention/training/
mining.rs

1//! Hard negative mining strategies
2//!
3//! Provides various methods for selecting informative negative samples.
4
5/// Mining strategy enumeration
6#[derive(Clone, Copy, Debug, Default, PartialEq)]
7pub enum MiningStrategy {
8    #[default]
9    Random,
10    HardNegative,
11    SemiHard,
12    DistanceWeighted,
13}
14
15/// Trait for negative sample mining
16pub trait NegativeMiner: Send + Sync {
17    /// Mine negatives for an anchor from a candidate pool
18    fn mine(
19        &self,
20        anchor: &[f32],
21        positive: &[f32],
22        candidates: &[&[f32]],
23        num_negatives: usize,
24    ) -> Vec<usize>;
25
26    /// Get mining strategy
27    fn strategy(&self) -> MiningStrategy;
28}
29
30/// Hard negative miner that selects closest negatives
31pub struct HardNegativeMiner {
32    strategy: MiningStrategy,
33    margin: f32,
34    temperature: f32,
35}
36
37impl HardNegativeMiner {
38    pub fn new(strategy: MiningStrategy) -> Self {
39        Self {
40            strategy,
41            margin: 0.1,
42            temperature: 1.0,
43        }
44    }
45
46    pub fn with_margin(mut self, margin: f32) -> Self {
47        self.margin = margin;
48        self
49    }
50
51    pub fn with_temperature(mut self, temp: f32) -> Self {
52        self.temperature = temp;
53        self
54    }
55
56    fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
57        a.iter()
58            .zip(b.iter())
59            .map(|(x, y)| (x - y).powi(2))
60            .sum::<f32>()
61            .sqrt()
62    }
63
64    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
65        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
66        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
67        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
68        dot / (norm_a * norm_b)
69    }
70
71    /// Select random indices
72    fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
73        let mut indices: Vec<usize> = (0..num_candidates).collect();
74        let mut current_seed = seed;
75
76        // Fisher-Yates shuffle
77        for i in (1..indices.len()).rev() {
78            current_seed = current_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
79            let j = (current_seed as usize) % (i + 1);
80            indices.swap(i, j);
81        }
82
83        indices.truncate(num_select.min(num_candidates));
84        indices
85    }
86
87    /// Select hardest negatives (closest to anchor)
88    fn hard_negative_selection(
89        &self,
90        anchor: &[f32],
91        candidates: &[&[f32]],
92        num_select: usize,
93    ) -> Vec<usize> {
94        let mut indexed_sims: Vec<(usize, f32)> = candidates
95            .iter()
96            .enumerate()
97            .map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
98            .collect();
99
100        // Sort by similarity descending (higher sim = harder negative)
101        indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
102
103        indexed_sims
104            .into_iter()
105            .take(num_select.min(candidates.len()))
106            .map(|(i, _)| i)
107            .collect()
108    }
109
110    /// Select semi-hard negatives (within margin of positive)
111    fn semi_hard_selection(
112        &self,
113        anchor: &[f32],
114        positive: &[f32],
115        candidates: &[&[f32]],
116        num_select: usize,
117    ) -> Vec<usize> {
118        let d_pos = Self::euclidean_distance(anchor, positive);
119
120        let mut semi_hard: Vec<(usize, f32)> = candidates
121            .iter()
122            .enumerate()
123            .filter_map(|(i, c)| {
124                let d_neg = Self::euclidean_distance(anchor, c);
125                // Semi-hard: d_pos < d_neg < d_pos + margin
126                if d_neg > d_pos && d_neg < d_pos + self.margin {
127                    Some((i, d_neg))
128                } else {
129                    None
130                }
131            })
132            .collect();
133
134        // Sort by distance (prefer harder ones)
135        semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
136
137        let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
138
139        // If not enough semi-hard, fill with hard negatives
140        if result.len() < num_select {
141            let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
142            for idx in hard {
143                if !result.contains(&idx) {
144                    result.push(idx);
145                }
146            }
147        }
148
149        result.truncate(num_select);
150        result
151    }
152
153    /// Distance-weighted sampling
154    fn distance_weighted_selection(
155        &self,
156        anchor: &[f32],
157        candidates: &[&[f32]],
158        num_select: usize,
159    ) -> Vec<usize> {
160        if candidates.is_empty() {
161            return vec![];
162        }
163
164        // Compute weights based on similarity (closer = higher weight)
165        let sims: Vec<f32> = candidates
166            .iter()
167            .map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
168            .collect();
169
170        // Softmax weights
171        let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
172        let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
173        let sum_exp: f32 = exp_sims.iter().sum();
174        let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
175
176        // Sample without replacement using the probabilities
177        let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
178        let mut selected = Vec::with_capacity(num_select);
179        let mut seed = 42u64;
180
181        while selected.len() < num_select && !remaining.is_empty() {
182            // Random value
183            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
184            let r = (seed as f32) / (u64::MAX as f32);
185
186            // Select based on cumulative probability
187            let total: f32 = remaining.iter().map(|(_, p)| p).sum();
188            let mut cumsum = 0.0;
189            let mut select_idx = 0;
190
191            for (i, (_, p)) in remaining.iter().enumerate() {
192                cumsum += p / total;
193                if r < cumsum {
194                    select_idx = i;
195                    break;
196                }
197            }
198
199            let (orig_idx, _) = remaining.remove(select_idx);
200            selected.push(orig_idx);
201        }
202
203        selected
204    }
205}
206
207impl NegativeMiner for HardNegativeMiner {
208    fn mine(
209        &self,
210        anchor: &[f32],
211        positive: &[f32],
212        candidates: &[&[f32]],
213        num_negatives: usize,
214    ) -> Vec<usize> {
215        match self.strategy {
216            MiningStrategy::Random => {
217                Self::random_selection(candidates.len(), num_negatives, 42)
218            }
219            MiningStrategy::HardNegative => {
220                self.hard_negative_selection(anchor, candidates, num_negatives)
221            }
222            MiningStrategy::SemiHard => {
223                self.semi_hard_selection(anchor, positive, candidates, num_negatives)
224            }
225            MiningStrategy::DistanceWeighted => {
226                self.distance_weighted_selection(anchor, candidates, num_negatives)
227            }
228        }
229    }
230
231    fn strategy(&self) -> MiningStrategy {
232        self.strategy
233    }
234}
235
236/// In-batch negative mining (uses other batch items as negatives)
237pub struct InBatchMiner {
238    exclude_positive: bool,
239}
240
241impl InBatchMiner {
242    pub fn new() -> Self {
243        Self {
244            exclude_positive: true,
245        }
246    }
247
248    pub fn include_positive(mut self) -> Self {
249        self.exclude_positive = false;
250        self
251    }
252
253    /// Get negative indices from a batch for a given anchor index
254    pub fn get_negatives(&self, anchor_idx: usize, positive_idx: usize, batch_size: usize) -> Vec<usize> {
255        (0..batch_size)
256            .filter(|&i| {
257                i != anchor_idx && (!self.exclude_positive || i != positive_idx)
258            })
259            .collect()
260    }
261}
262
263impl Default for InBatchMiner {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_random_mining() {
275        let miner = HardNegativeMiner::new(MiningStrategy::Random);
276
277        let anchor = vec![1.0, 0.0, 0.0];
278        let positive = vec![0.9, 0.1, 0.0];
279        let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
280        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
281
282        let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
283        assert_eq!(selected.len(), 5);
284    }
285
286    #[test]
287    fn test_hard_negative_mining() {
288        let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
289
290        let anchor = vec![1.0, 0.0, 0.0];
291        let positive = vec![0.9, 0.1, 0.0];
292        // Create candidates with varying similarity to anchor
293        let candidates: Vec<Vec<f32>> = vec![
294            vec![0.9, 0.1, 0.0],  // Similar to anchor
295            vec![0.5, 0.5, 0.0],  // Medium
296            vec![0.0, 1.0, 0.0],  // Different
297            vec![0.0, 0.0, 1.0],  // Different
298        ];
299        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
300
301        let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
302
303        // Should select the most similar ones first
304        assert!(selected.contains(&0)); // Most similar
305    }
306
307    #[test]
308    fn test_semi_hard_mining() {
309        let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
310
311        let anchor = vec![0.0, 0.0];
312        let positive = vec![0.5, 0.0]; // Distance 0.5
313        let candidates: Vec<Vec<f32>> = vec![
314            vec![0.3, 0.0],  // Too easy (d = 0.3 < 0.5)
315            vec![0.7, 0.0],  // Semi-hard (0.5 < 0.7 < 1.5)
316            vec![1.0, 0.0],  // Semi-hard
317            vec![3.0, 0.0],  // Too hard (d = 3.0 > 1.5)
318        ];
319        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
320
321        let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
322        assert!(!selected.is_empty());
323    }
324
325    #[test]
326    fn test_distance_weighted() {
327        let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
328
329        let anchor = vec![1.0, 0.0];
330        let positive = vec![0.9, 0.1];
331        let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
332        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
333
334        let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
335        assert_eq!(selected.len(), 3);
336    }
337
338    #[test]
339    fn test_in_batch_miner() {
340        let miner = InBatchMiner::new();
341
342        let negatives = miner.get_negatives(2, 5, 10);
343
344        assert!(!negatives.contains(&2)); // Exclude anchor
345        assert!(!negatives.contains(&5)); // Exclude positive
346        assert_eq!(negatives.len(), 8);
347    }
348}