1use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::random::Random;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
13 types::Float,
14};
15
16#[derive(Debug, Clone)]
18pub struct ProjectionNetwork {
19 pub weights: Vec<Array2<f64>>,
21 pub biases: Vec<Array1<f64>>,
23 pub architecture: Vec<usize>,
25 pub output_dim: usize,
27}
28
29impl ProjectionNetwork {
30 pub fn new(input_dim: usize, output_dim: usize, hidden_dims: Vec<usize>) -> Self {
32 let mut architecture = vec![input_dim];
33 architecture.extend(hidden_dims);
34 architecture.push(output_dim);
35
36 let mut weights = Vec::new();
37 let mut biases = Vec::new();
38
39 for i in 0..architecture.len() - 1 {
40 let input_size = architecture[i];
41 let output_size = architecture[i + 1];
42
43 let scale = (2.0 / (input_size + output_size) as f64).sqrt();
45 let mut rng = Random::default();
46 let mut w = Array2::<f64>::zeros((output_size, input_size));
47 for i in 0..output_size {
48 for j in 0..input_size {
49 let u1: f64 = rng.random_range(0.0..1.0);
51 let u2: f64 = rng.random_range(0.0..1.0);
52 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
53 w[(i, j)] = z * scale;
54 }
55 }
56 let b = Array1::zeros(output_size);
57
58 weights.push(w);
59 biases.push(b);
60 }
61
62 Self {
63 weights,
64 biases,
65 architecture,
66 output_dim,
67 }
68 }
69
70 pub fn forward(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
72 let mut current = x.to_owned();
73
74 for (i, (weights, biases)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
75 let linear = weights.dot(¤t) + biases;
76
77 current = if i < self.weights.len() - 1 {
79 linear.mapv(|x| x.max(0.0))
80 } else {
81 let norm = (linear.mapv(|x| x * x).sum() + 1e-12).sqrt();
83 linear / norm
84 };
85 }
86
87 Ok(current)
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct CrossModalContrastive<S = Untrained> {
94 state: S,
95 projection1: Option<ProjectionNetwork>,
97 projection2: Option<ProjectionNetwork>,
99 classifier_weights: Option<Array2<f64>>,
101 classifier_biases: Option<Array1<f64>>,
102 projection_dim: usize,
104 n_classes: usize,
106 hidden_dims: Vec<usize>,
108 temperature: f64,
110 learning_rate: f64,
112 max_iter: usize,
114 contrastive_weight: f64,
116 supervised_weight: f64,
118 random_state: Option<u64>,
120}
121
122impl Default for CrossModalContrastive<Untrained> {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128impl CrossModalContrastive<Untrained> {
129 pub fn new() -> Self {
131 Self {
132 state: Untrained,
133 projection1: None,
134 projection2: None,
135 classifier_weights: None,
136 classifier_biases: None,
137 projection_dim: 128,
138 n_classes: 2,
139 hidden_dims: vec![256, 128],
140 temperature: 0.07,
141 learning_rate: 0.001,
142 max_iter: 100,
143 contrastive_weight: 1.0,
144 supervised_weight: 1.0,
145 random_state: None,
146 }
147 }
148
149 pub fn projection_dim(mut self, dim: usize) -> Self {
151 self.projection_dim = dim;
152 self
153 }
154
155 pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
157 self.hidden_dims = dims;
158 self
159 }
160
161 pub fn temperature(mut self, temp: f64) -> Self {
163 self.temperature = temp;
164 self
165 }
166
167 pub fn learning_rate(mut self, lr: f64) -> Self {
169 self.learning_rate = lr;
170 self
171 }
172
173 pub fn max_iter(mut self, max_iter: usize) -> Self {
175 self.max_iter = max_iter;
176 self
177 }
178
179 pub fn contrastive_weight(mut self, weight: f64) -> Self {
181 self.contrastive_weight = weight;
182 self
183 }
184
185 pub fn supervised_weight(mut self, weight: f64) -> Self {
187 self.supervised_weight = weight;
188 self
189 }
190
191 pub fn random_state(mut self, seed: u64) -> Self {
193 self.random_state = Some(seed);
194 self
195 }
196
197 fn initialize_networks(&mut self, input_dim1: usize, input_dim2: usize, n_classes: usize) {
199 self.projection1 = Some(ProjectionNetwork::new(
200 input_dim1,
201 self.projection_dim,
202 self.hidden_dims.clone(),
203 ));
204
205 self.projection2 = Some(ProjectionNetwork::new(
206 input_dim2,
207 self.projection_dim,
208 self.hidden_dims.clone(),
209 ));
210
211 let combined_dim = self.projection_dim * 2;
213 let mut rng = Random::default();
215 let mut weights = Array2::<f64>::zeros((n_classes, combined_dim));
216 for i in 0..n_classes {
217 for j in 0..combined_dim {
218 let u1: f64 = rng.random_range(0.0..1.0);
220 let u2: f64 = rng.random_range(0.0..1.0);
221 let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
222 weights[(i, j)] = z * 0.1;
223 }
224 }
225 self.classifier_weights = Some(weights);
226 self.classifier_biases = Some(Array1::zeros(n_classes));
227
228 self.n_classes = n_classes;
229 }
230
231 fn contrastive_loss(&self, z1: &ArrayView2<f64>, z2: &ArrayView2<f64>) -> SklResult<f64> {
233 let batch_size = z1.nrows();
234 if batch_size != z2.nrows() {
235 return Err(SklearsError::InvalidInput(
236 "Batch sizes must match".to_string(),
237 ));
238 }
239
240 let mut total_loss = 0.0;
241
242 for i in 0..batch_size {
243 let z1_i = z1.row(i);
244 let z2_i = z2.row(i);
245
246 let pos_sim = z1_i.dot(&z2_i) / self.temperature;
248
249 let mut neg_sims = Vec::new();
251 for j in 0..batch_size {
252 if i != j {
253 let sim1 = z1_i.dot(&z1.row(j)) / self.temperature;
254 let sim2 = z1_i.dot(&z2.row(j)) / self.temperature;
255 neg_sims.push(sim1);
256 neg_sims.push(sim2);
257 }
258 }
259
260 let mut exp_sum = pos_sim.exp();
262 for &sim in &neg_sims {
263 exp_sum += sim.exp();
264 }
265
266 let loss = -pos_sim + (exp_sum + 1e-12).ln();
268 total_loss += loss;
269 }
270
271 Ok(total_loss / batch_size as f64)
272 }
273
274 fn project_features(
276 &self,
277 x1: &ArrayView2<f64>,
278 x2: &ArrayView2<f64>,
279 ) -> SklResult<(Array2<f64>, Array2<f64>)> {
280 let proj1 = self.projection1.as_ref().ok_or_else(|| {
281 SklearsError::InvalidInput("Projection network 1 not initialized".to_string())
282 })?;
283
284 let proj2 = self.projection2.as_ref().ok_or_else(|| {
285 SklearsError::InvalidInput("Projection network 2 not initialized".to_string())
286 })?;
287
288 let batch_size = x1.nrows();
289 let mut z1 = Array2::zeros((batch_size, self.projection_dim));
290 let mut z2 = Array2::zeros((batch_size, self.projection_dim));
291
292 for i in 0..batch_size {
293 let proj1_output = proj1.forward(&x1.row(i))?;
294 let proj2_output = proj2.forward(&x2.row(i))?;
295
296 z1.row_mut(i).assign(&proj1_output);
297 z2.row_mut(i).assign(&proj2_output);
298 }
299
300 Ok((z1, z2))
301 }
302
303 fn classify(&self, z1: &ArrayView1<f64>, z2: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
305 match (&self.classifier_weights, &self.classifier_biases) {
306 (Some(weights), Some(biases)) => {
307 let mut combined = Array1::zeros(z1.len() + z2.len());
309 combined.slice_mut(s![..z1.len()]).assign(z1);
310 combined.slice_mut(s![z1.len()..]).assign(z2);
311
312 let logits = weights.dot(&combined) + biases;
313 Ok(self.softmax(&logits.view()))
314 }
315 _ => Err(SklearsError::InvalidInput(
316 "Classifier not initialized".to_string(),
317 )),
318 }
319 }
320
321 fn softmax(&self, x: &ArrayView1<f64>) -> Array1<f64> {
323 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
324 let exp_x = x.mapv(|v| (v - max_val).exp());
325 let sum_exp = exp_x.sum();
326 exp_x / sum_exp
327 }
328
329 fn train(
331 &mut self,
332 x1: &ArrayView2<f64>,
333 x2: &ArrayView2<f64>,
334 y: &ArrayView1<i32>,
335 ) -> SklResult<()> {
336 let n_samples = x1.nrows();
337
338 if x1.nrows() != x2.nrows() || x1.nrows() != y.len() {
339 return Err(SklearsError::InvalidInput(
340 "All inputs must have the same number of samples".to_string(),
341 ));
342 }
343
344 self.initialize_networks(x1.ncols(), x2.ncols(), self.n_classes);
346
347 let mut labeled_indices = Vec::new();
349 let mut unlabeled_indices = Vec::new();
350
351 for (i, &label) in y.iter().enumerate() {
352 if label >= 0 {
353 labeled_indices.push(i);
354 } else {
355 unlabeled_indices.push(i);
356 }
357 }
358
359 for iteration in 0..self.max_iter {
361 let mut total_loss = 0.0;
362
363 let (z1, z2) = self.project_features(x1, x2)?;
365
366 let contrastive_loss = self.contrastive_loss(&z1.view(), &z2.view())?;
368 total_loss += self.contrastive_weight * contrastive_loss;
369
370 if !labeled_indices.is_empty() {
372 let mut supervised_loss = 0.0;
373 for &idx in &labeled_indices {
374 let probs = self.classify(&z1.row(idx), &z2.row(idx))?;
375 let label_idx = y[idx] as usize;
376 if label_idx < probs.len() {
377 supervised_loss -= (probs[label_idx] + 1e-15).ln();
378 }
379 }
380 supervised_loss /= labeled_indices.len() as f64;
381 total_loss += self.supervised_weight * supervised_loss;
382 }
383
384 if iteration % 10 == 0 {
386 println!("Iteration {}: Loss = {:.4}", iteration, total_loss);
387 }
388
389 if total_loss < 1e-6 {
391 break;
392 }
393 }
394
395 Ok(())
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct CrossModalContrastiveTrained {
402 pub projection1: ProjectionNetwork,
404 pub projection2: ProjectionNetwork,
406 pub classifier_weights: Array2<f64>,
408 pub classifier_biases: Array1<f64>,
410 pub classes: Array1<i32>,
412 pub projection_dim: usize,
414 pub n_classes: usize,
416 pub temperature: f64,
418}
419
420impl CrossModalContrastive<CrossModalContrastiveTrained> {
421 pub fn get_embeddings(
423 &self,
424 x1: &ArrayView2<f64>,
425 x2: &ArrayView2<f64>,
426 ) -> SklResult<(Array2<f64>, Array2<f64>)> {
427 let batch_size = x1.nrows();
428 let mut z1 = Array2::zeros((batch_size, self.state.projection_dim));
429 let mut z2 = Array2::zeros((batch_size, self.state.projection_dim));
430
431 for i in 0..batch_size {
432 let proj1_output = self.state.projection1.forward(&x1.row(i))?;
433 let proj2_output = self.state.projection2.forward(&x2.row(i))?;
434
435 z1.row_mut(i).assign(&proj1_output);
436 z2.row_mut(i).assign(&proj2_output);
437 }
438
439 Ok((z1, z2))
440 }
441
442 fn classify(&self, z1: &ArrayView1<f64>, z2: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
444 let mut combined = Array1::zeros(z1.len() + z2.len());
446 combined.slice_mut(s![..z1.len()]).assign(z1);
447 combined.slice_mut(s![z1.len()..]).assign(z2);
448
449 let logits = self.state.classifier_weights.dot(&combined) + &self.state.classifier_biases;
450 Ok(self.softmax(&logits.view()))
451 }
452
453 fn softmax(&self, x: &ArrayView1<f64>) -> Array1<f64> {
455 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
456 let exp_x = x.mapv(|v| (v - max_val).exp());
457 let sum_exp = exp_x.sum();
458 exp_x / sum_exp
459 }
460}
461
462impl Estimator for CrossModalContrastive<Untrained> {
463 type Config = ();
464 type Error = SklearsError;
465 type Float = Float;
466
467 fn config(&self) -> &Self::Config {
468 &()
469 }
470}
471
472pub type CrossModalInput = (Array2<f64>, Array2<f64>);
474
475impl Fit<CrossModalInput, ArrayView1<'_, i32>> for CrossModalContrastive<Untrained> {
476 type Fitted = CrossModalContrastive<CrossModalContrastiveTrained>;
477
478 fn fit(self, input: &CrossModalInput, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
479 let (x1, x2) = input;
480 let y = y.to_owned();
481
482 if x1.nrows() != x2.nrows() || x1.nrows() != y.len() {
483 return Err(SklearsError::InvalidInput(
484 "All inputs must have the same number of samples".to_string(),
485 ));
486 }
487
488 if x1.nrows() == 0 {
489 return Err(SklearsError::InvalidInput(
490 "No samples provided".to_string(),
491 ));
492 }
493
494 let labeled_count = y.iter().filter(|&&label| label >= 0).count();
496 if labeled_count == 0 {
497 return Err(SklearsError::InvalidInput(
498 "No labeled samples provided".to_string(),
499 ));
500 }
501
502 let mut unique_classes: Vec<i32> = y.iter().filter(|&&label| label >= 0).cloned().collect();
504 unique_classes.sort_unstable();
505 unique_classes.dedup();
506
507 let mut model = self.clone();
508 model.n_classes = unique_classes.len();
509
510 model.train(&x1.view(), &x2.view(), &y.view())?;
512
513 Ok(CrossModalContrastive {
514 state: CrossModalContrastiveTrained {
515 projection1: model.projection1.unwrap(),
516 projection2: model.projection2.unwrap(),
517 classifier_weights: model.classifier_weights.unwrap(),
518 classifier_biases: model.classifier_biases.unwrap(),
519 classes: Array1::from(unique_classes),
520 projection_dim: model.projection_dim,
521 n_classes: model.n_classes,
522 temperature: model.temperature,
523 },
524 projection1: None,
525 projection2: None,
526 classifier_weights: None,
527 classifier_biases: None,
528 projection_dim: 0,
529 n_classes: 0,
530 hidden_dims: Vec::new(),
531 temperature: 0.0,
532 learning_rate: 0.0,
533 max_iter: 0,
534 contrastive_weight: 0.0,
535 supervised_weight: 0.0,
536 random_state: None,
537 })
538 }
539}
540
541impl Predict<CrossModalInput, Array1<i32>> for CrossModalContrastive<CrossModalContrastiveTrained> {
542 fn predict(&self, input: &CrossModalInput) -> SklResult<Array1<i32>> {
543 let (x1, x2) = input;
544 let mut predictions = Array1::zeros(x1.nrows());
545
546 for i in 0..x1.nrows() {
547 let z1 = self.state.projection1.forward(&x1.row(i))?;
548 let z2 = self.state.projection2.forward(&x2.row(i))?;
549 let probs = self.classify(&z1.view(), &z2.view())?;
550
551 let max_idx = probs
552 .iter()
553 .enumerate()
554 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
555 .map(|(idx, _)| idx)
556 .unwrap_or(0);
557
558 predictions[i] = self.state.classes[max_idx];
559 }
560
561 Ok(predictions)
562 }
563}
564
565impl PredictProba<CrossModalInput, Array2<f64>>
566 for CrossModalContrastive<CrossModalContrastiveTrained>
567{
568 fn predict_proba(&self, input: &CrossModalInput) -> SklResult<Array2<f64>> {
569 let (x1, x2) = input;
570 let mut probabilities = Array2::zeros((x1.nrows(), self.state.n_classes));
571
572 for i in 0..x1.nrows() {
573 let z1 = self.state.projection1.forward(&x1.row(i))?;
574 let z2 = self.state.projection2.forward(&x2.row(i))?;
575 let probs = self.classify(&z1.view(), &z2.view())?;
576 probabilities.row_mut(i).assign(&probs);
577 }
578
579 Ok(probabilities)
580 }
581}
582
583#[allow(non_snake_case)]
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use scirs2_core::array;
588
589 #[test]
590 fn test_projection_network_creation() {
591 let network = ProjectionNetwork::new(10, 5, vec![8, 6]);
592 assert_eq!(network.architecture, vec![10, 8, 6, 5]);
593 assert_eq!(network.output_dim, 5);
594 assert_eq!(network.weights.len(), 3);
595 assert_eq!(network.biases.len(), 3);
596 }
597
598 #[test]
599 #[ignore = "Flaky test due to random weight initialization - fails occasionally when Xavier init produces small values"]
600 fn test_projection_network_forward() {
601 let network = ProjectionNetwork::new(3, 2, vec![4]);
602 let x = array![1.0, 2.0, 3.0];
603
604 let result = network.forward(&x.view());
605 assert!(result.is_ok());
606
607 let output = result.unwrap();
608 assert_eq!(output.len(), 2);
609
610 let norm = (output.mapv(|x| x * x).sum() + 1e-12).sqrt();
612 assert!(
613 (norm - 1.0).abs() < 1e-5,
614 "Norm should be ~1.0, got {}",
615 norm
616 );
617 }
618
619 #[test]
620 fn test_cross_modal_contrastive_creation() {
621 let model = CrossModalContrastive::new()
622 .projection_dim(64)
623 .hidden_dims(vec![128, 64])
624 .temperature(0.1)
625 .learning_rate(0.01)
626 .max_iter(50);
627
628 assert_eq!(model.projection_dim, 64);
629 assert_eq!(model.hidden_dims, vec![128, 64]);
630 assert_eq!(model.temperature, 0.1);
631 assert_eq!(model.learning_rate, 0.01);
632 assert_eq!(model.max_iter, 50);
633 }
634
635 #[test]
636 fn test_cross_modal_contrastive_fit_predict() {
637 let x1 = array![
639 [1.0, 2.0, 3.0],
640 [2.0, 3.0, 4.0],
641 [3.0, 4.0, 5.0],
642 [4.0, 5.0, 6.0],
643 [5.0, 6.0, 7.0],
644 [6.0, 7.0, 8.0]
645 ];
646
647 let x2 = array![
649 [0.5, 1.5, 2.5, 3.5],
650 [1.5, 2.5, 3.5, 4.5],
651 [2.5, 3.5, 4.5, 5.5],
652 [3.5, 4.5, 5.5, 6.5],
653 [4.5, 5.5, 6.5, 7.5],
654 [5.5, 6.5, 7.5, 8.5]
655 ];
656
657 let y = array![0, 1, 0, 1, -1, -1]; let model = CrossModalContrastive::new()
660 .projection_dim(8)
661 .hidden_dims(vec![12])
662 .temperature(0.1)
663 .learning_rate(0.01)
664 .max_iter(5);
665
666 let input = (x1.clone(), x2.clone());
667 let result = model.fit(&input, &y.view());
668 assert!(result.is_ok());
669
670 let fitted = result.unwrap();
671 assert_eq!(fitted.state.classes.len(), 2);
672
673 let predictions = fitted.predict(&input);
674 assert!(predictions.is_ok());
675
676 let pred = predictions.unwrap();
677 assert_eq!(pred.len(), 6);
678
679 let probabilities = fitted.predict_proba(&input);
680 assert!(probabilities.is_ok());
681
682 let proba = probabilities.unwrap();
683 assert_eq!(proba.dim(), (6, 2));
684
685 for i in 0..6 {
687 let sum: f64 = proba.row(i).sum();
688 assert!((sum - 1.0).abs() < 1e-10);
689 }
690 }
691
692 #[test]
693 fn test_cross_modal_contrastive_insufficient_labeled_samples() {
694 let x1 = array![[1.0, 2.0], [2.0, 3.0]];
695 let x2 = array![[1.5, 2.5], [2.5, 3.5]];
696 let y = array![-1, -1]; let model = CrossModalContrastive::new();
699 let input = (x1, x2);
700 let result = model.fit(&input, &y.view());
701 assert!(result.is_err());
702 }
703
704 #[test]
705 fn test_cross_modal_contrastive_mismatched_dimensions() {
706 let x1 = array![[1.0, 2.0], [2.0, 3.0]];
707 let x2 = array![[1.5, 2.5]]; let y = array![0, 1];
709
710 let model = CrossModalContrastive::new();
711 let input = (x1, x2);
712 let result = model.fit(&input, &y.view());
713 assert!(result.is_err());
714 }
715
716 #[test]
717 fn test_cross_modal_get_embeddings() {
718 let x1 = array![
719 [1.0, 2.0, 3.0],
720 [2.0, 3.0, 4.0],
721 [3.0, 4.0, 5.0],
722 [4.0, 5.0, 6.0]
723 ];
724
725 let x2 = array![
726 [0.5, 1.5, 2.5, 3.5],
727 [1.5, 2.5, 3.5, 4.5],
728 [2.5, 3.5, 4.5, 5.5],
729 [3.5, 4.5, 5.5, 6.5]
730 ];
731
732 let y = array![0, 1, 0, -1]; let model = CrossModalContrastive::new().projection_dim(6).max_iter(3);
735
736 let input = (x1.clone(), x2.clone());
737 let fitted = model.fit(&input, &y.view()).unwrap();
738
739 let embeddings = fitted.get_embeddings(&x1.view(), &x2.view());
740 assert!(embeddings.is_ok());
741
742 let (z1, z2) = embeddings.unwrap();
743 assert_eq!(z1.dim(), (4, 6));
744 assert_eq!(z2.dim(), (4, 6));
745
746 for i in 0..4 {
748 let norm1 = (z1.row(i).mapv(|x| x * x).sum()).sqrt();
749 let norm2 = (z2.row(i).mapv(|x| x * x).sum()).sqrt();
750 assert!((norm1 - 1.0).abs() < 1e-10);
751 assert!((norm2 - 1.0).abs() < 1e-10);
752 }
753 }
754
755 #[test]
756 fn test_cross_modal_contrastive_with_different_parameters() {
757 let x1 = array![
758 [1.0, 2.0, 3.0, 4.0],
759 [2.0, 3.0, 4.0, 5.0],
760 [3.0, 4.0, 5.0, 6.0],
761 [4.0, 5.0, 6.0, 7.0]
762 ];
763
764 let x2 = array![[0.5, 1.5], [1.5, 2.5], [2.5, 3.5], [3.5, 4.5]];
765
766 let y = array![0, 1, 0, -1]; let model = CrossModalContrastive::new()
769 .projection_dim(10)
770 .hidden_dims(vec![16, 12])
771 .temperature(0.05)
772 .contrastive_weight(2.0)
773 .supervised_weight(0.5)
774 .max_iter(2);
775
776 let input = (x1, x2);
777 let result = model.fit(&input, &y.view());
778 assert!(result.is_ok());
779
780 let fitted = result.unwrap();
781 let predictions = fitted.predict(&input).unwrap();
782 assert_eq!(predictions.len(), 4);
783 }
784}