ruvector_attention/training/
mining.rs

1//! Hard negative mining strategies
2//!
3//! Provides various methods for selecting informative negative samples.
4
5/// Mining strategy enumeration
6#[derive(Clone, Copy, Debug, Default, PartialEq)]
7pub enum MiningStrategy {
8    #[default]
9    Random,
10    HardNegative,
11    SemiHard,
12    DistanceWeighted,
13}
14
15/// Trait for negative sample mining
16pub trait NegativeMiner: Send + Sync {
17    /// Mine negatives for an anchor from a candidate pool
18    fn mine(
19        &self,
20        anchor: &[f32],
21        positive: &[f32],
22        candidates: &[&[f32]],
23        num_negatives: usize,
24    ) -> Vec<usize>;
25
26    /// Get mining strategy
27    fn strategy(&self) -> MiningStrategy;
28}
29
30/// Hard negative miner that selects closest negatives
31pub struct HardNegativeMiner {
32    strategy: MiningStrategy,
33    margin: f32,
34    temperature: f32,
35}
36
37impl HardNegativeMiner {
38    pub fn new(strategy: MiningStrategy) -> Self {
39        Self {
40            strategy,
41            margin: 0.1,
42            temperature: 1.0,
43        }
44    }
45
46    pub fn with_margin(mut self, margin: f32) -> Self {
47        self.margin = margin;
48        self
49    }
50
51    pub fn with_temperature(mut self, temp: f32) -> Self {
52        self.temperature = temp;
53        self
54    }
55
56    fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
57        a.iter()
58            .zip(b.iter())
59            .map(|(x, y)| (x - y).powi(2))
60            .sum::<f32>()
61            .sqrt()
62    }
63
64    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
65        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
66        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
67        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
68        dot / (norm_a * norm_b)
69    }
70
71    /// Select random indices
72    fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
73        let mut indices: Vec<usize> = (0..num_candidates).collect();
74        let mut current_seed = seed;
75
76        // Fisher-Yates shuffle
77        for i in (1..indices.len()).rev() {
78            current_seed = current_seed
79                .wrapping_mul(6364136223846793005)
80                .wrapping_add(1);
81            let j = (current_seed as usize) % (i + 1);
82            indices.swap(i, j);
83        }
84
85        indices.truncate(num_select.min(num_candidates));
86        indices
87    }
88
89    /// Select hardest negatives (closest to anchor)
90    fn hard_negative_selection(
91        &self,
92        anchor: &[f32],
93        candidates: &[&[f32]],
94        num_select: usize,
95    ) -> Vec<usize> {
96        let mut indexed_sims: Vec<(usize, f32)> = candidates
97            .iter()
98            .enumerate()
99            .map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
100            .collect();
101
102        // Sort by similarity descending (higher sim = harder negative)
103        indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
104
105        indexed_sims
106            .into_iter()
107            .take(num_select.min(candidates.len()))
108            .map(|(i, _)| i)
109            .collect()
110    }
111
112    /// Select semi-hard negatives (within margin of positive)
113    fn semi_hard_selection(
114        &self,
115        anchor: &[f32],
116        positive: &[f32],
117        candidates: &[&[f32]],
118        num_select: usize,
119    ) -> Vec<usize> {
120        let d_pos = Self::euclidean_distance(anchor, positive);
121
122        let mut semi_hard: Vec<(usize, f32)> = candidates
123            .iter()
124            .enumerate()
125            .filter_map(|(i, c)| {
126                let d_neg = Self::euclidean_distance(anchor, c);
127                // Semi-hard: d_pos < d_neg < d_pos + margin
128                if d_neg > d_pos && d_neg < d_pos + self.margin {
129                    Some((i, d_neg))
130                } else {
131                    None
132                }
133            })
134            .collect();
135
136        // Sort by distance (prefer harder ones)
137        semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
138
139        let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
140
141        // If not enough semi-hard, fill with hard negatives
142        if result.len() < num_select {
143            let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
144            for idx in hard {
145                if !result.contains(&idx) {
146                    result.push(idx);
147                }
148            }
149        }
150
151        result.truncate(num_select);
152        result
153    }
154
155    /// Distance-weighted sampling
156    fn distance_weighted_selection(
157        &self,
158        anchor: &[f32],
159        candidates: &[&[f32]],
160        num_select: usize,
161    ) -> Vec<usize> {
162        if candidates.is_empty() {
163            return vec![];
164        }
165
166        // Compute weights based on similarity (closer = higher weight)
167        let sims: Vec<f32> = candidates
168            .iter()
169            .map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
170            .collect();
171
172        // Softmax weights
173        let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
174        let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
175        let sum_exp: f32 = exp_sims.iter().sum();
176        let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
177
178        // Sample without replacement using the probabilities
179        let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
180        let mut selected = Vec::with_capacity(num_select);
181        let mut seed = 42u64;
182
183        while selected.len() < num_select && !remaining.is_empty() {
184            // Random value
185            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
186            let r = (seed as f32) / (u64::MAX as f32);
187
188            // Select based on cumulative probability
189            let total: f32 = remaining.iter().map(|(_, p)| p).sum();
190            let mut cumsum = 0.0;
191            let mut select_idx = 0;
192
193            for (i, (_, p)) in remaining.iter().enumerate() {
194                cumsum += p / total;
195                if r < cumsum {
196                    select_idx = i;
197                    break;
198                }
199            }
200
201            let (orig_idx, _) = remaining.remove(select_idx);
202            selected.push(orig_idx);
203        }
204
205        selected
206    }
207}
208
209impl NegativeMiner for HardNegativeMiner {
210    fn mine(
211        &self,
212        anchor: &[f32],
213        positive: &[f32],
214        candidates: &[&[f32]],
215        num_negatives: usize,
216    ) -> Vec<usize> {
217        match self.strategy {
218            MiningStrategy::Random => Self::random_selection(candidates.len(), num_negatives, 42),
219            MiningStrategy::HardNegative => {
220                self.hard_negative_selection(anchor, candidates, num_negatives)
221            }
222            MiningStrategy::SemiHard => {
223                self.semi_hard_selection(anchor, positive, candidates, num_negatives)
224            }
225            MiningStrategy::DistanceWeighted => {
226                self.distance_weighted_selection(anchor, candidates, num_negatives)
227            }
228        }
229    }
230
231    fn strategy(&self) -> MiningStrategy {
232        self.strategy
233    }
234}
235
236/// In-batch negative mining (uses other batch items as negatives)
237pub struct InBatchMiner {
238    exclude_positive: bool,
239}
240
241impl InBatchMiner {
242    pub fn new() -> Self {
243        Self {
244            exclude_positive: true,
245        }
246    }
247
248    pub fn include_positive(mut self) -> Self {
249        self.exclude_positive = false;
250        self
251    }
252
253    /// Get negative indices from a batch for a given anchor index
254    pub fn get_negatives(
255        &self,
256        anchor_idx: usize,
257        positive_idx: usize,
258        batch_size: usize,
259    ) -> Vec<usize> {
260        (0..batch_size)
261            .filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
262            .collect()
263    }
264}
265
266impl Default for InBatchMiner {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_random_mining() {
278        let miner = HardNegativeMiner::new(MiningStrategy::Random);
279
280        let anchor = vec![1.0, 0.0, 0.0];
281        let positive = vec![0.9, 0.1, 0.0];
282        let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
283        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
284
285        let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
286        assert_eq!(selected.len(), 5);
287    }
288
289    #[test]
290    fn test_hard_negative_mining() {
291        let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
292
293        let anchor = vec![1.0, 0.0, 0.0];
294        let positive = vec![0.9, 0.1, 0.0];
295        // Create candidates with varying similarity to anchor
296        let candidates: Vec<Vec<f32>> = vec![
297            vec![0.9, 0.1, 0.0], // Similar to anchor
298            vec![0.5, 0.5, 0.0], // Medium
299            vec![0.0, 1.0, 0.0], // Different
300            vec![0.0, 0.0, 1.0], // Different
301        ];
302        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
303
304        let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
305
306        // Should select the most similar ones first
307        assert!(selected.contains(&0)); // Most similar
308    }
309
310    #[test]
311    fn test_semi_hard_mining() {
312        let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
313
314        let anchor = vec![0.0, 0.0];
315        let positive = vec![0.5, 0.0]; // Distance 0.5
316        let candidates: Vec<Vec<f32>> = vec![
317            vec![0.3, 0.0], // Too easy (d = 0.3 < 0.5)
318            vec![0.7, 0.0], // Semi-hard (0.5 < 0.7 < 1.5)
319            vec![1.0, 0.0], // Semi-hard
320            vec![3.0, 0.0], // Too hard (d = 3.0 > 1.5)
321        ];
322        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
323
324        let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
325        assert!(!selected.is_empty());
326    }
327
328    #[test]
329    fn test_distance_weighted() {
330        let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
331
332        let anchor = vec![1.0, 0.0];
333        let positive = vec![0.9, 0.1];
334        let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
335        let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
336
337        let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
338        assert_eq!(selected.len(), 3);
339    }
340
341    #[test]
342    fn test_in_batch_miner() {
343        let miner = InBatchMiner::new();
344
345        let negatives = miner.get_negatives(2, 5, 10);
346
347        assert!(!negatives.contains(&2)); // Exclude anchor
348        assert!(!negatives.contains(&5)); // Exclude positive
349        assert_eq!(negatives.len(), 8);
350    }
351}