ruvector_attention/training/
mining.rs1#[derive(Clone, Copy, Debug, Default, PartialEq)]
7pub enum MiningStrategy {
8 #[default]
9 Random,
10 HardNegative,
11 SemiHard,
12 DistanceWeighted,
13}
14
15pub trait NegativeMiner: Send + Sync {
17 fn mine(
19 &self,
20 anchor: &[f32],
21 positive: &[f32],
22 candidates: &[&[f32]],
23 num_negatives: usize,
24 ) -> Vec<usize>;
25
26 fn strategy(&self) -> MiningStrategy;
28}
29
30pub struct HardNegativeMiner {
32 strategy: MiningStrategy,
33 margin: f32,
34 temperature: f32,
35}
36
37impl HardNegativeMiner {
38 pub fn new(strategy: MiningStrategy) -> Self {
39 Self {
40 strategy,
41 margin: 0.1,
42 temperature: 1.0,
43 }
44 }
45
46 pub fn with_margin(mut self, margin: f32) -> Self {
47 self.margin = margin;
48 self
49 }
50
51 pub fn with_temperature(mut self, temp: f32) -> Self {
52 self.temperature = temp;
53 self
54 }
55
56 fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
57 a.iter()
58 .zip(b.iter())
59 .map(|(x, y)| (x - y).powi(2))
60 .sum::<f32>()
61 .sqrt()
62 }
63
64 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
65 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
66 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
67 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
68 dot / (norm_a * norm_b)
69 }
70
71 fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
73 let mut indices: Vec<usize> = (0..num_candidates).collect();
74 let mut current_seed = seed;
75
76 for i in (1..indices.len()).rev() {
78 current_seed = current_seed.wrapping_mul(6364136223846793005).wrapping_add(1);
79 let j = (current_seed as usize) % (i + 1);
80 indices.swap(i, j);
81 }
82
83 indices.truncate(num_select.min(num_candidates));
84 indices
85 }
86
87 fn hard_negative_selection(
89 &self,
90 anchor: &[f32],
91 candidates: &[&[f32]],
92 num_select: usize,
93 ) -> Vec<usize> {
94 let mut indexed_sims: Vec<(usize, f32)> = candidates
95 .iter()
96 .enumerate()
97 .map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
98 .collect();
99
100 indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
102
103 indexed_sims
104 .into_iter()
105 .take(num_select.min(candidates.len()))
106 .map(|(i, _)| i)
107 .collect()
108 }
109
110 fn semi_hard_selection(
112 &self,
113 anchor: &[f32],
114 positive: &[f32],
115 candidates: &[&[f32]],
116 num_select: usize,
117 ) -> Vec<usize> {
118 let d_pos = Self::euclidean_distance(anchor, positive);
119
120 let mut semi_hard: Vec<(usize, f32)> = candidates
121 .iter()
122 .enumerate()
123 .filter_map(|(i, c)| {
124 let d_neg = Self::euclidean_distance(anchor, c);
125 if d_neg > d_pos && d_neg < d_pos + self.margin {
127 Some((i, d_neg))
128 } else {
129 None
130 }
131 })
132 .collect();
133
134 semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
136
137 let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
138
139 if result.len() < num_select {
141 let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
142 for idx in hard {
143 if !result.contains(&idx) {
144 result.push(idx);
145 }
146 }
147 }
148
149 result.truncate(num_select);
150 result
151 }
152
153 fn distance_weighted_selection(
155 &self,
156 anchor: &[f32],
157 candidates: &[&[f32]],
158 num_select: usize,
159 ) -> Vec<usize> {
160 if candidates.is_empty() {
161 return vec![];
162 }
163
164 let sims: Vec<f32> = candidates
166 .iter()
167 .map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
168 .collect();
169
170 let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
172 let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
173 let sum_exp: f32 = exp_sims.iter().sum();
174 let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
175
176 let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
178 let mut selected = Vec::with_capacity(num_select);
179 let mut seed = 42u64;
180
181 while selected.len() < num_select && !remaining.is_empty() {
182 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
184 let r = (seed as f32) / (u64::MAX as f32);
185
186 let total: f32 = remaining.iter().map(|(_, p)| p).sum();
188 let mut cumsum = 0.0;
189 let mut select_idx = 0;
190
191 for (i, (_, p)) in remaining.iter().enumerate() {
192 cumsum += p / total;
193 if r < cumsum {
194 select_idx = i;
195 break;
196 }
197 }
198
199 let (orig_idx, _) = remaining.remove(select_idx);
200 selected.push(orig_idx);
201 }
202
203 selected
204 }
205}
206
207impl NegativeMiner for HardNegativeMiner {
208 fn mine(
209 &self,
210 anchor: &[f32],
211 positive: &[f32],
212 candidates: &[&[f32]],
213 num_negatives: usize,
214 ) -> Vec<usize> {
215 match self.strategy {
216 MiningStrategy::Random => {
217 Self::random_selection(candidates.len(), num_negatives, 42)
218 }
219 MiningStrategy::HardNegative => {
220 self.hard_negative_selection(anchor, candidates, num_negatives)
221 }
222 MiningStrategy::SemiHard => {
223 self.semi_hard_selection(anchor, positive, candidates, num_negatives)
224 }
225 MiningStrategy::DistanceWeighted => {
226 self.distance_weighted_selection(anchor, candidates, num_negatives)
227 }
228 }
229 }
230
231 fn strategy(&self) -> MiningStrategy {
232 self.strategy
233 }
234}
235
236pub struct InBatchMiner {
238 exclude_positive: bool,
239}
240
241impl InBatchMiner {
242 pub fn new() -> Self {
243 Self {
244 exclude_positive: true,
245 }
246 }
247
248 pub fn include_positive(mut self) -> Self {
249 self.exclude_positive = false;
250 self
251 }
252
253 pub fn get_negatives(&self, anchor_idx: usize, positive_idx: usize, batch_size: usize) -> Vec<usize> {
255 (0..batch_size)
256 .filter(|&i| {
257 i != anchor_idx && (!self.exclude_positive || i != positive_idx)
258 })
259 .collect()
260 }
261}
262
263impl Default for InBatchMiner {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_random_mining() {
275 let miner = HardNegativeMiner::new(MiningStrategy::Random);
276
277 let anchor = vec![1.0, 0.0, 0.0];
278 let positive = vec![0.9, 0.1, 0.0];
279 let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
280 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
281
282 let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
283 assert_eq!(selected.len(), 5);
284 }
285
286 #[test]
287 fn test_hard_negative_mining() {
288 let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
289
290 let anchor = vec![1.0, 0.0, 0.0];
291 let positive = vec![0.9, 0.1, 0.0];
292 let candidates: Vec<Vec<f32>> = vec![
294 vec![0.9, 0.1, 0.0], vec![0.5, 0.5, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0], ];
299 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
300
301 let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
302
303 assert!(selected.contains(&0)); }
306
307 #[test]
308 fn test_semi_hard_mining() {
309 let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
310
311 let anchor = vec![0.0, 0.0];
312 let positive = vec![0.5, 0.0]; let candidates: Vec<Vec<f32>> = vec![
314 vec![0.3, 0.0], vec![0.7, 0.0], vec![1.0, 0.0], vec![3.0, 0.0], ];
319 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
320
321 let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
322 assert!(!selected.is_empty());
323 }
324
325 #[test]
326 fn test_distance_weighted() {
327 let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
328
329 let anchor = vec![1.0, 0.0];
330 let positive = vec![0.9, 0.1];
331 let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
332 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
333
334 let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
335 assert_eq!(selected.len(), 3);
336 }
337
338 #[test]
339 fn test_in_batch_miner() {
340 let miner = InBatchMiner::new();
341
342 let negatives = miner.get_negatives(2, 5, 10);
343
344 assert!(!negatives.contains(&2)); assert!(!negatives.contains(&5)); assert_eq!(negatives.len(), 8);
347 }
348}