1use napi::bindgen_prelude::*;
11use napi_derive::napi;
12use ruvector_attention::training::{
13 Adam as RustAdam, AdamW as RustAdamW, CurriculumScheduler as RustCurriculum,
14 CurriculumStage as RustStage, DecayType as RustDecayType, HardNegativeMiner as RustHardMiner,
15 InfoNCELoss as RustInfoNCE, LocalContrastiveLoss as RustLocalContrastive, Loss,
16 MiningStrategy as RustMiningStrategy, NegativeMiner, Optimizer,
17 SpectralRegularization as RustSpectralReg, TemperatureAnnealing as RustTempAnnealing,
18 SGD as RustSGD,
19};
20
21#[napi]
27pub struct InfoNCELoss {
28 inner: RustInfoNCE,
29 temperature_value: f32,
30}
31
32#[napi]
33impl InfoNCELoss {
34 #[napi(constructor)]
39 pub fn new(temperature: f64) -> Self {
40 Self {
41 inner: RustInfoNCE::new(temperature as f32),
42 temperature_value: temperature as f32,
43 }
44 }
45
46 #[napi]
53 pub fn compute(
54 &self,
55 anchor: Float32Array,
56 positive: Float32Array,
57 negatives: Vec<Float32Array>,
58 ) -> f64 {
59 let anchor_slice = anchor.as_ref();
60 let positive_slice = positive.as_ref();
61 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
62 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
63
64 self.inner
65 .compute(anchor_slice, positive_slice, &negatives_refs) as f64
66 }
67
68 #[napi]
72 pub fn compute_with_gradients(
73 &self,
74 anchor: Float32Array,
75 positive: Float32Array,
76 negatives: Vec<Float32Array>,
77 ) -> LossWithGradients {
78 let anchor_slice = anchor.as_ref();
79 let positive_slice = positive.as_ref();
80 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
81 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
82
83 let (loss, gradients) =
84 self.inner
85 .compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
86
87 LossWithGradients {
88 loss: loss as f64,
89 gradients: Float32Array::new(gradients),
90 }
91 }
92
93 #[napi(getter)]
95 pub fn temperature(&self) -> f64 {
96 self.temperature_value as f64
97 }
98}
99
100#[napi(object)]
102pub struct LossWithGradients {
103 pub loss: f64,
104 pub gradients: Float32Array,
105}
106
107#[napi]
109pub struct LocalContrastiveLoss {
110 inner: RustLocalContrastive,
111 margin_value: f32,
112}
113
114#[napi]
115impl LocalContrastiveLoss {
116 #[napi(constructor)]
121 pub fn new(margin: f64) -> Self {
122 Self {
123 inner: RustLocalContrastive::new(margin as f32),
124 margin_value: margin as f32,
125 }
126 }
127
128 #[napi]
130 pub fn compute(
131 &self,
132 anchor: Float32Array,
133 positive: Float32Array,
134 negatives: Vec<Float32Array>,
135 ) -> f64 {
136 let anchor_slice = anchor.as_ref();
137 let positive_slice = positive.as_ref();
138 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
139 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
140
141 self.inner
142 .compute(anchor_slice, positive_slice, &negatives_refs) as f64
143 }
144
145 #[napi]
147 pub fn compute_with_gradients(
148 &self,
149 anchor: Float32Array,
150 positive: Float32Array,
151 negatives: Vec<Float32Array>,
152 ) -> LossWithGradients {
153 let anchor_slice = anchor.as_ref();
154 let positive_slice = positive.as_ref();
155 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
156 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
157
158 let (loss, gradients) =
159 self.inner
160 .compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
161
162 LossWithGradients {
163 loss: loss as f64,
164 gradients: Float32Array::new(gradients),
165 }
166 }
167
168 #[napi(getter)]
170 pub fn margin(&self) -> f64 {
171 self.margin_value as f64
172 }
173}
174
175#[napi]
177pub struct SpectralRegularization {
178 inner: RustSpectralReg,
179 weight_value: f32,
180}
181
182#[napi]
183impl SpectralRegularization {
184 #[napi(constructor)]
189 pub fn new(weight: f64) -> Self {
190 Self {
191 inner: RustSpectralReg::new(weight as f32),
192 weight_value: weight as f32,
193 }
194 }
195
196 #[napi]
198 pub fn compute_batch(&self, embeddings: Vec<Float32Array>) -> f64 {
199 let embeddings_vec: Vec<Vec<f32>> = embeddings.into_iter().map(|e| e.to_vec()).collect();
200 let embeddings_refs: Vec<&[f32]> = embeddings_vec.iter().map(|e| e.as_slice()).collect();
201
202 self.inner.compute_batch(&embeddings_refs) as f64
203 }
204
205 #[napi(getter)]
207 pub fn weight(&self) -> f64 {
208 self.weight_value as f64
209 }
210}
211
212#[napi]
218pub struct SGDOptimizer {
219 inner: RustSGD,
220}
221
222#[napi]
223impl SGDOptimizer {
224 #[napi(constructor)]
230 pub fn new(param_count: u32, learning_rate: f64) -> Self {
231 Self {
232 inner: RustSGD::new(param_count as usize, learning_rate as f32),
233 }
234 }
235
236 #[napi(factory)]
238 pub fn with_momentum(param_count: u32, learning_rate: f64, momentum: f64) -> Self {
239 Self {
240 inner: RustSGD::new(param_count as usize, learning_rate as f32)
241 .with_momentum(momentum as f32),
242 }
243 }
244
245 #[napi(factory)]
247 pub fn with_weight_decay(
248 param_count: u32,
249 learning_rate: f64,
250 momentum: f64,
251 weight_decay: f64,
252 ) -> Self {
253 Self {
254 inner: RustSGD::new(param_count as usize, learning_rate as f32)
255 .with_momentum(momentum as f32)
256 .with_weight_decay(weight_decay as f32),
257 }
258 }
259
260 #[napi]
269 pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
270 let mut params_vec = params.to_vec();
271 let gradients_slice = gradients.as_ref();
272 self.inner.step(&mut params_vec, gradients_slice);
273 Float32Array::new(params_vec)
274 }
275
276 #[napi]
278 pub fn reset(&mut self) {
279 self.inner.reset();
280 }
281
282 #[napi(getter)]
284 pub fn learning_rate(&self) -> f64 {
285 self.inner.learning_rate() as f64
286 }
287
288 #[napi(setter)]
290 pub fn set_learning_rate(&mut self, lr: f64) {
291 self.inner.set_learning_rate(lr as f32);
292 }
293}
294
295#[napi]
297pub struct AdamOptimizer {
298 inner: RustAdam,
299}
300
301#[napi]
302impl AdamOptimizer {
303 #[napi(constructor)]
309 pub fn new(param_count: u32, learning_rate: f64) -> Self {
310 Self {
311 inner: RustAdam::new(param_count as usize, learning_rate as f32),
312 }
313 }
314
315 #[napi(factory)]
317 pub fn with_betas(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64) -> Self {
318 Self {
319 inner: RustAdam::new(param_count as usize, learning_rate as f32)
320 .with_betas(beta1 as f32, beta2 as f32),
321 }
322 }
323
324 #[napi(factory)]
326 pub fn with_config(
327 param_count: u32,
328 learning_rate: f64,
329 beta1: f64,
330 beta2: f64,
331 epsilon: f64,
332 weight_decay: f64,
333 ) -> Self {
334 Self {
335 inner: RustAdam::new(param_count as usize, learning_rate as f32)
336 .with_betas(beta1 as f32, beta2 as f32)
337 .with_epsilon(epsilon as f32)
338 .with_weight_decay(weight_decay as f32),
339 }
340 }
341
342 #[napi]
347 pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
348 let mut params_vec = params.to_vec();
349 let gradients_slice = gradients.as_ref();
350 self.inner.step(&mut params_vec, gradients_slice);
351 Float32Array::new(params_vec)
352 }
353
354 #[napi]
356 pub fn reset(&mut self) {
357 self.inner.reset();
358 }
359
360 #[napi(getter)]
362 pub fn learning_rate(&self) -> f64 {
363 self.inner.learning_rate() as f64
364 }
365
366 #[napi(setter)]
368 pub fn set_learning_rate(&mut self, lr: f64) {
369 self.inner.set_learning_rate(lr as f32);
370 }
371}
372
373#[napi]
375pub struct AdamWOptimizer {
376 inner: RustAdamW,
377 wd: f32,
378}
379
380#[napi]
381impl AdamWOptimizer {
382 #[napi(constructor)]
389 pub fn new(param_count: u32, learning_rate: f64, weight_decay: f64) -> Self {
390 Self {
391 inner: RustAdamW::new(param_count as usize, learning_rate as f32)
392 .with_weight_decay(weight_decay as f32),
393 wd: weight_decay as f32,
394 }
395 }
396
397 #[napi(factory)]
399 pub fn with_betas(
400 param_count: u32,
401 learning_rate: f64,
402 weight_decay: f64,
403 beta1: f64,
404 beta2: f64,
405 ) -> Self {
406 Self {
407 inner: RustAdamW::new(param_count as usize, learning_rate as f32)
408 .with_weight_decay(weight_decay as f32)
409 .with_betas(beta1 as f32, beta2 as f32),
410 wd: weight_decay as f32,
411 }
412 }
413
414 #[napi]
419 pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
420 let mut params_vec = params.to_vec();
421 let gradients_slice = gradients.as_ref();
422 self.inner.step(&mut params_vec, gradients_slice);
423 Float32Array::new(params_vec)
424 }
425
426 #[napi]
428 pub fn reset(&mut self) {
429 self.inner.reset();
430 }
431
432 #[napi(getter)]
434 pub fn learning_rate(&self) -> f64 {
435 self.inner.learning_rate() as f64
436 }
437
438 #[napi(setter)]
440 pub fn set_learning_rate(&mut self, lr: f64) {
441 self.inner.set_learning_rate(lr as f32);
442 }
443
444 #[napi(getter)]
446 pub fn weight_decay(&self) -> f64 {
447 self.wd as f64
448 }
449}
450
451#[napi]
457pub struct LearningRateScheduler {
458 initial_lr: f32,
459 current_step: usize,
460 warmup_steps: usize,
461 total_steps: usize,
462 min_lr: f32,
463}
464
465#[napi]
466impl LearningRateScheduler {
467 #[napi(constructor)]
474 pub fn new(initial_lr: f64, warmup_steps: u32, total_steps: u32) -> Self {
475 Self {
476 initial_lr: initial_lr as f32,
477 current_step: 0,
478 warmup_steps: warmup_steps as usize,
479 total_steps: total_steps as usize,
480 min_lr: 1e-7,
481 }
482 }
483
484 #[napi(factory)]
486 pub fn with_min_lr(initial_lr: f64, warmup_steps: u32, total_steps: u32, min_lr: f64) -> Self {
487 Self {
488 initial_lr: initial_lr as f32,
489 current_step: 0,
490 warmup_steps: warmup_steps as usize,
491 total_steps: total_steps as usize,
492 min_lr: min_lr as f32,
493 }
494 }
495
496 #[napi]
498 pub fn get_lr(&self) -> f64 {
499 if self.current_step < self.warmup_steps {
500 (self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32) as f64
502 } else {
503 let progress = (self.current_step - self.warmup_steps) as f32
505 / (self.total_steps - self.warmup_steps).max(1) as f32;
506 let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
507 (self.min_lr + (self.initial_lr - self.min_lr) * decay) as f64
508 }
509 }
510
511 #[napi]
513 pub fn step(&mut self) -> f64 {
514 let lr = self.get_lr();
515 self.current_step += 1;
516 lr
517 }
518
519 #[napi]
521 pub fn reset(&mut self) {
522 self.current_step = 0;
523 }
524
525 #[napi(getter)]
527 pub fn current_step(&self) -> u32 {
528 self.current_step as u32
529 }
530
531 #[napi(getter)]
533 pub fn progress(&self) -> f64 {
534 (self.current_step as f64 / self.total_steps.max(1) as f64).min(1.0)
535 }
536}
537
538#[napi(string_enum)]
544pub enum DecayType {
545 Linear,
546 Exponential,
547 Cosine,
548 Step,
549}
550
551impl From<DecayType> for RustDecayType {
552 fn from(dt: DecayType) -> Self {
553 match dt {
554 DecayType::Linear => RustDecayType::Linear,
555 DecayType::Exponential => RustDecayType::Exponential,
556 DecayType::Cosine => RustDecayType::Cosine,
557 DecayType::Step => RustDecayType::Step,
558 }
559 }
560}
561
562#[napi]
564pub struct TemperatureAnnealing {
565 inner: RustTempAnnealing,
566}
567
568#[napi]
569impl TemperatureAnnealing {
570 #[napi(constructor)]
577 pub fn new(initial_temp: f64, final_temp: f64, steps: u32) -> Self {
578 Self {
579 inner: RustTempAnnealing::new(initial_temp as f32, final_temp as f32, steps as usize),
580 }
581 }
582
583 #[napi(factory)]
585 pub fn with_decay(
586 initial_temp: f64,
587 final_temp: f64,
588 steps: u32,
589 decay_type: DecayType,
590 ) -> Self {
591 Self {
592 inner: RustTempAnnealing::new(initial_temp as f32, final_temp as f32, steps as usize)
593 .with_decay(decay_type.into()),
594 }
595 }
596
597 #[napi]
599 pub fn get_temp(&self) -> f64 {
600 self.inner.get_temp() as f64
601 }
602
603 #[napi]
605 pub fn step(&mut self) -> f64 {
606 self.inner.step() as f64
607 }
608
609 #[napi]
611 pub fn reset(&mut self) {
612 self.inner.reset();
613 }
614}
615
616#[napi(object)]
622pub struct CurriculumStageConfig {
623 pub name: String,
624 pub difficulty: f64,
625 pub duration: u32,
626 pub temperature: f64,
627 pub negative_count: u32,
628}
629
630#[napi]
632pub struct CurriculumScheduler {
633 inner: RustCurriculum,
634}
635
636#[napi]
637impl CurriculumScheduler {
638 #[napi(constructor)]
640 pub fn new() -> Self {
641 Self {
642 inner: RustCurriculum::new(),
643 }
644 }
645
646 #[napi(factory)]
648 pub fn default_curriculum(total_steps: u32) -> Self {
649 Self {
650 inner: RustCurriculum::default_curriculum(total_steps as usize),
651 }
652 }
653
654 #[napi]
656 pub fn add_stage(&mut self, config: CurriculumStageConfig) {
657 let stage = RustStage::new(&config.name)
658 .difficulty(config.difficulty as f32)
659 .duration(config.duration as usize)
660 .temperature(config.temperature as f32)
661 .negative_count(config.negative_count as usize);
662
663 let new_inner = std::mem::take(&mut self.inner).add_stage(stage);
665 self.inner = new_inner;
666 }
667
668 #[napi]
670 pub fn step(&mut self) -> Option<CurriculumStageConfig> {
671 self.inner.step().map(|s| CurriculumStageConfig {
672 name: s.name.clone(),
673 difficulty: s.difficulty as f64,
674 duration: s.duration as u32,
675 temperature: s.temperature as f64,
676 negative_count: s.negative_count as u32,
677 })
678 }
679
680 #[napi(getter)]
682 pub fn difficulty(&self) -> f64 {
683 self.inner.difficulty() as f64
684 }
685
686 #[napi(getter)]
688 pub fn temperature(&self) -> f64 {
689 self.inner.temperature() as f64
690 }
691
692 #[napi(getter)]
694 pub fn negative_count(&self) -> u32 {
695 self.inner.negative_count() as u32
696 }
697
698 #[napi(getter)]
700 pub fn is_complete(&self) -> bool {
701 self.inner.is_complete()
702 }
703
704 #[napi(getter)]
706 pub fn progress(&self) -> f64 {
707 self.inner.progress() as f64
708 }
709
710 #[napi]
712 pub fn reset(&mut self) {
713 self.inner.reset();
714 }
715}
716
717#[napi(string_enum)]
723pub enum MiningStrategy {
724 Random,
725 HardNegative,
726 SemiHard,
727 DistanceWeighted,
728}
729
730impl From<MiningStrategy> for RustMiningStrategy {
731 fn from(ms: MiningStrategy) -> Self {
732 match ms {
733 MiningStrategy::Random => RustMiningStrategy::Random,
734 MiningStrategy::HardNegative => RustMiningStrategy::HardNegative,
735 MiningStrategy::SemiHard => RustMiningStrategy::SemiHard,
736 MiningStrategy::DistanceWeighted => RustMiningStrategy::DistanceWeighted,
737 }
738 }
739}
740
741#[napi]
743pub struct HardNegativeMiner {
744 inner: RustHardMiner,
745}
746
747#[napi]
748impl HardNegativeMiner {
749 #[napi(constructor)]
754 pub fn new(strategy: MiningStrategy) -> Self {
755 Self {
756 inner: RustHardMiner::new(strategy.into()),
757 }
758 }
759
760 #[napi(factory)]
762 pub fn with_margin(strategy: MiningStrategy, margin: f64) -> Self {
763 Self {
764 inner: RustHardMiner::new(strategy.into()).with_margin(margin as f32),
765 }
766 }
767
768 #[napi(factory)]
770 pub fn with_temperature(strategy: MiningStrategy, temperature: f64) -> Self {
771 Self {
772 inner: RustHardMiner::new(strategy.into()).with_temperature(temperature as f32),
773 }
774 }
775
776 #[napi]
787 pub fn mine(
788 &self,
789 anchor: Float32Array,
790 positive: Float32Array,
791 candidates: Vec<Float32Array>,
792 num_negatives: u32,
793 ) -> Vec<u32> {
794 let anchor_slice = anchor.as_ref();
795 let positive_slice = positive.as_ref();
796 let candidates_vec: Vec<Vec<f32>> = candidates.into_iter().map(|c| c.to_vec()).collect();
797 let candidates_refs: Vec<&[f32]> = candidates_vec.iter().map(|c| c.as_slice()).collect();
798
799 self.inner
800 .mine(
801 anchor_slice,
802 positive_slice,
803 &candidates_refs,
804 num_negatives as usize,
805 )
806 .into_iter()
807 .map(|i| i as u32)
808 .collect()
809 }
810}
811
812#[napi]
814pub struct InBatchMiner {
815 exclude_positive: bool,
816}
817
818#[napi]
819impl InBatchMiner {
820 #[napi(constructor)]
822 pub fn new() -> Self {
823 Self {
824 exclude_positive: true,
825 }
826 }
827
828 #[napi(factory)]
830 pub fn include_positive() -> Self {
831 Self {
832 exclude_positive: false,
833 }
834 }
835
836 #[napi]
846 pub fn get_negatives(&self, anchor_idx: u32, positive_idx: u32, batch_size: u32) -> Vec<u32> {
847 (0..batch_size)
848 .filter(|&i| i != anchor_idx && (!self.exclude_positive || i != positive_idx))
849 .collect()
850 }
851}