1use napi::bindgen_prelude::*;
11use napi_derive::napi;
12use ruvector_attention::training::{
13 InfoNCELoss as RustInfoNCE,
14 LocalContrastiveLoss as RustLocalContrastive,
15 SpectralRegularization as RustSpectralReg,
16 Loss,
17 SGD as RustSGD,
18 Adam as RustAdam,
19 AdamW as RustAdamW,
20 Optimizer,
21 CurriculumScheduler as RustCurriculum,
22 CurriculumStage as RustStage,
23 TemperatureAnnealing as RustTempAnnealing,
24 DecayType as RustDecayType,
25 HardNegativeMiner as RustHardMiner,
26 MiningStrategy as RustMiningStrategy,
27 NegativeMiner,
28};
29
30#[napi]
36pub struct InfoNCELoss {
37 inner: RustInfoNCE,
38 temperature_value: f32,
39}
40
41#[napi]
42impl InfoNCELoss {
43 #[napi(constructor)]
48 pub fn new(temperature: f64) -> Self {
49 Self {
50 inner: RustInfoNCE::new(temperature as f32),
51 temperature_value: temperature as f32,
52 }
53 }
54
55 #[napi]
62 pub fn compute(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> f64 {
63 let anchor_slice = anchor.as_ref();
64 let positive_slice = positive.as_ref();
65 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
66 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
67
68 self.inner.compute(anchor_slice, positive_slice, &negatives_refs) as f64
69 }
70
71 #[napi]
75 pub fn compute_with_gradients(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> LossWithGradients {
76 let anchor_slice = anchor.as_ref();
77 let positive_slice = positive.as_ref();
78 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
79 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
80
81 let (loss, gradients) = self.inner.compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
82
83 LossWithGradients {
84 loss: loss as f64,
85 gradients: Float32Array::new(gradients),
86 }
87 }
88
89 #[napi(getter)]
91 pub fn temperature(&self) -> f64 {
92 self.temperature_value as f64
93 }
94}
95
96#[napi(object)]
98pub struct LossWithGradients {
99 pub loss: f64,
100 pub gradients: Float32Array,
101}
102
103#[napi]
105pub struct LocalContrastiveLoss {
106 inner: RustLocalContrastive,
107 margin_value: f32,
108}
109
110#[napi]
111impl LocalContrastiveLoss {
112 #[napi(constructor)]
117 pub fn new(margin: f64) -> Self {
118 Self {
119 inner: RustLocalContrastive::new(margin as f32),
120 margin_value: margin as f32,
121 }
122 }
123
124 #[napi]
126 pub fn compute(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> f64 {
127 let anchor_slice = anchor.as_ref();
128 let positive_slice = positive.as_ref();
129 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
130 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
131
132 self.inner.compute(anchor_slice, positive_slice, &negatives_refs) as f64
133 }
134
135 #[napi]
137 pub fn compute_with_gradients(&self, anchor: Float32Array, positive: Float32Array, negatives: Vec<Float32Array>) -> LossWithGradients {
138 let anchor_slice = anchor.as_ref();
139 let positive_slice = positive.as_ref();
140 let negatives_vec: Vec<Vec<f32>> = negatives.into_iter().map(|n| n.to_vec()).collect();
141 let negatives_refs: Vec<&[f32]> = negatives_vec.iter().map(|n| n.as_slice()).collect();
142
143 let (loss, gradients) = self.inner.compute_with_gradients(anchor_slice, positive_slice, &negatives_refs);
144
145 LossWithGradients {
146 loss: loss as f64,
147 gradients: Float32Array::new(gradients),
148 }
149 }
150
151 #[napi(getter)]
153 pub fn margin(&self) -> f64 {
154 self.margin_value as f64
155 }
156}
157
158#[napi]
160pub struct SpectralRegularization {
161 inner: RustSpectralReg,
162 weight_value: f32,
163}
164
165#[napi]
166impl SpectralRegularization {
167 #[napi(constructor)]
172 pub fn new(weight: f64) -> Self {
173 Self {
174 inner: RustSpectralReg::new(weight as f32),
175 weight_value: weight as f32,
176 }
177 }
178
179 #[napi]
181 pub fn compute_batch(&self, embeddings: Vec<Float32Array>) -> f64 {
182 let embeddings_vec: Vec<Vec<f32>> = embeddings.into_iter().map(|e| e.to_vec()).collect();
183 let embeddings_refs: Vec<&[f32]> = embeddings_vec.iter().map(|e| e.as_slice()).collect();
184
185 self.inner.compute_batch(&embeddings_refs) as f64
186 }
187
188 #[napi(getter)]
190 pub fn weight(&self) -> f64 {
191 self.weight_value as f64
192 }
193}
194
195#[napi]
201pub struct SGDOptimizer {
202 inner: RustSGD,
203}
204
205#[napi]
206impl SGDOptimizer {
207 #[napi(constructor)]
213 pub fn new(param_count: u32, learning_rate: f64) -> Self {
214 Self {
215 inner: RustSGD::new(param_count as usize, learning_rate as f32),
216 }
217 }
218
219 #[napi(factory)]
221 pub fn with_momentum(param_count: u32, learning_rate: f64, momentum: f64) -> Self {
222 Self {
223 inner: RustSGD::new(param_count as usize, learning_rate as f32)
224 .with_momentum(momentum as f32),
225 }
226 }
227
228 #[napi(factory)]
230 pub fn with_weight_decay(param_count: u32, learning_rate: f64, momentum: f64, weight_decay: f64) -> Self {
231 Self {
232 inner: RustSGD::new(param_count as usize, learning_rate as f32)
233 .with_momentum(momentum as f32)
234 .with_weight_decay(weight_decay as f32),
235 }
236 }
237
238 #[napi]
247 pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
248 let mut params_vec = params.to_vec();
249 let gradients_slice = gradients.as_ref();
250 self.inner.step(&mut params_vec, gradients_slice);
251 Float32Array::new(params_vec)
252 }
253
254 #[napi]
256 pub fn reset(&mut self) {
257 self.inner.reset();
258 }
259
260 #[napi(getter)]
262 pub fn learning_rate(&self) -> f64 {
263 self.inner.learning_rate() as f64
264 }
265
266 #[napi(setter)]
268 pub fn set_learning_rate(&mut self, lr: f64) {
269 self.inner.set_learning_rate(lr as f32);
270 }
271}
272
273#[napi]
275pub struct AdamOptimizer {
276 inner: RustAdam,
277}
278
279#[napi]
280impl AdamOptimizer {
281 #[napi(constructor)]
287 pub fn new(param_count: u32, learning_rate: f64) -> Self {
288 Self {
289 inner: RustAdam::new(param_count as usize, learning_rate as f32),
290 }
291 }
292
293 #[napi(factory)]
295 pub fn with_betas(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64) -> Self {
296 Self {
297 inner: RustAdam::new(param_count as usize, learning_rate as f32)
298 .with_betas(beta1 as f32, beta2 as f32),
299 }
300 }
301
302 #[napi(factory)]
304 pub fn with_config(param_count: u32, learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64, weight_decay: f64) -> Self {
305 Self {
306 inner: RustAdam::new(param_count as usize, learning_rate as f32)
307 .with_betas(beta1 as f32, beta2 as f32)
308 .with_epsilon(epsilon as f32)
309 .with_weight_decay(weight_decay as f32),
310 }
311 }
312
313 #[napi]
318 pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
319 let mut params_vec = params.to_vec();
320 let gradients_slice = gradients.as_ref();
321 self.inner.step(&mut params_vec, gradients_slice);
322 Float32Array::new(params_vec)
323 }
324
325 #[napi]
327 pub fn reset(&mut self) {
328 self.inner.reset();
329 }
330
331 #[napi(getter)]
333 pub fn learning_rate(&self) -> f64 {
334 self.inner.learning_rate() as f64
335 }
336
337 #[napi(setter)]
339 pub fn set_learning_rate(&mut self, lr: f64) {
340 self.inner.set_learning_rate(lr as f32);
341 }
342}
343
344#[napi]
346pub struct AdamWOptimizer {
347 inner: RustAdamW,
348 wd: f32,
349}
350
351#[napi]
352impl AdamWOptimizer {
353 #[napi(constructor)]
360 pub fn new(param_count: u32, learning_rate: f64, weight_decay: f64) -> Self {
361 Self {
362 inner: RustAdamW::new(param_count as usize, learning_rate as f32)
363 .with_weight_decay(weight_decay as f32),
364 wd: weight_decay as f32,
365 }
366 }
367
368 #[napi(factory)]
370 pub fn with_betas(param_count: u32, learning_rate: f64, weight_decay: f64, beta1: f64, beta2: f64) -> Self {
371 Self {
372 inner: RustAdamW::new(param_count as usize, learning_rate as f32)
373 .with_weight_decay(weight_decay as f32)
374 .with_betas(beta1 as f32, beta2 as f32),
375 wd: weight_decay as f32,
376 }
377 }
378
379 #[napi]
384 pub fn step(&mut self, params: Float32Array, gradients: Float32Array) -> Float32Array {
385 let mut params_vec = params.to_vec();
386 let gradients_slice = gradients.as_ref();
387 self.inner.step(&mut params_vec, gradients_slice);
388 Float32Array::new(params_vec)
389 }
390
391 #[napi]
393 pub fn reset(&mut self) {
394 self.inner.reset();
395 }
396
397 #[napi(getter)]
399 pub fn learning_rate(&self) -> f64 {
400 self.inner.learning_rate() as f64
401 }
402
403 #[napi(setter)]
405 pub fn set_learning_rate(&mut self, lr: f64) {
406 self.inner.set_learning_rate(lr as f32);
407 }
408
409 #[napi(getter)]
411 pub fn weight_decay(&self) -> f64 {
412 self.wd as f64
413 }
414}
415
416#[napi]
422pub struct LearningRateScheduler {
423 initial_lr: f32,
424 current_step: usize,
425 warmup_steps: usize,
426 total_steps: usize,
427 min_lr: f32,
428}
429
430#[napi]
431impl LearningRateScheduler {
432 #[napi(constructor)]
439 pub fn new(initial_lr: f64, warmup_steps: u32, total_steps: u32) -> Self {
440 Self {
441 initial_lr: initial_lr as f32,
442 current_step: 0,
443 warmup_steps: warmup_steps as usize,
444 total_steps: total_steps as usize,
445 min_lr: 1e-7,
446 }
447 }
448
449 #[napi(factory)]
451 pub fn with_min_lr(initial_lr: f64, warmup_steps: u32, total_steps: u32, min_lr: f64) -> Self {
452 Self {
453 initial_lr: initial_lr as f32,
454 current_step: 0,
455 warmup_steps: warmup_steps as usize,
456 total_steps: total_steps as usize,
457 min_lr: min_lr as f32,
458 }
459 }
460
461 #[napi]
463 pub fn get_lr(&self) -> f64 {
464 if self.current_step < self.warmup_steps {
465 (self.initial_lr * (self.current_step + 1) as f32 / self.warmup_steps as f32) as f64
467 } else {
468 let progress = (self.current_step - self.warmup_steps) as f32
470 / (self.total_steps - self.warmup_steps).max(1) as f32;
471 let decay = 0.5 * (1.0 + (std::f32::consts::PI * progress.min(1.0)).cos());
472 (self.min_lr + (self.initial_lr - self.min_lr) * decay) as f64
473 }
474 }
475
476 #[napi]
478 pub fn step(&mut self) -> f64 {
479 let lr = self.get_lr();
480 self.current_step += 1;
481 lr
482 }
483
484 #[napi]
486 pub fn reset(&mut self) {
487 self.current_step = 0;
488 }
489
490 #[napi(getter)]
492 pub fn current_step(&self) -> u32 {
493 self.current_step as u32
494 }
495
496 #[napi(getter)]
498 pub fn progress(&self) -> f64 {
499 (self.current_step as f64 / self.total_steps.max(1) as f64).min(1.0)
500 }
501}
502
503#[napi(string_enum)]
509pub enum DecayType {
510 Linear,
511 Exponential,
512 Cosine,
513 Step,
514}
515
516impl From<DecayType> for RustDecayType {
517 fn from(dt: DecayType) -> Self {
518 match dt {
519 DecayType::Linear => RustDecayType::Linear,
520 DecayType::Exponential => RustDecayType::Exponential,
521 DecayType::Cosine => RustDecayType::Cosine,
522 DecayType::Step => RustDecayType::Step,
523 }
524 }
525}
526
527#[napi]
529pub struct TemperatureAnnealing {
530 inner: RustTempAnnealing,
531}
532
533#[napi]
534impl TemperatureAnnealing {
535 #[napi(constructor)]
542 pub fn new(initial_temp: f64, final_temp: f64, steps: u32) -> Self {
543 Self {
544 inner: RustTempAnnealing::new(
545 initial_temp as f32,
546 final_temp as f32,
547 steps as usize,
548 ),
549 }
550 }
551
552 #[napi(factory)]
554 pub fn with_decay(initial_temp: f64, final_temp: f64, steps: u32, decay_type: DecayType) -> Self {
555 Self {
556 inner: RustTempAnnealing::new(
557 initial_temp as f32,
558 final_temp as f32,
559 steps as usize,
560 ).with_decay(decay_type.into()),
561 }
562 }
563
564 #[napi]
566 pub fn get_temp(&self) -> f64 {
567 self.inner.get_temp() as f64
568 }
569
570 #[napi]
572 pub fn step(&mut self) -> f64 {
573 self.inner.step() as f64
574 }
575
576 #[napi]
578 pub fn reset(&mut self) {
579 self.inner.reset();
580 }
581}
582
583#[napi(object)]
589pub struct CurriculumStageConfig {
590 pub name: String,
591 pub difficulty: f64,
592 pub duration: u32,
593 pub temperature: f64,
594 pub negative_count: u32,
595}
596
597#[napi]
599pub struct CurriculumScheduler {
600 inner: RustCurriculum,
601}
602
603#[napi]
604impl CurriculumScheduler {
605 #[napi(constructor)]
607 pub fn new() -> Self {
608 Self {
609 inner: RustCurriculum::new(),
610 }
611 }
612
613 #[napi(factory)]
615 pub fn default_curriculum(total_steps: u32) -> Self {
616 Self {
617 inner: RustCurriculum::default_curriculum(total_steps as usize),
618 }
619 }
620
621 #[napi]
623 pub fn add_stage(&mut self, config: CurriculumStageConfig) {
624 let stage = RustStage::new(&config.name)
625 .difficulty(config.difficulty as f32)
626 .duration(config.duration as usize)
627 .temperature(config.temperature as f32)
628 .negative_count(config.negative_count as usize);
629
630 let new_inner = std::mem::take(&mut self.inner).add_stage(stage);
632 self.inner = new_inner;
633 }
634
635 #[napi]
637 pub fn step(&mut self) -> Option<CurriculumStageConfig> {
638 self.inner.step().map(|s| CurriculumStageConfig {
639 name: s.name.clone(),
640 difficulty: s.difficulty as f64,
641 duration: s.duration as u32,
642 temperature: s.temperature as f64,
643 negative_count: s.negative_count as u32,
644 })
645 }
646
647 #[napi(getter)]
649 pub fn difficulty(&self) -> f64 {
650 self.inner.difficulty() as f64
651 }
652
653 #[napi(getter)]
655 pub fn temperature(&self) -> f64 {
656 self.inner.temperature() as f64
657 }
658
659 #[napi(getter)]
661 pub fn negative_count(&self) -> u32 {
662 self.inner.negative_count() as u32
663 }
664
665 #[napi(getter)]
667 pub fn is_complete(&self) -> bool {
668 self.inner.is_complete()
669 }
670
671 #[napi(getter)]
673 pub fn progress(&self) -> f64 {
674 self.inner.progress() as f64
675 }
676
677 #[napi]
679 pub fn reset(&mut self) {
680 self.inner.reset();
681 }
682}
683
684#[napi(string_enum)]
690pub enum MiningStrategy {
691 Random,
692 HardNegative,
693 SemiHard,
694 DistanceWeighted,
695}
696
697impl From<MiningStrategy> for RustMiningStrategy {
698 fn from(ms: MiningStrategy) -> Self {
699 match ms {
700 MiningStrategy::Random => RustMiningStrategy::Random,
701 MiningStrategy::HardNegative => RustMiningStrategy::HardNegative,
702 MiningStrategy::SemiHard => RustMiningStrategy::SemiHard,
703 MiningStrategy::DistanceWeighted => RustMiningStrategy::DistanceWeighted,
704 }
705 }
706}
707
708#[napi]
710pub struct HardNegativeMiner {
711 inner: RustHardMiner,
712}
713
714#[napi]
715impl HardNegativeMiner {
716 #[napi(constructor)]
721 pub fn new(strategy: MiningStrategy) -> Self {
722 Self {
723 inner: RustHardMiner::new(strategy.into()),
724 }
725 }
726
727 #[napi(factory)]
729 pub fn with_margin(strategy: MiningStrategy, margin: f64) -> Self {
730 Self {
731 inner: RustHardMiner::new(strategy.into())
732 .with_margin(margin as f32),
733 }
734 }
735
736 #[napi(factory)]
738 pub fn with_temperature(strategy: MiningStrategy, temperature: f64) -> Self {
739 Self {
740 inner: RustHardMiner::new(strategy.into())
741 .with_temperature(temperature as f32),
742 }
743 }
744
745 #[napi]
756 pub fn mine(
757 &self,
758 anchor: Float32Array,
759 positive: Float32Array,
760 candidates: Vec<Float32Array>,
761 num_negatives: u32,
762 ) -> Vec<u32> {
763 let anchor_slice = anchor.as_ref();
764 let positive_slice = positive.as_ref();
765 let candidates_vec: Vec<Vec<f32>> = candidates.into_iter().map(|c| c.to_vec()).collect();
766 let candidates_refs: Vec<&[f32]> = candidates_vec.iter().map(|c| c.as_slice()).collect();
767
768 self.inner
769 .mine(anchor_slice, positive_slice, &candidates_refs, num_negatives as usize)
770 .into_iter()
771 .map(|i| i as u32)
772 .collect()
773 }
774}
775
776#[napi]
778pub struct InBatchMiner {
779 exclude_positive: bool,
780}
781
782#[napi]
783impl InBatchMiner {
784 #[napi(constructor)]
786 pub fn new() -> Self {
787 Self {
788 exclude_positive: true,
789 }
790 }
791
792 #[napi(factory)]
794 pub fn include_positive() -> Self {
795 Self {
796 exclude_positive: false,
797 }
798 }
799
800 #[napi]
810 pub fn get_negatives(&self, anchor_idx: u32, positive_idx: u32, batch_size: u32) -> Vec<u32> {
811 (0..batch_size)
812 .filter(|&i| {
813 i != anchor_idx && (!self.exclude_positive || i != positive_idx)
814 })
815 .collect()
816 }
817}