1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::error::{Result as SklResult, SklearsError};
10use sklears_core::traits::{Estimator, Fit, Predict, Untrained};
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone)]
48pub struct MutualInformationMaximization<S = Untrained> {
49 state: S,
50 n_bins: usize,
51 max_iter: usize,
52 learning_rate: f64,
53 temperature: f64,
54 regularization: f64,
55 random_state: Option<u64>,
56}
57
58impl MutualInformationMaximization<Untrained> {
59 pub fn new() -> Self {
61 Self {
62 state: Untrained,
63 n_bins: 20,
64 max_iter: 100,
65 learning_rate: 0.01,
66 temperature: 1.0,
67 regularization: 0.01,
68 random_state: None,
69 }
70 }
71
72 pub fn n_bins(mut self, n_bins: usize) -> Self {
74 self.n_bins = n_bins;
75 self
76 }
77
78 pub fn max_iter(mut self, max_iter: usize) -> Self {
80 self.max_iter = max_iter;
81 self
82 }
83
84 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
86 self.learning_rate = learning_rate;
87 self
88 }
89
90 pub fn temperature(mut self, temperature: f64) -> Self {
92 self.temperature = temperature;
93 self
94 }
95
96 pub fn regularization(mut self, regularization: f64) -> Self {
98 self.regularization = regularization;
99 self
100 }
101
102 pub fn random_state(mut self, random_state: u64) -> Self {
104 self.random_state = Some(random_state);
105 self
106 }
107}
108
109impl Default for MutualInformationMaximization<Untrained> {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl Estimator for MutualInformationMaximization<Untrained> {
116 type Config = ();
117 type Error = SklearsError;
118 type Float = Float;
119
120 fn config(&self) -> &Self::Config {
121 &()
122 }
123}
124
125impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MutualInformationMaximization<Untrained> {
126 type Fitted = MutualInformationMaximization<MutualInformationTrained>;
127
128 #[allow(non_snake_case)]
129 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
130 let X = X.to_owned();
131 let y = y.to_owned();
132 let (n_samples, n_features) = X.dim();
133
134 let mut labeled_indices = Vec::new();
136 let mut unlabeled_indices = Vec::new();
137 let mut classes = std::collections::HashSet::new();
138
139 for (i, &label) in y.iter().enumerate() {
140 if label == -1 {
141 unlabeled_indices.push(i);
142 } else {
143 labeled_indices.push(i);
144 classes.insert(label);
145 }
146 }
147
148 if labeled_indices.is_empty() {
149 return Err(SklearsError::InvalidInput(
150 "No labeled samples provided".to_string(),
151 ));
152 }
153
154 let classes: Vec<i32> = classes.into_iter().collect();
155 let n_classes = classes.len();
156
157 let mut rng = if let Some(seed) = self.random_state {
159 Random::seed(seed)
160 } else {
161 Random::seed(
162 std::time::SystemTime::now()
163 .duration_since(std::time::UNIX_EPOCH)
164 .unwrap()
165 .as_secs(),
166 )
167 };
168
169 let mut transformation = Array2::<f64>::zeros((n_features, n_features));
171 for i in 0..n_features {
172 transformation[[i, i]] = 1.0; for j in 0..n_features {
174 if i != j {
175 transformation[[i, j]] = rng.random_range(-0.1..0.1);
176 }
177 }
178 }
179
180 for _iter in 0..self.max_iter {
182 let X_transformed = X.dot(&transformation);
184
185 let mi =
187 self.estimate_mutual_information(&X_transformed, &y, &labeled_indices, &classes)?;
188
189 let mut gradient = Array2::<f64>::zeros((n_features, n_features));
191 let epsilon = 1e-6;
192
193 for i in 0..n_features {
194 for j in 0..n_features {
195 transformation[[i, j]] += epsilon;
197 let X_perturbed = X.dot(&transformation);
198 let mi_perturbed = self.estimate_mutual_information(
199 &X_perturbed,
200 &y,
201 &labeled_indices,
202 &classes,
203 )?;
204 gradient[[i, j]] = (mi_perturbed - mi) / epsilon;
205 transformation[[i, j]] -= epsilon; }
207 }
208
209 for i in 0..n_features {
211 for j in 0..n_features {
212 transformation[[i, j]] += self.learning_rate * gradient[[i, j]]
213 - self.regularization * transformation[[i, j]];
214 }
215 }
216 }
217
218 let X_final = X.dot(&transformation);
220 let mut final_labels = y.clone();
221
222 for &unlabeled_idx in &unlabeled_indices {
224 let mut distances = Vec::new();
225 for &labeled_idx in &labeled_indices {
226 let dist = (&X_final.row(unlabeled_idx) - &X_final.row(labeled_idx))
227 .mapv(|x| x * x)
228 .sum()
229 .sqrt();
230 distances.push((labeled_idx, dist));
231 }
232
233 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
235 let k = 3.min(labeled_indices.len());
236 let mut class_votes = HashMap::new();
237
238 for &(labeled_idx, _) in distances.iter().take(k) {
239 *class_votes.entry(y[labeled_idx]).or_insert(0) += 1;
240 }
241
242 if let Some((&predicted_class, _)) = class_votes.iter().max_by_key(|&(_, count)| count)
244 {
245 final_labels[unlabeled_idx] = predicted_class;
246 }
247 }
248
249 Ok(MutualInformationMaximization {
250 state: MutualInformationTrained {
251 X_train: X,
252 y_train: final_labels,
253 classes: Array1::from(classes),
254 transformation,
255 n_bins: self.n_bins,
256 },
257 n_bins: self.n_bins,
258 max_iter: self.max_iter,
259 learning_rate: self.learning_rate,
260 temperature: self.temperature,
261 regularization: self.regularization,
262 random_state: self.random_state,
263 })
264 }
265}
266
267impl MutualInformationMaximization<Untrained> {
268 fn estimate_mutual_information(
270 &self,
271 X: &Array2<f64>,
272 y: &Array1<i32>,
273 labeled_indices: &[usize],
274 classes: &[i32],
275 ) -> SklResult<f64> {
276 if labeled_indices.is_empty() {
277 return Ok(0.0);
278 }
279
280 let mut feature_bins = Vec::new();
282 for j in 0..X.ncols() {
283 let labeled_features: Vec<f64> = labeled_indices.iter().map(|&i| X[[i, j]]).collect();
284
285 if labeled_features.is_empty() {
286 continue;
287 }
288
289 let min_val = labeled_features
290 .iter()
291 .fold(f64::INFINITY, |a, &b| a.min(b));
292 let max_val = labeled_features
293 .iter()
294 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
295
296 if (max_val - min_val).abs() < 1e-10 {
297 feature_bins.push(vec![0; labeled_indices.len()]); continue;
299 }
300
301 let bin_width = (max_val - min_val) / self.n_bins as f64;
302 let bins: Vec<usize> = labeled_features
303 .iter()
304 .map(|&val| {
305 ((val - min_val) / bin_width)
306 .floor()
307 .min((self.n_bins - 1) as f64) as usize
308 })
309 .collect();
310 feature_bins.push(bins);
311 }
312
313 if feature_bins.is_empty() {
314 return Ok(0.0);
315 }
316
317 let mut joint_counts = HashMap::new();
319 let mut feature_counts = HashMap::new();
320 let mut label_counts = HashMap::new();
321
322 for (sample_idx, &global_idx) in labeled_indices.iter().enumerate() {
323 let label = y[global_idx];
324
325 let feature_bin = if !feature_bins.is_empty() && sample_idx < feature_bins[0].len() {
327 feature_bins[0][sample_idx]
328 } else {
329 0
330 };
331
332 *joint_counts.entry((feature_bin, label)).or_insert(0) += 1;
333 *feature_counts.entry(feature_bin).or_insert(0) += 1;
334 *label_counts.entry(label).or_insert(0) += 1;
335 }
336
337 let n_labeled = labeled_indices.len() as f64;
338 let mut mi = 0.0;
339
340 for (&(feature_bin, label), &joint_count) in &joint_counts {
342 let p_xy = joint_count as f64 / n_labeled;
343 let p_x = feature_counts[&feature_bin] as f64 / n_labeled;
344 let p_y = label_counts[&label] as f64 / n_labeled;
345
346 if p_xy > 0.0 && p_x > 0.0 && p_y > 0.0 {
347 mi += p_xy * (p_xy / (p_x * p_y)).ln();
348 }
349 }
350
351 Ok(mi)
352 }
353}
354
355impl Predict<ArrayView2<'_, Float>, Array1<i32>>
356 for MutualInformationMaximization<MutualInformationTrained>
357{
358 #[allow(non_snake_case)]
359 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
360 let X = X.to_owned();
361 let n_test = X.nrows();
362 let mut predictions = Array1::zeros(n_test);
363
364 let X_transformed = X.dot(&self.state.transformation);
366
367 for i in 0..n_test {
368 let mut min_dist = f64::INFINITY;
370 let mut best_label = self.state.classes[0];
371
372 for j in 0..self.state.X_train.nrows() {
373 let X_train_transformed = self.state.X_train.dot(&self.state.transformation);
374 let diff = &X_transformed.row(i) - &X_train_transformed.row(j);
375 let dist = diff.mapv(|x| x * x).sum().sqrt();
376
377 if dist < min_dist {
378 min_dist = dist;
379 best_label = self.state.y_train[j];
380 }
381 }
382
383 predictions[i] = best_label;
384 }
385
386 Ok(predictions)
387 }
388}
389
390#[derive(Debug, Clone)]
402pub struct InformationBottleneck<S = Untrained> {
403 state: S,
404 beta: f64,
405 n_components: usize,
406 max_iter: usize,
407 tol: f64,
408 random_state: Option<u64>,
409}
410
411impl InformationBottleneck<Untrained> {
412 pub fn new() -> Self {
414 Self {
415 state: Untrained,
416 beta: 1.0,
417 n_components: 10,
418 max_iter: 100,
419 tol: 1e-4,
420 random_state: None,
421 }
422 }
423
424 pub fn beta(mut self, beta: f64) -> Self {
426 self.beta = beta;
427 self
428 }
429
430 pub fn n_components(mut self, n_components: usize) -> Self {
432 self.n_components = n_components;
433 self
434 }
435
436 pub fn max_iter(mut self, max_iter: usize) -> Self {
438 self.max_iter = max_iter;
439 self
440 }
441
442 pub fn tol(mut self, tol: f64) -> Self {
444 self.tol = tol;
445 self
446 }
447
448 pub fn random_state(mut self, random_state: u64) -> Self {
450 self.random_state = Some(random_state);
451 self
452 }
453}
454
455impl Default for InformationBottleneck<Untrained> {
456 fn default() -> Self {
457 Self::new()
458 }
459}
460
461impl Estimator for InformationBottleneck<Untrained> {
462 type Config = ();
463 type Error = SklearsError;
464 type Float = Float;
465
466 fn config(&self) -> &Self::Config {
467 &()
468 }
469}
470
471impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for InformationBottleneck<Untrained> {
472 type Fitted = InformationBottleneck<InformationBottleneckTrained>;
473
474 #[allow(non_snake_case)]
475 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
476 let X = X.to_owned();
477 let y = y.to_owned();
478 let (n_samples, n_features) = X.dim();
479
480 let mut labeled_indices = Vec::new();
482 let mut classes = std::collections::HashSet::new();
483
484 for (i, &label) in y.iter().enumerate() {
485 if label != -1 {
486 labeled_indices.push(i);
487 classes.insert(label);
488 }
489 }
490
491 if labeled_indices.is_empty() {
492 return Err(SklearsError::InvalidInput(
493 "No labeled samples provided".to_string(),
494 ));
495 }
496
497 let classes: Vec<i32> = classes.into_iter().collect();
498
499 let mut rng = if let Some(seed) = self.random_state {
501 Random::seed(seed)
502 } else {
503 Random::seed(
504 std::time::SystemTime::now()
505 .duration_since(std::time::UNIX_EPOCH)
506 .unwrap()
507 .as_secs(),
508 )
509 };
510
511 let mut projection = Array2::<f64>::zeros((n_features, self.n_components));
513 for i in 0..n_features {
514 for j in 0..self.n_components {
515 projection[[i, j]] = rng.random_range(-0.1..0.1);
516 }
517 }
518
519 for _iter in 0..self.max_iter {
521 let X_projected = X.dot(&projection);
523
524 let reconstruction_loss =
526 self.compute_reconstruction_loss(&X, &X_projected, &projection)?;
527
528 for i in 0..n_features {
531 for j in 0..self.n_components {
532 let gradient = reconstruction_loss / (n_samples as f64);
533 projection[[i, j]] -= 0.001 * gradient; }
535 }
536 }
537
538 Ok(InformationBottleneck {
539 state: InformationBottleneckTrained {
540 X_train: X,
541 y_train: y,
542 classes: Array1::from(classes),
543 projection,
544 },
545 beta: self.beta,
546 n_components: self.n_components,
547 max_iter: self.max_iter,
548 tol: self.tol,
549 random_state: self.random_state,
550 })
551 }
552}
553
554impl InformationBottleneck<Untrained> {
555 fn compute_reconstruction_loss(
556 &self,
557 X_original: &Array2<f64>,
558 X_projected: &Array2<f64>,
559 projection: &Array2<f64>,
560 ) -> SklResult<f64> {
561 let reconstruction = X_projected.dot(&projection.t());
563 let diff = X_original - &reconstruction;
564 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
565 Ok(mse)
566 }
567}
568
569impl Predict<ArrayView2<'_, Float>, Array1<i32>>
570 for InformationBottleneck<InformationBottleneckTrained>
571{
572 #[allow(non_snake_case)]
573 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
574 let X = X.to_owned();
575 let n_test = X.nrows();
576 let mut predictions = Array1::zeros(n_test);
577
578 let X_test_projected = X.dot(&self.state.projection);
580 let X_train_projected = self.state.X_train.dot(&self.state.projection);
581
582 for i in 0..n_test {
583 let mut min_dist = f64::INFINITY;
585 let mut best_label = self.state.classes[0];
586
587 for j in 0..self.state.X_train.nrows() {
588 let diff = &X_test_projected.row(i) - &X_train_projected.row(j);
589 let dist = diff.mapv(|x| x * x).sum().sqrt();
590
591 if dist < min_dist {
592 min_dist = dist;
593 best_label = self.state.y_train[j];
594 }
595 }
596
597 predictions[i] = best_label;
598 }
599
600 Ok(predictions)
601 }
602}
603
604#[derive(Debug, Clone)]
606pub struct MutualInformationTrained {
607 pub X_train: Array2<f64>,
609 pub y_train: Array1<i32>,
611 pub classes: Array1<i32>,
613 pub transformation: Array2<f64>,
615 pub n_bins: usize,
617}
618
619#[derive(Debug, Clone)]
621pub struct InformationBottleneckTrained {
622 pub X_train: Array2<f64>,
624 pub y_train: Array1<i32>,
626 pub classes: Array1<i32>,
628 pub projection: Array2<f64>,
630}
631
632#[derive(Debug, Clone)]
644pub struct EntropyRegularizedSemiSupervised<S = Untrained> {
645 state: S,
646 entropy_weight: f64,
647 max_iter: usize,
648 learning_rate: f64,
649 n_neighbors: usize,
650 random_state: Option<u64>,
651}
652
653impl EntropyRegularizedSemiSupervised<Untrained> {
654 pub fn new() -> Self {
656 Self {
657 state: Untrained,
658 entropy_weight: 0.5,
659 max_iter: 100,
660 learning_rate: 0.01,
661 n_neighbors: 5,
662 random_state: None,
663 }
664 }
665
666 pub fn entropy_weight(mut self, weight: f64) -> Self {
668 self.entropy_weight = weight;
669 self
670 }
671
672 pub fn max_iter(mut self, max_iter: usize) -> Self {
674 self.max_iter = max_iter;
675 self
676 }
677
678 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
680 self.learning_rate = learning_rate;
681 self
682 }
683
684 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
686 self.n_neighbors = n_neighbors;
687 self
688 }
689
690 pub fn random_state(mut self, random_state: u64) -> Self {
692 self.random_state = Some(random_state);
693 self
694 }
695}
696
697impl Default for EntropyRegularizedSemiSupervised<Untrained> {
698 fn default() -> Self {
699 Self::new()
700 }
701}
702
703impl Estimator for EntropyRegularizedSemiSupervised<Untrained> {
704 type Config = ();
705 type Error = SklearsError;
706 type Float = Float;
707
708 fn config(&self) -> &Self::Config {
709 &()
710 }
711}
712
713impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>>
714 for EntropyRegularizedSemiSupervised<Untrained>
715{
716 type Fitted = EntropyRegularizedSemiSupervised<EntropyRegularizedTrained>;
717
718 #[allow(non_snake_case)]
719 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
720 let X = X.to_owned();
721 let y = y.to_owned();
722 let (n_samples, n_features) = X.dim();
723
724 let mut labeled_indices = Vec::new();
726 let mut unlabeled_indices = Vec::new();
727 let mut classes = std::collections::HashSet::new();
728
729 for (i, &label) in y.iter().enumerate() {
730 if label == -1 {
731 unlabeled_indices.push(i);
732 } else {
733 labeled_indices.push(i);
734 classes.insert(label);
735 }
736 }
737
738 if labeled_indices.is_empty() {
739 return Err(SklearsError::InvalidInput(
740 "No labeled samples provided".to_string(),
741 ));
742 }
743
744 let classes: Vec<i32> = classes.into_iter().collect();
745 let n_classes = classes.len();
746
747 let mut prob_distributions = Array2::<f64>::zeros((n_samples, n_classes));
749
750 for &idx in &labeled_indices {
752 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
753 prob_distributions[[idx, class_idx]] = 1.0;
754 }
755 }
756
757 for &idx in &unlabeled_indices {
759 for class_idx in 0..n_classes {
760 prob_distributions[[idx, class_idx]] = 1.0 / n_classes as f64;
761 }
762 }
763
764 let mut adjacency = Array2::<f64>::zeros((n_samples, n_samples));
766 for i in 0..n_samples {
767 let mut distances: Vec<(usize, f64)> = Vec::new();
768 for j in 0..n_samples {
769 if i != j {
770 let diff = &X.row(i) - &X.row(j);
771 let dist = diff.mapv(|x| x * x).sum().sqrt();
772 distances.push((j, dist));
773 }
774 }
775 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
776
777 for &(j, dist) in distances.iter().take(self.n_neighbors) {
778 let weight = (-dist.powi(2) / 2.0).exp();
779 adjacency[[i, j]] = weight;
780 adjacency[[j, i]] = weight;
781 }
782 }
783
784 for i in 0..n_samples {
786 let row_sum: f64 = adjacency.row(i).sum();
787 if row_sum > 0.0 {
788 for j in 0..n_samples {
789 adjacency[[i, j]] /= row_sum;
790 }
791 }
792 }
793
794 for _iter in 0..self.max_iter {
796 let prev_probs = prob_distributions.clone();
797
798 for &idx in &unlabeled_indices {
800 let mut smooth_dist = Array1::<f64>::zeros(n_classes);
802 for j in 0..n_samples {
803 for k in 0..n_classes {
804 smooth_dist[k] += adjacency[[idx, j]] * prob_distributions[[j, k]];
805 }
806 }
807
808 let mut entropy_grad = Array1::<f64>::zeros(n_classes);
810 for k in 0..n_classes {
811 let p = prob_distributions[[idx, k]].max(1e-10);
812 entropy_grad[k] = -(p.ln() + 1.0);
813 }
814
815 for k in 0..n_classes {
817 prob_distributions[[idx, k]] =
818 smooth_dist[k] - self.learning_rate * self.entropy_weight * entropy_grad[k];
819 prob_distributions[[idx, k]] = prob_distributions[[idx, k]].max(0.0);
820 }
821
822 let row_sum: f64 = prob_distributions.row(idx).sum();
824 if row_sum > 0.0 {
825 for k in 0..n_classes {
826 prob_distributions[[idx, k]] /= row_sum;
827 }
828 }
829 }
830
831 let diff = (&prob_distributions - &prev_probs).mapv(|x| x.abs()).sum();
833 if diff < 1e-6 {
834 break;
835 }
836 }
837
838 let mut final_labels = y.clone();
840 for &idx in &unlabeled_indices {
841 let class_idx = prob_distributions
842 .row(idx)
843 .iter()
844 .enumerate()
845 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
846 .unwrap()
847 .0;
848 final_labels[idx] = classes[class_idx];
849 }
850
851 Ok(EntropyRegularizedSemiSupervised {
852 state: EntropyRegularizedTrained {
853 X_train: X,
854 y_train: final_labels,
855 classes: Array1::from(classes),
856 prob_distributions,
857 adjacency,
858 },
859 entropy_weight: self.entropy_weight,
860 max_iter: self.max_iter,
861 learning_rate: self.learning_rate,
862 n_neighbors: self.n_neighbors,
863 random_state: self.random_state,
864 })
865 }
866}
867
868impl Predict<ArrayView2<'_, Float>, Array1<i32>>
869 for EntropyRegularizedSemiSupervised<EntropyRegularizedTrained>
870{
871 #[allow(non_snake_case)]
872 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
873 let X = X.to_owned();
874 let n_test = X.nrows();
875 let mut predictions = Array1::zeros(n_test);
876
877 for i in 0..n_test {
878 let mut min_dist = f64::INFINITY;
879 let mut best_label = self.state.classes[0];
880
881 for j in 0..self.state.X_train.nrows() {
882 let diff = &X.row(i) - &self.state.X_train.row(j);
883 let dist = diff.mapv(|x| x * x).sum().sqrt();
884
885 if dist < min_dist {
886 min_dist = dist;
887 best_label = self.state.y_train[j];
888 }
889 }
890
891 predictions[i] = best_label;
892 }
893
894 Ok(predictions)
895 }
896}
897
898#[derive(Debug, Clone)]
910pub struct KLDivergenceOptimization<S = Untrained> {
911 state: S,
912 temperature: f64,
913 max_iter: usize,
914 learning_rate: f64,
915 kl_weight: f64,
916 random_state: Option<u64>,
917}
918
919impl KLDivergenceOptimization<Untrained> {
920 pub fn new() -> Self {
922 Self {
923 state: Untrained,
924 temperature: 1.0,
925 max_iter: 100,
926 learning_rate: 0.01,
927 kl_weight: 1.0,
928 random_state: None,
929 }
930 }
931
932 pub fn temperature(mut self, temperature: f64) -> Self {
934 self.temperature = temperature;
935 self
936 }
937
938 pub fn max_iter(mut self, max_iter: usize) -> Self {
940 self.max_iter = max_iter;
941 self
942 }
943
944 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
946 self.learning_rate = learning_rate;
947 self
948 }
949
950 pub fn kl_weight(mut self, weight: f64) -> Self {
952 self.kl_weight = weight;
953 self
954 }
955
956 pub fn random_state(mut self, random_state: u64) -> Self {
958 self.random_state = Some(random_state);
959 self
960 }
961}
962
963impl Default for KLDivergenceOptimization<Untrained> {
964 fn default() -> Self {
965 Self::new()
966 }
967}
968
969impl Estimator for KLDivergenceOptimization<Untrained> {
970 type Config = ();
971 type Error = SklearsError;
972 type Float = Float;
973
974 fn config(&self) -> &Self::Config {
975 &()
976 }
977}
978
979impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for KLDivergenceOptimization<Untrained> {
980 type Fitted = KLDivergenceOptimization<KLDivergenceTrained>;
981
982 #[allow(non_snake_case)]
983 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
984 let X = X.to_owned();
985 let y = y.to_owned();
986 let (n_samples, n_features) = X.dim();
987
988 let mut labeled_indices = Vec::new();
990 let mut unlabeled_indices = Vec::new();
991 let mut classes = std::collections::HashSet::new();
992
993 for (i, &label) in y.iter().enumerate() {
994 if label == -1 {
995 unlabeled_indices.push(i);
996 } else {
997 labeled_indices.push(i);
998 classes.insert(label);
999 }
1000 }
1001
1002 if labeled_indices.is_empty() {
1003 return Err(SklearsError::InvalidInput(
1004 "No labeled samples provided".to_string(),
1005 ));
1006 }
1007
1008 let classes: Vec<i32> = classes.into_iter().collect();
1009 let n_classes = classes.len();
1010
1011 let mut rng = if let Some(seed) = self.random_state {
1013 Random::seed(seed)
1014 } else {
1015 Random::seed(
1016 std::time::SystemTime::now()
1017 .duration_since(std::time::UNIX_EPOCH)
1018 .unwrap()
1019 .as_secs(),
1020 )
1021 };
1022
1023 let mut weights = Array2::<f64>::zeros((n_features, n_classes));
1025 for i in 0..n_features {
1026 for j in 0..n_classes {
1027 weights[[i, j]] = rng.random_range(-0.1..0.1);
1028 }
1029 }
1030
1031 for _iter in 0..self.max_iter {
1033 let logits = X.dot(&weights);
1035 let mut predictions = Array2::<f64>::zeros((n_samples, n_classes));
1036
1037 for i in 0..n_samples {
1039 let max_logit = logits
1040 .row(i)
1041 .iter()
1042 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1043 let mut exp_sum = 0.0;
1044
1045 for j in 0..n_classes {
1046 let exp_val = ((logits[[i, j]] - max_logit) / self.temperature).exp();
1047 predictions[[i, j]] = exp_val;
1048 exp_sum += exp_val;
1049 }
1050
1051 if exp_sum > 0.0 {
1052 for j in 0..n_classes {
1053 predictions[[i, j]] /= exp_sum;
1054 }
1055 }
1056 }
1057
1058 let mut gradient = Array2::<f64>::zeros((n_features, n_classes));
1060
1061 for &idx in &labeled_indices {
1063 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
1064 for j in 0..n_features {
1065 for k in 0..n_classes {
1066 let target = if k == class_idx { 1.0 } else { 0.0 };
1067 gradient[[j, k]] += X[[idx, j]] * (predictions[[idx, k]] - target);
1068 }
1069 }
1070 }
1071 }
1072
1073 for &idx in &unlabeled_indices {
1075 let mut X_aug = X.row(idx).to_owned();
1077 for j in 0..n_features {
1078 X_aug[j] += rng.random_range(-0.01..0.01);
1079 }
1080
1081 let logits_aug = X_aug.dot(&weights);
1083 let max_logit = logits_aug.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1084 let mut pred_aug = Array1::<f64>::zeros(n_classes);
1085 let mut exp_sum = 0.0;
1086
1087 for j in 0..n_classes {
1088 let exp_val = ((logits_aug[j] - max_logit) / self.temperature).exp();
1089 pred_aug[j] = exp_val;
1090 exp_sum += exp_val;
1091 }
1092
1093 if exp_sum > 0.0 {
1094 pred_aug /= exp_sum;
1095 }
1096
1097 for j in 0..n_features {
1099 for k in 0..n_classes {
1100 let p = predictions[[idx, k]].max(1e-10);
1101 let q = pred_aug[k].max(1e-10);
1102 let kl_grad = p * (p / q).ln();
1103 gradient[[j, k]] += self.kl_weight * X[[idx, j]] * kl_grad;
1104 }
1105 }
1106 }
1107
1108 let scale = self.learning_rate / n_samples as f64;
1110 for i in 0..n_features {
1111 for j in 0..n_classes {
1112 weights[[i, j]] -= scale * gradient[[i, j]];
1113 }
1114 }
1115 }
1116
1117 let logits = X.dot(&weights);
1119 let mut final_labels = y.clone();
1120
1121 for &idx in &unlabeled_indices {
1122 let class_idx = logits
1123 .row(idx)
1124 .iter()
1125 .enumerate()
1126 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
1127 .unwrap()
1128 .0;
1129 final_labels[idx] = classes[class_idx];
1130 }
1131
1132 Ok(KLDivergenceOptimization {
1133 state: KLDivergenceTrained {
1134 X_train: X,
1135 y_train: final_labels,
1136 classes: Array1::from(classes),
1137 weights,
1138 },
1139 temperature: self.temperature,
1140 max_iter: self.max_iter,
1141 learning_rate: self.learning_rate,
1142 kl_weight: self.kl_weight,
1143 random_state: self.random_state,
1144 })
1145 }
1146}
1147
1148impl Predict<ArrayView2<'_, Float>, Array1<i32>> for KLDivergenceOptimization<KLDivergenceTrained> {
1149 #[allow(non_snake_case)]
1150 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1151 let X = X.to_owned();
1152 let n_test = X.nrows();
1153 let mut predictions = Array1::zeros(n_test);
1154
1155 let logits = X.dot(&self.state.weights);
1156
1157 for i in 0..n_test {
1158 let class_idx = logits
1159 .row(i)
1160 .iter()
1161 .enumerate()
1162 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
1163 .unwrap()
1164 .0;
1165 predictions[i] = self.state.classes[class_idx];
1166 }
1167
1168 Ok(predictions)
1169 }
1170}
1171
1172#[derive(Debug, Clone)]
1174pub struct EntropyRegularizedTrained {
1175 pub X_train: Array2<f64>,
1177 pub y_train: Array1<i32>,
1179 pub classes: Array1<i32>,
1181 pub prob_distributions: Array2<f64>,
1183 pub adjacency: Array2<f64>,
1185}
1186
1187#[derive(Debug, Clone)]
1189pub struct KLDivergenceTrained {
1190 pub X_train: Array2<f64>,
1192 pub y_train: Array1<i32>,
1194 pub classes: Array1<i32>,
1196 pub weights: Array2<f64>,
1198}
1199
1200#[allow(non_snake_case)]
1201#[cfg(test)]
1202mod tests {
1203 use super::*;
1204 use scirs2_core::array;
1205
1206 #[test]
1207 #[allow(non_snake_case)]
1208 fn test_mutual_information_maximization() {
1209 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1210 let y = array![0, 1, -1, -1];
1211
1212 let mim = MutualInformationMaximization::new()
1213 .n_bins(5)
1214 .max_iter(10)
1215 .random_state(42);
1216
1217 let fitted = mim.fit(&X.view(), &y.view()).unwrap();
1218 let predictions = fitted.predict(&X.view()).unwrap();
1219
1220 assert_eq!(predictions.len(), 4);
1221 assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1222
1223 assert_eq!(predictions[0], 0);
1225 assert_eq!(predictions[1], 1);
1226 }
1227
1228 #[test]
1229 #[allow(non_snake_case)]
1230 fn test_information_bottleneck() {
1231 let X = array![
1232 [1.0, 2.0, 3.0],
1233 [2.0, 3.0, 4.0],
1234 [3.0, 4.0, 5.0],
1235 [4.0, 5.0, 6.0]
1236 ];
1237 let y = array![0, 1, -1, -1];
1238
1239 let ib = InformationBottleneck::new()
1240 .n_components(2)
1241 .max_iter(10)
1242 .random_state(42);
1243
1244 let fitted = ib.fit(&X.view(), &y.view()).unwrap();
1245 let predictions = fitted.predict(&X.view()).unwrap();
1246
1247 assert_eq!(predictions.len(), 4);
1248 assert!(predictions.iter().all(|&p| p == -1 || p == 0 || p == 1));
1250 }
1251
1252 #[test]
1253 #[allow(non_snake_case)]
1254 fn test_mutual_information_estimation() {
1255 let mim = MutualInformationMaximization::new().n_bins(5);
1256 let X = array![[1.0, 2.0], [2.0, 3.0]];
1257 let y = array![0, 1];
1258 let labeled_indices = vec![0, 1];
1259 let classes = vec![0, 1];
1260
1261 let mi = mim
1262 .estimate_mutual_information(&X, &y, &labeled_indices, &classes)
1263 .unwrap();
1264 assert!(mi >= 0.0); }
1266
1267 #[test]
1268 fn test_information_bottleneck_parameters() {
1269 let ib = InformationBottleneck::new()
1270 .beta(0.5)
1271 .n_components(5)
1272 .max_iter(50)
1273 .tol(1e-5);
1274
1275 assert_eq!(ib.beta, 0.5);
1276 assert_eq!(ib.n_components, 5);
1277 assert_eq!(ib.max_iter, 50);
1278 assert_eq!(ib.tol, 1e-5);
1279 }
1280
1281 #[test]
1282 fn test_mutual_information_maximization_parameters() {
1283 let mim = MutualInformationMaximization::new()
1284 .n_bins(15)
1285 .max_iter(200)
1286 .learning_rate(0.05)
1287 .temperature(2.0)
1288 .regularization(0.02);
1289
1290 assert_eq!(mim.n_bins, 15);
1291 assert_eq!(mim.max_iter, 200);
1292 assert_eq!(mim.learning_rate, 0.05);
1293 assert_eq!(mim.temperature, 2.0);
1294 assert_eq!(mim.regularization, 0.02);
1295 }
1296
1297 #[test]
1298 #[allow(non_snake_case)]
1299 fn test_empty_labeled_samples_error() {
1300 let X = array![[1.0, 2.0], [2.0, 3.0]];
1301 let y = array![-1, -1]; let mim = MutualInformationMaximization::new();
1304 let result = mim.fit(&X.view(), &y.view());
1305
1306 assert!(result.is_err());
1307 if let Err(SklearsError::InvalidInput(msg)) = result {
1308 assert_eq!(msg, "No labeled samples provided");
1309 } else {
1310 panic!("Expected InvalidInput error");
1311 }
1312 }
1313
1314 #[test]
1315 #[allow(non_snake_case)]
1316 fn test_single_class_stability() {
1317 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1318 let y = array![0, 0, -1, -1]; let mim = MutualInformationMaximization::new()
1321 .max_iter(5)
1322 .random_state(42);
1323
1324 let fitted = mim.fit(&X.view(), &y.view()).unwrap();
1325 let predictions = fitted.predict(&X.view()).unwrap();
1326
1327 assert_eq!(predictions.len(), 4);
1328 assert!(predictions.iter().all(|&p| p == 0));
1330 }
1331
1332 #[test]
1333 #[allow(non_snake_case)]
1334 fn test_entropy_regularized_basic() {
1335 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1336 let y = array![0, 1, -1, -1];
1337
1338 let er = EntropyRegularizedSemiSupervised::new()
1339 .entropy_weight(0.5)
1340 .max_iter(10)
1341 .random_state(42);
1342
1343 let fitted = er.fit(&X.view(), &y.view()).unwrap();
1344 let predictions = fitted.predict(&X.view()).unwrap();
1345
1346 assert_eq!(predictions.len(), 4);
1347 assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1348 }
1349
1350 #[test]
1351 fn test_entropy_regularized_parameters() {
1352 let er = EntropyRegularizedSemiSupervised::new()
1353 .entropy_weight(0.5)
1354 .max_iter(50)
1355 .learning_rate(0.001)
1356 .n_neighbors(10);
1357
1358 assert_eq!(er.entropy_weight, 0.5);
1359 assert_eq!(er.max_iter, 50);
1360 assert_eq!(er.learning_rate, 0.001);
1361 assert_eq!(er.n_neighbors, 10);
1362 }
1363
1364 #[test]
1365 #[allow(non_snake_case)]
1366 fn test_kl_divergence_optimization_basic() {
1367 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1368 let y = array![0, 1, -1, -1];
1369
1370 let kl = KLDivergenceOptimization::new()
1371 .max_iter(10)
1372 .temperature(1.0)
1373 .random_state(42);
1374
1375 let fitted = kl.fit(&X.view(), &y.view()).unwrap();
1376 let predictions = fitted.predict(&X.view()).unwrap();
1377
1378 assert_eq!(predictions.len(), 4);
1379 assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1380 }
1381
1382 #[test]
1383 fn test_kl_divergence_parameters() {
1384 let kl = KLDivergenceOptimization::new()
1385 .temperature(2.0)
1386 .max_iter(200)
1387 .learning_rate(0.001)
1388 .kl_weight(0.5);
1389
1390 assert_eq!(kl.temperature, 2.0);
1391 assert_eq!(kl.max_iter, 200);
1392 assert_eq!(kl.learning_rate, 0.001);
1393 assert_eq!(kl.kl_weight, 0.5);
1394 }
1395}