ruvector_attention_node/
training.rs

1//! NAPI-RS bindings for training utilities
2//!
3//! Provides Node.js bindings for:
4//! - Loss functions (InfoNCE, LocalContrastive, SpectralRegularization)
5//! - Optimizers (SGD, Adam, AdamW)
6//! - Learning rate schedulers
7//! - Curriculum learning
8//! - Negative mining
9
10use napi::bindgen_prelude::*;
11use napi_derive::napi;
12use ruvector_attention::training::{
13    InfoNCELoss as RustInfoNCE,
14    LocalContrastiveLoss as RustLocalContrastive,
15    SpectralRegularization as RustSpectralReg,
16    Loss,
17    SGD as RustSGD,
18    Adam as RustAdam,
19    AdamW as RustAdamW,
20    Optimizer,
21    CurriculumScheduler as RustCurriculum,
22    CurriculumStage as RustStage,
23    TemperatureAnnealing as RustTempAnnealing,
24    DecayType as RustDecayType,
25    HardNegativeMiner as RustHardMiner,
26    MiningStrategy as RustMiningStrategy,
27    NegativeMiner,
28};
29
30// ============================================================================
31// Loss Functions
32// ============================================================================
33
34/// InfoNCE contrastive loss for representation learning
35#[napi]
36pub struct InfoNCELoss {
37    inner: RustInfoNCE,
38    temperature_value: f32,
39}
40
41#[napi]
42impl InfoNCELoss {
43    /// Create a new InfoNCE loss instance
44    ///
45    /// # Arguments
46    /// * `temperature` - Temperature parameter for softmax (typically 0.07-0.1)
47    #[napi(constructor)]
48    pub fn new(temperature: f64) -> Self {
49        Self {
50            inner: RustInfoNCE::new(temperature as f32),
51            temperature_value: temperature as f32,
52        }
53    }
54
55    /// Compute InfoNCE loss
56    ///
57    /// # Arguments
58    /// * `anchor` - Anchor embedding
59    /// * `positive` - Positive example embedding
60    /// * `negatives` - Array of negative example embeddings
61    #[napi]
62    pub fn compute(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> f64 {
63        let anchor_slice = anchor.as_ref();
64        let positive_slice = positive.as_ref();
65        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
66        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
67
68        self.inner.compute(anchor_slice, positive_slice, &negatives_refs) as f64
69    }
70
71    /// Compute InfoNCE loss with gradients
72    ///
73    /// Returns an object with `loss` and `gradients` fields
74    #[napi]
75    pub fn compute_with_gradients(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> LossWithGradients {
76        let anchor_slice = anchor.as_ref();
77        let positive_slice = positive.as_ref();
78        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
79        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
80
81        let (loss, gradients) = self.inner.compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
82
83        LossWithGradients {
84            loss: loss as f64,
85            gradients: Float32Array::new(gradients),
86        }
87    }
88
89    /// Get the temperature
90    #[napi(getter)]
91    pub fn temperature(&self) -> f64 {
92        self.temperature_value as f64
93    }
94}
95
96/// Loss computation result with gradients
97#[napi(object)]
98pub struct LossWithGradients {
99    pub loss: f64,
100    pub gradients: Float32Array,
101}
102
103/// Local contrastive loss for neighborhood preservation
104#[napi]
105pub struct LocalContrastiveLoss {
106    inner: RustLocalContrastive,
107    margin_value: f32,
108}
109
110#[napi]
111impl LocalContrastiveLoss {
112    /// Create a new local contrastive loss instance
113    ///
114    /// # Arguments
115    /// * `margin` - Margin for triplet loss
116    #[napi(constructor)]
117    pub fn new(margin: f64) -> Self {
118        Self {
119            inner: RustLocalContrastive::new(margin as f32),
120            margin_value: margin as f32,
121        }
122    }
123
124    /// Compute local contrastive loss
125    #[napi]
126    pub fn compute(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> f64 {
127        let anchor_slice = anchor.as_ref();
128        let positive_slice = positive.as_ref();
129        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
130        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
131
132        self.inner.compute(anchor_slice, positive_slice, &negatives_refs) as f64
133    }
134
135    /// Compute with gradients
136    #[napi]
137    pub fn compute_with_gradients(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> LossWithGradients {
138        let anchor_slice = anchor.as_ref();
139        let positive_slice = positive.as_ref();
140        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
141        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
142
143        let (loss, gradients) = self.inner.compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
144
145        LossWithGradients {
146            loss: loss as f64,
147            gradients: Float32Array::new(gradients),
148        }
149    }
150
151    /// Get the margin
152    #[napi(getter)]
153    pub fn margin(&self) -> f64 {
154        self.margin_value as f64
155    }
156}
157
158/// Spectral regularization for smooth representations
159#[napi]
160pub struct SpectralRegularization {
161    inner: RustSpectralReg,
162    weight_value: f32,
163}
164
165#[napi]
166impl SpectralRegularization {
167    /// Create a new spectral regularization instance
168    ///
169    /// # Arguments
170    /// * `weight` - Regularization weight
171    #[napi(constructor)]
172    pub fn new(weight: f64) -> Self {
173        Self {
174            inner: RustSpectralReg::new(weight as f32),
175            weight_value: weight as f32,
176        }
177    }
178
179    /// Compute spectral regularization for a batch of embeddings
180    #[napi]
181    pub fn compute_batch(&self, embeddings: Vec<Float32Array>) -> f64 {
182        let embeddings_vec: Vec<Vec<f32>> = embeddings.into_iter().map(|e| e.to_vec()).collect();
183        let embeddings_refs: Vec<&[f32]> = embeddings_vec.iter().map(|e| e.as_slice()).collect();
184
185        self.inner.compute_batch(&embeddings_refs) as f64
186    }
187
188    /// Get the weight
189    #[napi(getter)]
190    pub fn weight(&self) -> f64 {
191        self.weight_value as f64
192    }
193}
194
195// ============================================================================
196// Optimizers
197// ============================================================================
198
199/// SGD optimizer with optional momentum and weight decay
200#[napi]
201pub struct SGDOptimizer {
202    inner: RustSGD,
203}
204
205#[napi]
206impl SGDOptimizer {
207    /// Create a new SGD optimizer
208    ///
209    /// # Arguments
210    /// * `param_count` - Number of parameters
211    /// * `learning_rate` - Learning rate
212    #[napi(constructor)]
213    pub fn new(param_count: u32, learning_rate: f64) -> Self {
214        Self {
215            inner: RustSGD::new(param_count as usize, learning_rate as f32),
216        }
217    }
218
219    /// Create with momentum
220    #[napi(factory)]
221    pub fn with_momentum(param_count: u32, learning_rate: f64, momentum: f64) -> Self {
222        Self {
223            inner: RustSGD::new(param_count as usize, learning_rate as f32)
224                .with_momentum(momentum as f32),
225        }
226    }
227
228    /// Create with momentum and weight decay
229    #[napi(factory)]
230    pub fn with_weight_decay(param_count: u32, learning_rate: f64, momentum: f64, weight_decay: f64) -> Self {
231        Self {
232            inner: RustSGD::new(param_count as usize, learning_rate as f32)
233                .with_momentum(momentum as f32)
234                .with_weight_decay(weight_decay as f32),
235        }
236    }
237
238    /// Perform an optimization step
239    ///
240    /// # Arguments
241    /// * `params` - Parameter array
242    /// * `gradients` - Gradient array
243    ///
244    /// # Returns
245    /// Updated parameter array
246    #[napi]
247    pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
248        let mut params_vec = params.to_vec();
249        let gradients_slice = gradients.as_ref();
250        self.inner.step(&mut params_vec, gradients_slice);
251        Float32Array::new(params_vec)
252    }
253
254    /// Reset optimizer state
255    #[napi]
256    pub fn reset(&mut self) {
257        self.inner.reset();
258    }
259
260    /// Get current learning rate
261    #[napi(getter)]
262    pub fn learning_rate(&self) -> f64 {
263        self.inner.learning_rate() as f64
264    }
265
266    /// Set learning rate
267    #[napi(setter)]
268    pub fn set_learning_rate(&mut self, lr: f64) {
269        self.inner.set_learning_rate(lr as f32);
270    }
271}
272
273/// Adam optimizer with bias correction
274#[napi]
275pub struct AdamOptimizer {
276    inner: RustAdam,
277}
278
279#[napi]
280impl AdamOptimizer {
281    /// Create a new Adam optimizer
282    ///
283    /// # Arguments
284    /// * `param_count` - Number of parameters
285    /// * `learning_rate` - Learning rate
286    #[napi(constructor)]
287    pub fn new(param_count: u32, learning_rate: f64) -> Self {
288        Self {
289            inner: RustAdam::new(param_count as usize, learning_rate as f32),
290        }
291    }
292
293    /// Create with custom betas
294    #[napi(factory)]
295    pub fn with_betas(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64) -> Self {
296        Self {
297            inner: RustAdam::new(param_count as usize, learning_rate as f32)
298                .with_betas(beta1 as f32, beta2 as f32),
299        }
300    }
301
302    /// Create with full configuration
303    #[napi(factory)]
304    pub fn with_config(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64, weight_decay: f64) -> Self {
305        Self {
306            inner: RustAdam::new(param_count as usize, learning_rate as f32)
307                .with_betas(beta1 as f32, beta2 as f32)
308                .with_epsilon(epsilon as f32)
309                .with_weight_decay(weight_decay as f32),
310        }
311    }
312
313    /// Perform an optimization step
314    ///
315    /// # Returns
316    /// Updated parameter array
317    #[napi]
318    pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
319        let mut params_vec = params.to_vec();
320        let gradients_slice = gradients.as_ref();
321        self.inner.step(&mut params_vec, gradients_slice);
322        Float32Array::new(params_vec)
323    }
324
325    /// Reset optimizer state (momentum terms)
326    #[napi]
327    pub fn reset(&mut self) {
328        self.inner.reset();
329    }
330
331    /// Get current learning rate
332    #[napi(getter)]
333    pub fn learning_rate(&self) -> f64 {
334        self.inner.learning_rate() as f64
335    }
336
337    /// Set learning rate
338    #[napi(setter)]
339    pub fn set_learning_rate(&mut self, lr: f64) {
340        self.inner.set_learning_rate(lr as f32);
341    }
342}
343
344/// AdamW optimizer (Adam with decoupled weight decay)
345#[napi]
346pub struct AdamWOptimizer {
347    inner: RustAdamW,
348    wd: f32,
349}
350
351#[napi]
352impl AdamWOptimizer {
353    /// Create a new AdamW optimizer
354    ///
355    /// # Arguments
356    /// * `param_count` - Number of parameters
357    /// * `learning_rate` - Learning rate
358    /// * `weight_decay` - Weight decay coefficient
359    #[napi(constructor)]
360    pub fn new(param_count: u32, learning_rate: f64, weight_decay: f64) -> Self {
361        Self {
362            inner: RustAdamW::new(param_count as usize, learning_rate as f32)
363                .with_weight_decay(weight_decay as f32),
364            wd: weight_decay as f32,
365        }
366    }
367
368    /// Create with custom betas
369    #[napi(factory)]
370    pub fn with_betas(param_count: u32, learning_rate: f64, weight_decay: f64, beta1: f64, beta2: f64) -> Self {
371        Self {
372            inner: RustAdamW::new(param_count as usize, learning_rate as f32)
373                .with_weight_decay(weight_decay as f32)
374                .with_betas(beta1 as f32, beta2 as f32),
375            wd: weight_decay as f32,
376        }
377    }
378
379    /// Perform an optimization step
380    ///
381    /// # Returns
382    /// Updated parameter array
383    #[napi]
384    pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
385        let mut params_vec = params.to_vec();
386        let gradients_slice = gradients.as_ref();
387        self.inner.step(&mut params_vec, gradients_slice);
388        Float32Array::new(params_vec)
389    }
390
391    /// Reset optimizer state
392    #[napi]
393    pub fn reset(&mut self) {
394        self.inner.reset();
395    }
396
397    /// Get current learning rate
398    #[napi(getter)]
399    pub fn learning_rate(&self) -> f64 {
400        self.inner.learning_rate() as f64
401    }
402
403    /// Set learning rate
404    #[napi(setter)]
405    pub fn set_learning_rate(&mut self, lr: f64) {
406        self.inner.set_learning_rate(lr as f32);
407    }
408
409    /// Get weight decay
410    #[napi(getter)]
411    pub fn weight_decay(&self) -> f64 {
412        self.wd as f64
413    }
414}
415
416// ============================================================================
417// Learning Rate Scheduling
418// ============================================================================
419
420/// Learning rate scheduler with warmup and cosine decay
421#[napi]
422pub struct LearningRateScheduler {
423    initial_lr: f32,
424    current_step: usize,
425    warmup_steps: usize,
426    total_steps: usize,
427    min_lr: f32,
428}
429
430#[napi]
431impl LearningRateScheduler {
432    /// Create a new learning rate scheduler
433    ///
434    /// # Arguments
435    /// * `initial_lr` - Initial/peak learning rate
436    /// * `warmup_steps` - Number of warmup steps
437    /// * `total_steps` - Total training steps
438    #[napi(constructor)]
439    pub fn new(initial_lr: f64, warmup_steps: u32, total_steps: u32) -> Self {
440        Self {
441            initial_lr: initial_lr as f32,
442            current_step: 0,
443            warmup_steps: warmup_steps as usize,
444            total_steps: total_steps as usize,
445            min_lr: 1e-7,
446        }
447    }
448
449    /// Create with minimum learning rate
450    #[napi(factory)]
451    pub fn with_min_lr(initial_lr: f64, warmup_steps: u32, total_steps: u32, min_lr: f64) -> Self {
452        Self {
453            initial_lr: initial_lr as f32,
454            current_step: 0,
455            warmup_steps: warmup_steps as usize,
456            total_steps: total_steps as usize,
457            min_lr: min_lr as f32,
458        }
459    }
460
461    /// Get learning rate for current step
462    #[napi]
463    pub fn get_lr(&self) -> f64 {
464        if self.current_step < self.warmup_steps {
465            // Linear warmup
466            (self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32) as f64
467        } else {
468            // Cosine decay
469            let progress = (self.current_step - self.warmup_steps) as f32
470                / (self.total_steps - self.warmup_steps).max(1) as f32;
471            let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
472            (self.min_lr + (self.initial_lr - self.min_lr) * decay) as f64
473        }
474    }
475
476    /// Step the scheduler and return current learning rate
477    #[napi]
478    pub fn step(&mut self) -> f64 {
479        let lr = self.get_lr();
480        self.current_step += 1;
481        lr
482    }
483
484    /// Reset scheduler to initial state
485    #[napi]
486    pub fn reset(&mut self) {
487        self.current_step = 0;
488    }
489
490    /// Get current step
491    #[napi(getter)]
492    pub fn current_step(&self) -> u32 {
493        self.current_step as u32
494    }
495
496    /// Get progress (0.0 to 1.0)
497    #[napi(getter)]
498    pub fn progress(&self) -> f64 {
499        (self.current_step as f64 / self.total_steps.max(1) as f64).min(1.0)
500    }
501}
502
503// ============================================================================
504// Temperature Annealing
505// ============================================================================
506
507/// Decay type for temperature annealing
508#[napi(string_enum)]
509pub enum DecayType {
510    Linear,
511    Exponential,
512    Cosine,
513    Step,
514}
515
516impl From<DecayType> for RustDecayType {
517    fn from(dt: DecayType) -> Self {
518        match dt {
519            DecayType::Linear => RustDecayType::Linear,
520            DecayType::Exponential => RustDecayType::Exponential,
521            DecayType::Cosine => RustDecayType::Cosine,
522            DecayType::Step => RustDecayType::Step,
523        }
524    }
525}
526
527/// Temperature annealing scheduler
528#[napi]
529pub struct TemperatureAnnealing {
530    inner: RustTempAnnealing,
531}
532
533#[napi]
534impl TemperatureAnnealing {
535    /// Create a new temperature annealing scheduler
536    ///
537    /// # Arguments
538    /// * `initial_temp` - Starting temperature
539    /// * `final_temp` - Final temperature
540    /// * `steps` - Number of annealing steps
541    #[napi(constructor)]
542    pub fn new(initial_temp: f64, final_temp: f64, steps: u32) -> Self {
543        Self {
544            inner: RustTempAnnealing::new(
545                initial_temp as f32,
546                final_temp as f32,
547                steps as usize,
548            ),
549        }
550    }
551
552    /// Create with specific decay type
553    #[napi(factory)]
554    pub fn with_decay(initial_temp: f64, final_temp: f64, steps: u32, decay_type: DecayType) -> Self {
555        Self {
556            inner: RustTempAnnealing::new(
557                initial_temp as f32,
558                final_temp as f32,
559                steps as usize,
560            ).with_decay(decay_type.into()),
561        }
562    }
563
564    /// Get current temperature
565    #[napi]
566    pub fn get_temp(&self) -> f64 {
567        self.inner.get_temp() as f64
568    }
569
570    /// Step the scheduler and return current temperature
571    #[napi]
572    pub fn step(&mut self) -> f64 {
573        self.inner.step() as f64
574    }
575
576    /// Reset scheduler
577    #[napi]
578    pub fn reset(&mut self) {
579        self.inner.reset();
580    }
581}
582
583// ============================================================================
584// Curriculum Learning
585// ============================================================================
586
587/// Curriculum stage configuration
588#[napi(object)]
589pub struct CurriculumStageConfig {
590    pub name: String,
591    pub difficulty: f64,
592    pub duration: u32,
593    pub temperature: f64,
594    pub negative_count: u32,
595}
596
597/// Curriculum scheduler for progressive training
598#[napi]
599pub struct CurriculumScheduler {
600    inner: RustCurriculum,
601}
602
603#[napi]
604impl CurriculumScheduler {
605    /// Create an empty curriculum scheduler
606    #[napi(constructor)]
607    pub fn new() -> Self {
608        Self {
609            inner: RustCurriculum::new(),
610        }
611    }
612
613    /// Create a default easy-to-hard curriculum
614    #[napi(factory)]
615    pub fn default_curriculum(total_steps: u32) -> Self {
616        Self {
617            inner: RustCurriculum::default_curriculum(total_steps as usize),
618        }
619    }
620
621    /// Add a stage to the curriculum
622    #[napi]
623    pub fn add_stage(&mut self, config: CurriculumStageConfig) {
624        let stage = RustStage::new(&config.name)
625            .difficulty(config.difficulty as f32)
626            .duration(config.duration as usize)
627            .temperature(config.temperature as f32)
628            .negative_count(config.negative_count as usize);
629
630        // Rebuild with added stage
631        let new_inner = std::mem::take(&mut self.inner).add_stage(stage);
632        self.inner = new_inner;
633    }
634
635    /// Step the curriculum and return current stage info
636    #[napi]
637    pub fn step(&mut self) -> Option<CurriculumStageConfig> {
638        self.inner.step().map(|s| CurriculumStageConfig {
639            name: s.name.clone(),
640            difficulty: s.difficulty as f64,
641            duration: s.duration as u32,
642            temperature: s.temperature as f64,
643            negative_count: s.negative_count as u32,
644        })
645    }
646
647    /// Get current difficulty (0.0 to 1.0)
648    #[napi(getter)]
649    pub fn difficulty(&self) -> f64 {
650        self.inner.difficulty() as f64
651    }
652
653    /// Get current temperature
654    #[napi(getter)]
655    pub fn temperature(&self) -> f64 {
656        self.inner.temperature() as f64
657    }
658
659    /// Get current negative count
660    #[napi(getter)]
661    pub fn negative_count(&self) -> u32 {
662        self.inner.negative_count() as u32
663    }
664
665    /// Check if curriculum is complete
666    #[napi(getter)]
667    pub fn is_complete(&self) -> bool {
668        self.inner.is_complete()
669    }
670
671    /// Get overall progress (0.0 to 1.0)
672    #[napi(getter)]
673    pub fn progress(&self) -> f64 {
674        self.inner.progress() as f64
675    }
676
677    /// Reset curriculum
678    #[napi]
679    pub fn reset(&mut self) {
680        self.inner.reset();
681    }
682}
683
684// ============================================================================
685// Negative Mining
686// ============================================================================
687
688/// Mining strategy for negative selection
689#[napi(string_enum)]
690pub enum MiningStrategy {
691    Random,
692    HardNegative,
693    SemiHard,
694    DistanceWeighted,
695}
696
697impl From<MiningStrategy> for RustMiningStrategy {
698    fn from(ms: MiningStrategy) -> Self {
699        match ms {
700            MiningStrategy::Random => RustMiningStrategy::Random,
701            MiningStrategy::HardNegative => RustMiningStrategy::HardNegative,
702            MiningStrategy::SemiHard => RustMiningStrategy::SemiHard,
703            MiningStrategy::DistanceWeighted => RustMiningStrategy::DistanceWeighted,
704        }
705    }
706}
707
708/// Hard negative miner for selecting informative negatives
709#[napi]
710pub struct HardNegativeMiner {
711    inner: RustHardMiner,
712}
713
714#[napi]
715impl HardNegativeMiner {
716    /// Create a new hard negative miner
717    ///
718    /// # Arguments
719    /// * `strategy` - Mining strategy to use
720    #[napi(constructor)]
721    pub fn new(strategy: MiningStrategy) -> Self {
722        Self {
723            inner: RustHardMiner::new(strategy.into()),
724        }
725    }
726
727    /// Create with margin (for semi-hard mining)
728    #[napi(factory)]
729    pub fn with_margin(strategy: MiningStrategy, margin: f64) -> Self {
730        Self {
731            inner: RustHardMiner::new(strategy.into())
732                .with_margin(margin as f32),
733        }
734    }
735
736    /// Create with temperature (for distance-weighted mining)
737    #[napi(factory)]
738    pub fn with_temperature(strategy: MiningStrategy, temperature: f64) -> Self {
739        Self {
740            inner: RustHardMiner::new(strategy.into())
741                .with_temperature(temperature as f32),
742        }
743    }
744
745    /// Mine negative indices from candidates
746    ///
747    /// # Arguments
748    /// * `anchor` - Anchor embedding
749    /// * `positive` - Positive example embedding
750    /// * `candidates` - Array of candidate embeddings
751    /// * `num_negatives` - Number of negatives to select
752    ///
753    /// # Returns
754    /// Array of indices into the candidates array
755    #[napi]
756    pub fn mine(
757        &self,
758        anchor: Float32Array,
759        positive: Float32Array,
760        candidates: Vec<Float32Array>,
761        num_negatives: u32,
762    ) -> Vec<u32> {
763        let anchor_slice = anchor.as_ref();
764        let positive_slice = positive.as_ref();
765        let candidates_vec: Vec<Vec<f32>> = candidates.into_iter().map(|c| c.to_vec()).collect();
766        let candidates_refs: Vec<&[f32]> = candidates_vec.iter().map(|c| c.as_slice()).collect();
767
768        self.inner
769            .mine(anchor_slice, positive_slice, &candidates_refs, num_negatives as usize)
770            .into_iter()
771            .map(|i| i as u32)
772            .collect()
773    }
774}
775
776/// In-batch negative mining utility
777#[napi]
778pub struct InBatchMiner {
779    exclude_positive: bool,
780}
781
782#[napi]
783impl InBatchMiner {
784    /// Create a new in-batch miner
785    #[napi(constructor)]
786    pub fn new() -> Self {
787        Self {
788            exclude_positive: true,
789        }
790    }
791
792    /// Create without excluding positive
793    #[napi(factory)]
794    pub fn include_positive() -> Self {
795        Self {
796            exclude_positive: false,
797        }
798    }
799
800    /// Get negative indices for a given anchor in a batch
801    ///
802    /// # Arguments
803    /// * `anchor_idx` - Index of the anchor in the batch
804    /// * `positive_idx` - Index of the positive in the batch
805    /// * `batch_size` - Total batch size
806    ///
807    /// # Returns
808    /// Array of indices that can be used as negatives
809    #[napi]
810    pub fn get_negatives(&self, anchor_idx: u32, positive_idx: u32, batch_size: u32) -> Vec<u32> {
811        (0..batch_size)
812            .filter(|&i| {
813                i != anchor_idx && (!self.exclude_positive || i != positive_idx)
814            })
815            .collect()
816    }
817}