ruvector_attention/training/
mining.rs1#[derive(Clone, Copy, Debug, Default, PartialEq)]
7pub enum MiningStrategy {
8 #[default]
9 Random,
10 HardNegative,
11 SemiHard,
12 DistanceWeighted,
13}
14
15pub trait NegativeMiner: Send + Sync {
17 fn mine(
19 &self,
20 anchor: &[f32],
21 positive: &[f32],
22 candidates: &[&[f32]],
23 num_negatives: usize,
24 ) -> Vec<usize>;
25
26 fn strategy(&self) -> MiningStrategy;
28}
29
30pub struct HardNegativeMiner {
32 strategy: MiningStrategy,
33 margin: f32,
34 temperature: f32,
35}
36
37impl HardNegativeMiner {
38 pub fn new(strategy: MiningStrategy) -> Self {
39 Self {
40 strategy,
41 margin: 0.1,
42 temperature: 1.0,
43 }
44 }
45
46 pub fn with_margin(mut self, margin: f32) -> Self {
47 self.margin = margin;
48 self
49 }
50
51 pub fn with_temperature(mut self, temp: f32) -> Self {
52 self.temperature = temp;
53 self
54 }
55
56 fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
57 a.iter()
58 .zip(b.iter())
59 .map(|(x, y)| (x - y).powi(2))
60 .sum::<f32>()
61 .sqrt()
62 }
63
64 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
65 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
66 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
67 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-8);
68 dot / (norm_a * norm_b)
69 }
70
71 fn random_selection(num_candidates: usize, num_select: usize, seed: u64) -> Vec<usize> {
73 let mut indices: Vec<usize> = (0..num_candidates).collect();
74 let mut current_seed = seed;
75
76 for i in (1..indices.len()).rev() {
78 current_seed = current_seed
79 .wrapping_mul(6364136223846793005)
80 .wrapping_add(1);
81 let j = (current_seed as usize) % (i + 1);
82 indices.swap(i, j);
83 }
84
85 indices.truncate(num_select.min(num_candidates));
86 indices
87 }
88
89 fn hard_negative_selection(
91 &self,
92 anchor: &[f32],
93 candidates: &[&[f32]],
94 num_select: usize,
95 ) -> Vec<usize> {
96 let mut indexed_sims: Vec<(usize, f32)> = candidates
97 .iter()
98 .enumerate()
99 .map(|(i, c)| (i, Self::cosine_similarity(anchor, c)))
100 .collect();
101
102 indexed_sims.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
104
105 indexed_sims
106 .into_iter()
107 .take(num_select.min(candidates.len()))
108 .map(|(i, _)| i)
109 .collect()
110 }
111
112 fn semi_hard_selection(
114 &self,
115 anchor: &[f32],
116 positive: &[f32],
117 candidates: &[&[f32]],
118 num_select: usize,
119 ) -> Vec<usize> {
120 let d_pos = Self::euclidean_distance(anchor, positive);
121
122 let mut semi_hard: Vec<(usize, f32)> = candidates
123 .iter()
124 .enumerate()
125 .filter_map(|(i, c)| {
126 let d_neg = Self::euclidean_distance(anchor, c);
127 if d_neg > d_pos && d_neg < d_pos + self.margin {
129 Some((i, d_neg))
130 } else {
131 None
132 }
133 })
134 .collect();
135
136 semi_hard.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
138
139 let mut result: Vec<usize> = semi_hard.into_iter().map(|(i, _)| i).collect();
140
141 if result.len() < num_select {
143 let hard = self.hard_negative_selection(anchor, candidates, num_select - result.len());
144 for idx in hard {
145 if !result.contains(&idx) {
146 result.push(idx);
147 }
148 }
149 }
150
151 result.truncate(num_select);
152 result
153 }
154
155 fn distance_weighted_selection(
157 &self,
158 anchor: &[f32],
159 candidates: &[&[f32]],
160 num_select: usize,
161 ) -> Vec<usize> {
162 if candidates.is_empty() {
163 return vec![];
164 }
165
166 let sims: Vec<f32> = candidates
168 .iter()
169 .map(|c| Self::cosine_similarity(anchor, c) / self.temperature)
170 .collect();
171
172 let max_sim = sims.iter().copied().fold(f32::NEG_INFINITY, f32::max);
174 let exp_sims: Vec<f32> = sims.iter().map(|s| (s - max_sim).exp()).collect();
175 let sum_exp: f32 = exp_sims.iter().sum();
176 let probs: Vec<f32> = exp_sims.iter().map(|e| e / sum_exp).collect();
177
178 let mut remaining: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
180 let mut selected = Vec::with_capacity(num_select);
181 let mut seed = 42u64;
182
183 while selected.len() < num_select && !remaining.is_empty() {
184 seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
186 let r = (seed as f32) / (u64::MAX as f32);
187
188 let total: f32 = remaining.iter().map(|(_, p)| p).sum();
190 let mut cumsum = 0.0;
191 let mut select_idx = 0;
192
193 for (i, (_, p)) in remaining.iter().enumerate() {
194 cumsum += p / total;
195 if r < cumsum {
196 select_idx = i;
197 break;
198 }
199 }
200
201 let (orig_idx, _) = remaining.remove(select_idx);
202 selected.push(orig_idx);
203 }
204
205 selected
206 }
207}
208
209impl NegativeMiner for HardNegativeMiner {
210 fn mine(
211 &self,
212 anchor: &[f32],
213 positive: &[f32],
214 candidates: &[&[f32]],
215 num_negatives: usize,
216 ) -> Vec<usize> {
217 match self.strategy {
218 MiningStrategy::Random => Self::random_selection(candidates.len(), num_negatives, 42),
219 MiningStrategy::HardNegative => {
220 self.hard_negative_selection(anchor, candidates, num_negatives)
221 }
222 MiningStrategy::SemiHard => {
223 self.semi_hard_selection(anchor, positive, candidates, num_negatives)
224 }
225 MiningStrategy::DistanceWeighted => {
226 self.distance_weighted_selection(anchor, candidates, num_negatives)
227 }
228 }
229 }
230
231 fn strategy(&self) -> MiningStrategy {
232 self.strategy
233 }
234}
235
236pub struct InBatchMiner {
238 exclude_positive: bool,
239}
240
241impl InBatchMiner {
242 pub fn new() -> Self {
243 Self {
244 exclude_positive: true,
245 }
246 }
247
248 pub fn include_positive(mut self) -> Self {
249 self.exclude_positive = false;
250 self
251 }
252
253 pub fn get_negatives(
255 &self,
256 anchor_idx: usize,
257 positive_idx: usize,
258 batch_size: usize,
259 ) -> Vec<usize> {
260 (0..batch_size)
261 .filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
262 .collect()
263 }
264}
265
266impl Default for InBatchMiner {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_random_mining() {
278 let miner = HardNegativeMiner::new(MiningStrategy::Random);
279
280 let anchor = vec![1.0, 0.0, 0.0];
281 let positive = vec![0.9, 0.1, 0.0];
282 let candidates: Vec<Vec<f32>> = (0..20).map(|i| vec![i as f32 * 0.05; 3]).collect();
283 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
284
285 let selected = miner.mine(&anchor, &positive, &cand_refs, 5);
286 assert_eq!(selected.len(), 5);
287 }
288
289 #[test]
290 fn test_hard_negative_mining() {
291 let miner = HardNegativeMiner::new(MiningStrategy::HardNegative);
292
293 let anchor = vec![1.0, 0.0, 0.0];
294 let positive = vec![0.9, 0.1, 0.0];
295 let candidates: Vec<Vec<f32>> = vec![
297 vec![0.9, 0.1, 0.0], vec![0.5, 0.5, 0.0], vec![0.0, 1.0, 0.0], vec![0.0, 0.0, 1.0], ];
302 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
303
304 let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
305
306 assert!(selected.contains(&0)); }
309
310 #[test]
311 fn test_semi_hard_mining() {
312 let miner = HardNegativeMiner::new(MiningStrategy::SemiHard).with_margin(1.0);
313
314 let anchor = vec![0.0, 0.0];
315 let positive = vec![0.5, 0.0]; let candidates: Vec<Vec<f32>> = vec![
317 vec![0.3, 0.0], vec![0.7, 0.0], vec![1.0, 0.0], vec![3.0, 0.0], ];
322 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
323
324 let selected = miner.mine(&anchor, &positive, &cand_refs, 2);
325 assert!(!selected.is_empty());
326 }
327
328 #[test]
329 fn test_distance_weighted() {
330 let miner = HardNegativeMiner::new(MiningStrategy::DistanceWeighted).with_temperature(0.5);
331
332 let anchor = vec![1.0, 0.0];
333 let positive = vec![0.9, 0.1];
334 let candidates: Vec<Vec<f32>> = (0..10).map(|i| vec![0.1 * i as f32; 2]).collect();
335 let cand_refs: Vec<&[f32]> = candidates.iter().map(|c| c.as_slice()).collect();
336
337 let selected = miner.mine(&anchor, &positive, &cand_refs, 3);
338 assert_eq!(selected.len(), 3);
339 }
340
341 #[test]
342 fn test_in_batch_miner() {
343 let miner = InBatchMiner::new();
344
345 let negatives = miner.get_negatives(2, 5, 10);
346
347 assert!(!negatives.contains(&2)); assert!(!negatives.contains(&5)); assert_eq!(negatives.len(), 8);
350 }
351}