Skip to main content

ruvector_attention_node/
training.rs

1//! NAPI-RS bindings for training utilities
2//!
3//! Provides Node.js bindings for:
4//! - Loss functions (InfoNCE, LocalContrastive, SpectralRegularization)
5//! - Optimizers (SGD, Adam, AdamW)
6//! - Learning rate schedulers
7//! - Curriculum learning
8//! - Negative mining
9
10use napi::bindgen_prelude::*;
11use napi_derive::napi;
12use ruvector_attention::training::{
13    Adam as RustAdam, AdamW as RustAdamW, CurriculumScheduler as RustCurriculum,
14    CurriculumStage as RustStage, DecayType as RustDecayType, HardNegativeMiner as RustHardMiner,
15    InfoNCELoss as RustInfoNCE, LocalContrastiveLoss as RustLocalContrastive, Loss,
16    MiningStrategy as RustMiningStrategy, NegativeMiner, Optimizer,
17    SpectralRegularization as RustSpectralReg, TemperatureAnnealing as RustTempAnnealing,
18    SGD as RustSGD,
19};
20
21// ============================================================================
22// Loss Functions
23// ============================================================================
24
25/// InfoNCE contrastive loss for representation learning
26#[napi]
27pub struct InfoNCELoss {
28    inner: RustInfoNCE,
29    temperature_value: f32,
30}
31
32#[napi]
33impl InfoNCELoss {
34    /// Create a new InfoNCE loss instance
35    ///
36    /// # Arguments
37    /// * `temperature` - Temperature parameter for softmax (typically 0.07-0.1)
38    #[napi(constructor)]
39    pub fn new(temperature: f64) -> Self {
40        Self {
41            inner: RustInfoNCE::new(temperature as f32),
42            temperature_value: temperature as f32,
43        }
44    }
45
46    /// Compute InfoNCE loss
47    ///
48    /// # Arguments
49    /// * `anchor` - Anchor embedding
50    /// * `positive` - Positive example embedding
51    /// * `negatives` - Array of negative example embeddings
52    #[napi]
53    pub fn compute(
54        &self,
55        anchor: Float32Array,
56        positive: Float32Array,
57        negatives: Vec<Float32Array>,
58    ) -> f64 {
59        let anchor_slice = anchor.as_ref();
60        let positive_slice = positive.as_ref();
61        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
62        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
63
64        self.inner
65            .compute(anchor_slice, positive_slice, &negatives_refs) as f64
66    }
67
68    /// Compute InfoNCE loss with gradients
69    ///
70    /// Returns an object with `loss` and `gradients` fields
71    #[napi]
72    pub fn compute_with_gradients(
73        &self,
74        anchor: Float32Array,
75        positive: Float32Array,
76        negatives: Vec<Float32Array>,
77    ) -> LossWithGradients {
78        let anchor_slice = anchor.as_ref();
79        let positive_slice = positive.as_ref();
80        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
81        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
82
83        let (loss, gradients) =
84            self.inner
85                .compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
86
87        LossWithGradients {
88            loss: loss as f64,
89            gradients: Float32Array::new(gradients),
90        }
91    }
92
93    /// Get the temperature
94    #[napi(getter)]
95    pub fn temperature(&self) -> f64 {
96        self.temperature_value as f64
97    }
98}
99
100/// Loss computation result with gradients
101#[napi(object)]
102pub struct LossWithGradients {
103    pub loss: f64,
104    pub gradients: Float32Array,
105}
106
107/// Local contrastive loss for neighborhood preservation
108#[napi]
109pub struct LocalContrastiveLoss {
110    inner: RustLocalContrastive,
111    margin_value: f32,
112}
113
114#[napi]
115impl LocalContrastiveLoss {
116    /// Create a new local contrastive loss instance
117    ///
118    /// # Arguments
119    /// * `margin` - Margin for triplet loss
120    #[napi(constructor)]
121    pub fn new(margin: f64) -> Self {
122        Self {
123            inner: RustLocalContrastive::new(margin as f32),
124            margin_value: margin as f32,
125        }
126    }
127
128    /// Compute local contrastive loss
129    #[napi]
130    pub fn compute(
131        &self,
132        anchor: Float32Array,
133        positive: Float32Array,
134        negatives: Vec<Float32Array>,
135    ) -> f64 {
136        let anchor_slice = anchor.as_ref();
137        let positive_slice = positive.as_ref();
138        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
139        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
140
141        self.inner
142            .compute(anchor_slice, positive_slice, &negatives_refs) as f64
143    }
144
145    /// Compute with gradients
146    #[napi]
147    pub fn compute_with_gradients(
148        &self,
149        anchor: Float32Array,
150        positive: Float32Array,
151        negatives: Vec<Float32Array>,
152    ) -> LossWithGradients {
153        let anchor_slice = anchor.as_ref();
154        let positive_slice = positive.as_ref();
155        let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
156        let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
157
158        let (loss, gradients) =
159            self.inner
160                .compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
161
162        LossWithGradients {
163            loss: loss as f64,
164            gradients: Float32Array::new(gradients),
165        }
166    }
167
168    /// Get the margin
169    #[napi(getter)]
170    pub fn margin(&self) -> f64 {
171        self.margin_value as f64
172    }
173}
174
175/// Spectral regularization for smooth representations
176#[napi]
177pub struct SpectralRegularization {
178    inner: RustSpectralReg,
179    weight_value: f32,
180}
181
182#[napi]
183impl SpectralRegularization {
184    /// Create a new spectral regularization instance
185    ///
186    /// # Arguments
187    /// * `weight` - Regularization weight
188    #[napi(constructor)]
189    pub fn new(weight: f64) -> Self {
190        Self {
191            inner: RustSpectralReg::new(weight as f32),
192            weight_value: weight as f32,
193        }
194    }
195
196    /// Compute spectral regularization for a batch of embeddings
197    #[napi]
198    pub fn compute_batch(&self, embeddings: Vec<Float32Array>) -> f64 {
199        let embeddings_vec: Vec<Vec<f32>> = embeddings.into_iter().map(|e| e.to_vec()).collect();
200        let embeddings_refs: Vec<&[f32]> = embeddings_vec.iter().map(|e| e.as_slice()).collect();
201
202        self.inner.compute_batch(&embeddings_refs) as f64
203    }
204
205    /// Get the weight
206    #[napi(getter)]
207    pub fn weight(&self) -> f64 {
208        self.weight_value as f64
209    }
210}
211
212// ============================================================================
213// Optimizers
214// ============================================================================
215
216/// SGD optimizer with optional momentum and weight decay
217#[napi]
218pub struct SGDOptimizer {
219    inner: RustSGD,
220}
221
222#[napi]
223impl SGDOptimizer {
224    /// Create a new SGD optimizer
225    ///
226    /// # Arguments
227    /// * `param_count` - Number of parameters
228    /// * `learning_rate` - Learning rate
229    #[napi(constructor)]
230    pub fn new(param_count: u32, learning_rate: f64) -> Self {
231        Self {
232            inner: RustSGD::new(param_count as usize, learning_rate as f32),
233        }
234    }
235
236    /// Create with momentum
237    #[napi(factory)]
238    pub fn with_momentum(param_count: u32, learning_rate: f64, momentum: f64) -> Self {
239        Self {
240            inner: RustSGD::new(param_count as usize, learning_rate as f32)
241                .with_momentum(momentum as f32),
242        }
243    }
244
245    /// Create with momentum and weight decay
246    #[napi(factory)]
247    pub fn with_weight_decay(
248        param_count: u32,
249        learning_rate: f64,
250        momentum: f64,
251        weight_decay: f64,
252    ) -> Self {
253        Self {
254            inner: RustSGD::new(param_count as usize, learning_rate as f32)
255                .with_momentum(momentum as f32)
256                .with_weight_decay(weight_decay as f32),
257        }
258    }
259
260    /// Perform an optimization step
261    ///
262    /// # Arguments
263    /// * `params` - Parameter array
264    /// * `gradients` - Gradient array
265    ///
266    /// # Returns
267    /// Updated parameter array
268    #[napi]
269    pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
270        let mut params_vec = params.to_vec();
271        let gradients_slice = gradients.as_ref();
272        self.inner.step(&mut params_vec, gradients_slice);
273        Float32Array::new(params_vec)
274    }
275
276    /// Reset optimizer state
277    #[napi]
278    pub fn reset(&mut self) {
279        self.inner.reset();
280    }
281
282    /// Get current learning rate
283    #[napi(getter)]
284    pub fn learning_rate(&self) -> f64 {
285        self.inner.learning_rate() as f64
286    }
287
288    /// Set learning rate
289    #[napi(setter)]
290    pub fn set_learning_rate(&mut self, lr: f64) {
291        self.inner.set_learning_rate(lr as f32);
292    }
293}
294
295/// Adam optimizer with bias correction
296#[napi]
297pub struct AdamOptimizer {
298    inner: RustAdam,
299}
300
301#[napi]
302impl AdamOptimizer {
303    /// Create a new Adam optimizer
304    ///
305    /// # Arguments
306    /// * `param_count` - Number of parameters
307    /// * `learning_rate` - Learning rate
308    #[napi(constructor)]
309    pub fn new(param_count: u32, learning_rate: f64) -> Self {
310        Self {
311            inner: RustAdam::new(param_count as usize, learning_rate as f32),
312        }
313    }
314
315    /// Create with custom betas
316    #[napi(factory)]
317    pub fn with_betas(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64) -> Self {
318        Self {
319            inner: RustAdam::new(param_count as usize, learning_rate as f32)
320                .with_betas(beta1 as f32, beta2 as f32),
321        }
322    }
323
324    /// Create with full configuration
325    #[napi(factory)]
326    pub fn with_config(
327        param_count: u32,
328        learning_rate: f64,
329        beta1: f64,
330        beta2: f64,
331        epsilon: f64,
332        weight_decay: f64,
333    ) -> Self {
334        Self {
335            inner: RustAdam::new(param_count as usize, learning_rate as f32)
336                .with_betas(beta1 as f32, beta2 as f32)
337                .with_epsilon(epsilon as f32)
338                .with_weight_decay(weight_decay as f32),
339        }
340    }
341
342    /// Perform an optimization step
343    ///
344    /// # Returns
345    /// Updated parameter array
346    #[napi]
347    pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
348        let mut params_vec = params.to_vec();
349        let gradients_slice = gradients.as_ref();
350        self.inner.step(&mut params_vec, gradients_slice);
351        Float32Array::new(params_vec)
352    }
353
354    /// Reset optimizer state (momentum terms)
355    #[napi]
356    pub fn reset(&mut self) {
357        self.inner.reset();
358    }
359
360    /// Get current learning rate
361    #[napi(getter)]
362    pub fn learning_rate(&self) -> f64 {
363        self.inner.learning_rate() as f64
364    }
365
366    /// Set learning rate
367    #[napi(setter)]
368    pub fn set_learning_rate(&mut self, lr: f64) {
369        self.inner.set_learning_rate(lr as f32);
370    }
371}
372
373/// AdamW optimizer (Adam with decoupled weight decay)
374#[napi]
375pub struct AdamWOptimizer {
376    inner: RustAdamW,
377    wd: f32,
378}
379
380#[napi]
381impl AdamWOptimizer {
382    /// Create a new AdamW optimizer
383    ///
384    /// # Arguments
385    /// * `param_count` - Number of parameters
386    /// * `learning_rate` - Learning rate
387    /// * `weight_decay` - Weight decay coefficient
388    #[napi(constructor)]
389    pub fn new(param_count: u32, learning_rate: f64, weight_decay: f64) -> Self {
390        Self {
391            inner: RustAdamW::new(param_count as usize, learning_rate as f32)
392                .with_weight_decay(weight_decay as f32),
393            wd: weight_decay as f32,
394        }
395    }
396
397    /// Create with custom betas
398    #[napi(factory)]
399    pub fn with_betas(
400        param_count: u32,
401        learning_rate: f64,
402        weight_decay: f64,
403        beta1: f64,
404        beta2: f64,
405    ) -> Self {
406        Self {
407            inner: RustAdamW::new(param_count as usize, learning_rate as f32)
408                .with_weight_decay(weight_decay as f32)
409                .with_betas(beta1 as f32, beta2 as f32),
410            wd: weight_decay as f32,
411        }
412    }
413
414    /// Perform an optimization step
415    ///
416    /// # Returns
417    /// Updated parameter array
418    #[napi]
419    pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
420        let mut params_vec = params.to_vec();
421        let gradients_slice = gradients.as_ref();
422        self.inner.step(&mut params_vec, gradients_slice);
423        Float32Array::new(params_vec)
424    }
425
426    /// Reset optimizer state
427    #[napi]
428    pub fn reset(&mut self) {
429        self.inner.reset();
430    }
431
432    /// Get current learning rate
433    #[napi(getter)]
434    pub fn learning_rate(&self) -> f64 {
435        self.inner.learning_rate() as f64
436    }
437
438    /// Set learning rate
439    #[napi(setter)]
440    pub fn set_learning_rate(&mut self, lr: f64) {
441        self.inner.set_learning_rate(lr as f32);
442    }
443
444    /// Get weight decay
445    #[napi(getter)]
446    pub fn weight_decay(&self) -> f64 {
447        self.wd as f64
448    }
449}
450
451// ============================================================================
452// Learning Rate Scheduling
453// ============================================================================
454
455/// Learning rate scheduler with warmup and cosine decay
456#[napi]
457pub struct LearningRateScheduler {
458    initial_lr: f32,
459    current_step: usize,
460    warmup_steps: usize,
461    total_steps: usize,
462    min_lr: f32,
463}
464
465#[napi]
466impl LearningRateScheduler {
467    /// Create a new learning rate scheduler
468    ///
469    /// # Arguments
470    /// * `initial_lr` - Initial/peak learning rate
471    /// * `warmup_steps` - Number of warmup steps
472    /// * `total_steps` - Total training steps
473    #[napi(constructor)]
474    pub fn new(initial_lr: f64, warmup_steps: u32, total_steps: u32) -> Self {
475        Self {
476            initial_lr: initial_lr as f32,
477            current_step: 0,
478            warmup_steps: warmup_steps as usize,
479            total_steps: total_steps as usize,
480            min_lr: 1e-7,
481        }
482    }
483
484    /// Create with minimum learning rate
485    #[napi(factory)]
486    pub fn with_min_lr(initial_lr: f64, warmup_steps: u32, total_steps: u32, min_lr: f64) -> Self {
487        Self {
488            initial_lr: initial_lr as f32,
489            current_step: 0,
490            warmup_steps: warmup_steps as usize,
491            total_steps: total_steps as usize,
492            min_lr: min_lr as f32,
493        }
494    }
495
496    /// Get learning rate for current step
497    #[napi]
498    pub fn get_lr(&self) -> f64 {
499        if self.current_step < self.warmup_steps {
500            // Linear warmup
501            (self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32) as f64
502        } else {
503            // Cosine decay
504            let progress = (self.current_step - self.warmup_steps) as f32
505                / (self.total_steps - self.warmup_steps).max(1) as f32;
506            let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
507            (self.min_lr + (self.initial_lr - self.min_lr) * decay) as f64
508        }
509    }
510
511    /// Step the scheduler and return current learning rate
512    #[napi]
513    pub fn step(&mut self) -> f64 {
514        let lr = self.get_lr();
515        self.current_step += 1;
516        lr
517    }
518
519    /// Reset scheduler to initial state
520    #[napi]
521    pub fn reset(&mut self) {
522        self.current_step = 0;
523    }
524
525    /// Get current step
526    #[napi(getter)]
527    pub fn current_step(&self) -> u32 {
528        self.current_step as u32
529    }
530
531    /// Get progress (0.0 to 1.0)
532    #[napi(getter)]
533    pub fn progress(&self) -> f64 {
534        (self.current_step as f64 / self.total_steps.max(1) as f64).min(1.0)
535    }
536}
537
538// ============================================================================
539// Temperature Annealing
540// ============================================================================
541
542/// Decay type for temperature annealing
543#[napi(string_enum)]
544pub enum DecayType {
545    Linear,
546    Exponential,
547    Cosine,
548    Step,
549}
550
551impl From<DecayType> for RustDecayType {
552    fn from(dt: DecayType) -> Self {
553        match dt {
554            DecayType::Linear => RustDecayType::Linear,
555            DecayType::Exponential => RustDecayType::Exponential,
556            DecayType::Cosine => RustDecayType::Cosine,
557            DecayType::Step => RustDecayType::Step,
558        }
559    }
560}
561
562/// Temperature annealing scheduler
563#[napi]
564pub struct TemperatureAnnealing {
565    inner: RustTempAnnealing,
566}
567
568#[napi]
569impl TemperatureAnnealing {
570    /// Create a new temperature annealing scheduler
571    ///
572    /// # Arguments
573    /// * `initial_temp` - Starting temperature
574    /// * `final_temp` - Final temperature
575    /// * `steps` - Number of annealing steps
576    #[napi(constructor)]
577    pub fn new(initial_temp: f64, final_temp: f64, steps: u32) -> Self {
578        Self {
579            inner: RustTempAnnealing::new(initial_temp as f32, final_temp as f32, steps as usize),
580        }
581    }
582
583    /// Create with specific decay type
584    #[napi(factory)]
585    pub fn with_decay(
586        initial_temp: f64,
587        final_temp: f64,
588        steps: u32,
589        decay_type: DecayType,
590    ) -> Self {
591        Self {
592            inner: RustTempAnnealing::new(initial_temp as f32, final_temp as f32, steps as usize)
593                .with_decay(decay_type.into()),
594        }
595    }
596
597    /// Get current temperature
598    #[napi]
599    pub fn get_temp(&self) -> f64 {
600        self.inner.get_temp() as f64
601    }
602
603    /// Step the scheduler and return current temperature
604    #[napi]
605    pub fn step(&mut self) -> f64 {
606        self.inner.step() as f64
607    }
608
609    /// Reset scheduler
610    #[napi]
611    pub fn reset(&mut self) {
612        self.inner.reset();
613    }
614}
615
616// ============================================================================
617// Curriculum Learning
618// ============================================================================
619
620/// Curriculum stage configuration
621#[napi(object)]
622pub struct CurriculumStageConfig {
623    pub name: String,
624    pub difficulty: f64,
625    pub duration: u32,
626    pub temperature: f64,
627    pub negative_count: u32,
628}
629
630/// Curriculum scheduler for progressive training
631#[napi]
632pub struct CurriculumScheduler {
633    inner: RustCurriculum,
634}
635
636#[napi]
637impl CurriculumScheduler {
638    /// Create an empty curriculum scheduler
639    #[napi(constructor)]
640    pub fn new() -> Self {
641        Self {
642            inner: RustCurriculum::new(),
643        }
644    }
645
646    /// Create a default easy-to-hard curriculum
647    #[napi(factory)]
648    pub fn default_curriculum(total_steps: u32) -> Self {
649        Self {
650            inner: RustCurriculum::default_curriculum(total_steps as usize),
651        }
652    }
653
654    /// Add a stage to the curriculum
655    #[napi]
656    pub fn add_stage(&mut self, config: CurriculumStageConfig) {
657        let stage = RustStage::new(&config.name)
658            .difficulty(config.difficulty as f32)
659            .duration(config.duration as usize)
660            .temperature(config.temperature as f32)
661            .negative_count(config.negative_count as usize);
662
663        // Rebuild with added stage
664        let new_inner = std::mem::take(&mut self.inner).add_stage(stage);
665        self.inner = new_inner;
666    }
667
668    /// Step the curriculum and return current stage info
669    #[napi]
670    pub fn step(&mut self) -> Option<CurriculumStageConfig> {
671        self.inner.step().map(|s| CurriculumStageConfig {
672            name: s.name.clone(),
673            difficulty: s.difficulty as f64,
674            duration: s.duration as u32,
675            temperature: s.temperature as f64,
676            negative_count: s.negative_count as u32,
677        })
678    }
679
680    /// Get current difficulty (0.0 to 1.0)
681    #[napi(getter)]
682    pub fn difficulty(&self) -> f64 {
683        self.inner.difficulty() as f64
684    }
685
686    /// Get current temperature
687    #[napi(getter)]
688    pub fn temperature(&self) -> f64 {
689        self.inner.temperature() as f64
690    }
691
692    /// Get current negative count
693    #[napi(getter)]
694    pub fn negative_count(&self) -> u32 {
695        self.inner.negative_count() as u32
696    }
697
698    /// Check if curriculum is complete
699    #[napi(getter)]
700    pub fn is_complete(&self) -> bool {
701        self.inner.is_complete()
702    }
703
704    /// Get overall progress (0.0 to 1.0)
705    #[napi(getter)]
706    pub fn progress(&self) -> f64 {
707        self.inner.progress() as f64
708    }
709
710    /// Reset curriculum
711    #[napi]
712    pub fn reset(&mut self) {
713        self.inner.reset();
714    }
715}
716
717// ============================================================================
718// Negative Mining
719// ============================================================================
720
721/// Mining strategy for negative selection
722#[napi(string_enum)]
723pub enum MiningStrategy {
724    Random,
725    HardNegative,
726    SemiHard,
727    DistanceWeighted,
728}
729
730impl From<MiningStrategy> for RustMiningStrategy {
731    fn from(ms: MiningStrategy) -> Self {
732        match ms {
733            MiningStrategy::Random => RustMiningStrategy::Random,
734            MiningStrategy::HardNegative => RustMiningStrategy::HardNegative,
735            MiningStrategy::SemiHard => RustMiningStrategy::SemiHard,
736            MiningStrategy::DistanceWeighted => RustMiningStrategy::DistanceWeighted,
737        }
738    }
739}
740
741/// Hard negative miner for selecting informative negatives
742#[napi]
743pub struct HardNegativeMiner {
744    inner: RustHardMiner,
745}
746
747#[napi]
748impl HardNegativeMiner {
749    /// Create a new hard negative miner
750    ///
751    /// # Arguments
752    /// * `strategy` - Mining strategy to use
753    #[napi(constructor)]
754    pub fn new(strategy: MiningStrategy) -> Self {
755        Self {
756            inner: RustHardMiner::new(strategy.into()),
757        }
758    }
759
760    /// Create with margin (for semi-hard mining)
761    #[napi(factory)]
762    pub fn with_margin(strategy: MiningStrategy, margin: f64) -> Self {
763        Self {
764            inner: RustHardMiner::new(strategy.into()).with_margin(margin as f32),
765        }
766    }
767
768    /// Create with temperature (for distance-weighted mining)
769    #[napi(factory)]
770    pub fn with_temperature(strategy: MiningStrategy, temperature: f64) -> Self {
771        Self {
772            inner: RustHardMiner::new(strategy.into()).with_temperature(temperature as f32),
773        }
774    }
775
776    /// Mine negative indices from candidates
777    ///
778    /// # Arguments
779    /// * `anchor` - Anchor embedding
780    /// * `positive` - Positive example embedding
781    /// * `candidates` - Array of candidate embeddings
782    /// * `num_negatives` - Number of negatives to select
783    ///
784    /// # Returns
785    /// Array of indices into the candidates array
786    #[napi]
787    pub fn mine(
788        &self,
789        anchor: Float32Array,
790        positive: Float32Array,
791        candidates: Vec<Float32Array>,
792        num_negatives: u32,
793    ) -> Vec<u32> {
794        let anchor_slice = anchor.as_ref();
795        let positive_slice = positive.as_ref();
796        let candidates_vec: Vec<Vec<f32>> = candidates.into_iter().map(|c| c.to_vec()).collect();
797        let candidates_refs: Vec<&[f32]> = candidates_vec.iter().map(|c| c.as_slice()).collect();
798
799        self.inner
800            .mine(
801                anchor_slice,
802                positive_slice,
803                &candidates_refs,
804                num_negatives as usize,
805            )
806            .into_iter()
807            .map(|i| i as u32)
808            .collect()
809    }
810}
811
812/// In-batch negative mining utility
813#[napi]
814pub struct InBatchMiner {
815    exclude_positive: bool,
816}
817
818#[napi]
819impl InBatchMiner {
820    /// Create a new in-batch miner
821    #[napi(constructor)]
822    pub fn new() -> Self {
823        Self {
824            exclude_positive: true,
825        }
826    }
827
828    /// Create without excluding positive
829    #[napi(factory)]
830    pub fn include_positive() -> Self {
831        Self {
832            exclude_positive: false,
833        }
834    }
835
836    /// Get negative indices for a given anchor in a batch
837    ///
838    /// # Arguments
839    /// * `anchor_idx` - Index of the anchor in the batch
840    /// * `positive_idx` - Index of the positive in the batch
841    /// * `batch_size` - Total batch size
842    ///
843    /// # Returns
844    /// Array of indices that can be used as negatives
845    #[napi]
846    pub fn get_negatives(&self, anchor_idx: u32, positive_idx: u32, batch_size: u32) -> Vec<u32> {
847        (0..batch_size)
848            .filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
849            .collect()
850    }
851}