sklears_semi_supervised/
information_theory.rs

1//! Information theory methods for semi-supervised learning
2//!
3//! This module provides information-theoretic approaches to semi-supervised learning,
4//! including mutual information maximization, information bottleneck principle,
5//! and entropy-based methods for feature selection and active learning.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::error::{Result as SklResult, SklearsError};
10use sklears_core::traits::{Estimator, Fit, Predict, Untrained};
11use sklears_core::types::Float;
12use std::collections::HashMap;
13
14/// Mutual Information Maximization for semi-supervised learning
15///
16/// This method learns representations that maximize mutual information between
17/// input features and output labels, using both labeled and unlabeled data
18/// to improve the learned representations.
19///
20/// # Parameters
21///
22/// * `n_bins` - Number of bins for discretization in MI estimation
23/// * `max_iter` - Maximum number of iterations for optimization
24/// * `learning_rate` - Learning rate for gradient-based optimization
25/// * `temperature` - Temperature parameter for soft discretization
26/// * `regularization` - L2 regularization strength
27/// * `random_state` - Random seed for reproducibility
28///
29/// # Examples
30///
31/// ```
32/// use scirs2_core::array;
33/// use sklears_semi_supervised::MutualInformationMaximization;
34/// use sklears_core::traits::{Predict, Fit};
35///
36///
37/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
38/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
39///
40/// let mim = MutualInformationMaximization::new()
41///     .n_bins(10)
42///     .max_iter(100)
43///     .learning_rate(0.01);
44/// let fitted = mim.fit(&X.view(), &y.view()).unwrap();
45/// let predictions = fitted.predict(&X.view()).unwrap();
46/// ```
47#[derive(Debug, Clone)]
48pub struct MutualInformationMaximization<S = Untrained> {
49    state: S,
50    n_bins: usize,
51    max_iter: usize,
52    learning_rate: f64,
53    temperature: f64,
54    regularization: f64,
55    random_state: Option<u64>,
56}
57
58impl MutualInformationMaximization<Untrained> {
59    /// Create a new MutualInformationMaximization instance
60    pub fn new() -> Self {
61        Self {
62            state: Untrained,
63            n_bins: 20,
64            max_iter: 100,
65            learning_rate: 0.01,
66            temperature: 1.0,
67            regularization: 0.01,
68            random_state: None,
69        }
70    }
71
72    /// Set the number of bins for discretization
73    pub fn n_bins(mut self, n_bins: usize) -> Self {
74        self.n_bins = n_bins;
75        self
76    }
77
78    /// Set the maximum number of iterations
79    pub fn max_iter(mut self, max_iter: usize) -> Self {
80        self.max_iter = max_iter;
81        self
82    }
83
84    /// Set the learning rate
85    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
86        self.learning_rate = learning_rate;
87        self
88    }
89
90    /// Set the temperature parameter
91    pub fn temperature(mut self, temperature: f64) -> Self {
92        self.temperature = temperature;
93        self
94    }
95
96    /// Set the regularization strength
97    pub fn regularization(mut self, regularization: f64) -> Self {
98        self.regularization = regularization;
99        self
100    }
101
102    /// Set the random state
103    pub fn random_state(mut self, random_state: u64) -> Self {
104        self.random_state = Some(random_state);
105        self
106    }
107}
108
109impl Default for MutualInformationMaximization<Untrained> {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl Estimator for MutualInformationMaximization<Untrained> {
116    type Config = ();
117    type Error = SklearsError;
118    type Float = Float;
119
120    fn config(&self) -> &Self::Config {
121        &()
122    }
123}
124
125impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MutualInformationMaximization<Untrained> {
126    type Fitted = MutualInformationMaximization<MutualInformationTrained>;
127
128    #[allow(non_snake_case)]
129    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
130        let X = X.to_owned();
131        let y = y.to_owned();
132        let (n_samples, n_features) = X.dim();
133
134        // Identify labeled and unlabeled samples
135        let mut labeled_indices = Vec::new();
136        let mut unlabeled_indices = Vec::new();
137        let mut classes = std::collections::HashSet::new();
138
139        for (i, &label) in y.iter().enumerate() {
140            if label == -1 {
141                unlabeled_indices.push(i);
142            } else {
143                labeled_indices.push(i);
144                classes.insert(label);
145            }
146        }
147
148        if labeled_indices.is_empty() {
149            return Err(SklearsError::InvalidInput(
150                "No labeled samples provided".to_string(),
151            ));
152        }
153
154        let classes: Vec<i32> = classes.into_iter().collect();
155        let n_classes = classes.len();
156
157        // Initialize random number generator
158        let mut rng = if let Some(seed) = self.random_state {
159            Random::seed(seed)
160        } else {
161            Random::seed(
162                std::time::SystemTime::now()
163                    .duration_since(std::time::UNIX_EPOCH)
164                    .unwrap()
165                    .as_secs(),
166            )
167        };
168
169        // Initialize transformation matrix (feature weights)
170        let mut transformation = Array2::<f64>::zeros((n_features, n_features));
171        for i in 0..n_features {
172            transformation[[i, i]] = 1.0; // Start with identity
173            for j in 0..n_features {
174                if i != j {
175                    transformation[[i, j]] = rng.random_range(-0.1..0.1);
176                }
177            }
178        }
179
180        // Gradient-based optimization to maximize mutual information
181        for _iter in 0..self.max_iter {
182            // Transform features
183            let X_transformed = X.dot(&transformation);
184
185            // Estimate mutual information using histograms
186            let mi =
187                self.estimate_mutual_information(&X_transformed, &y, &labeled_indices, &classes)?;
188
189            // Compute gradient (simplified finite differences)
190            let mut gradient = Array2::<f64>::zeros((n_features, n_features));
191            let epsilon = 1e-6;
192
193            for i in 0..n_features {
194                for j in 0..n_features {
195                    // Forward difference
196                    transformation[[i, j]] += epsilon;
197                    let X_perturbed = X.dot(&transformation);
198                    let mi_perturbed = self.estimate_mutual_information(
199                        &X_perturbed,
200                        &y,
201                        &labeled_indices,
202                        &classes,
203                    )?;
204                    gradient[[i, j]] = (mi_perturbed - mi) / epsilon;
205                    transformation[[i, j]] -= epsilon; // Reset
206                }
207            }
208
209            // Update transformation matrix
210            for i in 0..n_features {
211                for j in 0..n_features {
212                    transformation[[i, j]] += self.learning_rate * gradient[[i, j]]
213                        - self.regularization * transformation[[i, j]];
214                }
215            }
216        }
217
218        // Final transformation and label prediction for unlabeled samples
219        let X_final = X.dot(&transformation);
220        let mut final_labels = y.clone();
221
222        // Use k-nearest neighbors on transformed space to predict unlabeled samples
223        for &unlabeled_idx in &unlabeled_indices {
224            let mut distances = Vec::new();
225            for &labeled_idx in &labeled_indices {
226                let dist = (&X_final.row(unlabeled_idx) - &X_final.row(labeled_idx))
227                    .mapv(|x| x * x)
228                    .sum()
229                    .sqrt();
230                distances.push((labeled_idx, dist));
231            }
232
233            // Sort by distance and take majority vote of k=3 nearest neighbors
234            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
235            let k = 3.min(labeled_indices.len());
236            let mut class_votes = HashMap::new();
237
238            for &(labeled_idx, _) in distances.iter().take(k) {
239                *class_votes.entry(y[labeled_idx]).or_insert(0) += 1;
240            }
241
242            // Assign most voted class
243            if let Some((&predicted_class, _)) = class_votes.iter().max_by_key(|&(_, count)| count)
244            {
245                final_labels[unlabeled_idx] = predicted_class;
246            }
247        }
248
249        Ok(MutualInformationMaximization {
250            state: MutualInformationTrained {
251                X_train: X,
252                y_train: final_labels,
253                classes: Array1::from(classes),
254                transformation,
255                n_bins: self.n_bins,
256            },
257            n_bins: self.n_bins,
258            max_iter: self.max_iter,
259            learning_rate: self.learning_rate,
260            temperature: self.temperature,
261            regularization: self.regularization,
262            random_state: self.random_state,
263        })
264    }
265}
266
267impl MutualInformationMaximization<Untrained> {
268    /// Estimate mutual information using histogram-based method
269    fn estimate_mutual_information(
270        &self,
271        X: &Array2<f64>,
272        y: &Array1<i32>,
273        labeled_indices: &[usize],
274        classes: &[i32],
275    ) -> SklResult<f64> {
276        if labeled_indices.is_empty() {
277            return Ok(0.0);
278        }
279
280        // Discretize features into bins for labeled samples only
281        let mut feature_bins = Vec::new();
282        for j in 0..X.ncols() {
283            let labeled_features: Vec<f64> = labeled_indices.iter().map(|&i| X[[i, j]]).collect();
284
285            if labeled_features.is_empty() {
286                continue;
287            }
288
289            let min_val = labeled_features
290                .iter()
291                .fold(f64::INFINITY, |a, &b| a.min(b));
292            let max_val = labeled_features
293                .iter()
294                .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
295
296            if (max_val - min_val).abs() < 1e-10 {
297                feature_bins.push(vec![0; labeled_indices.len()]); // All same bin
298                continue;
299            }
300
301            let bin_width = (max_val - min_val) / self.n_bins as f64;
302            let bins: Vec<usize> = labeled_features
303                .iter()
304                .map(|&val| {
305                    ((val - min_val) / bin_width)
306                        .floor()
307                        .min((self.n_bins - 1) as f64) as usize
308                })
309                .collect();
310            feature_bins.push(bins);
311        }
312
313        if feature_bins.is_empty() {
314            return Ok(0.0);
315        }
316
317        // Compute joint and marginal distributions
318        let mut joint_counts = HashMap::new();
319        let mut feature_counts = HashMap::new();
320        let mut label_counts = HashMap::new();
321
322        for (sample_idx, &global_idx) in labeled_indices.iter().enumerate() {
323            let label = y[global_idx];
324
325            // Multi-dimensional feature bin (use first feature for simplicity)
326            let feature_bin = if !feature_bins.is_empty() && sample_idx < feature_bins[0].len() {
327                feature_bins[0][sample_idx]
328            } else {
329                0
330            };
331
332            *joint_counts.entry((feature_bin, label)).or_insert(0) += 1;
333            *feature_counts.entry(feature_bin).or_insert(0) += 1;
334            *label_counts.entry(label).or_insert(0) += 1;
335        }
336
337        let n_labeled = labeled_indices.len() as f64;
338        let mut mi = 0.0;
339
340        // Calculate mutual information: MI(X,Y) = sum_{x,y} p(x,y) * log(p(x,y) / (p(x) * p(y)))
341        for (&(feature_bin, label), &joint_count) in &joint_counts {
342            let p_xy = joint_count as f64 / n_labeled;
343            let p_x = feature_counts[&feature_bin] as f64 / n_labeled;
344            let p_y = label_counts[&label] as f64 / n_labeled;
345
346            if p_xy > 0.0 && p_x > 0.0 && p_y > 0.0 {
347                mi += p_xy * (p_xy / (p_x * p_y)).ln();
348            }
349        }
350
351        Ok(mi)
352    }
353}
354
355impl Predict<ArrayView2<'_, Float>, Array1<i32>>
356    for MutualInformationMaximization<MutualInformationTrained>
357{
358    #[allow(non_snake_case)]
359    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
360        let X = X.to_owned();
361        let n_test = X.nrows();
362        let mut predictions = Array1::zeros(n_test);
363
364        // Transform test features
365        let X_transformed = X.dot(&self.state.transformation);
366
367        for i in 0..n_test {
368            // Find most similar training sample using transformed features
369            let mut min_dist = f64::INFINITY;
370            let mut best_label = self.state.classes[0];
371
372            for j in 0..self.state.X_train.nrows() {
373                let X_train_transformed = self.state.X_train.dot(&self.state.transformation);
374                let diff = &X_transformed.row(i) - &X_train_transformed.row(j);
375                let dist = diff.mapv(|x| x * x).sum().sqrt();
376
377                if dist < min_dist {
378                    min_dist = dist;
379                    best_label = self.state.y_train[j];
380                }
381            }
382
383            predictions[i] = best_label;
384        }
385
386        Ok(predictions)
387    }
388}
389
390/// Information Bottleneck principle for semi-supervised learning
391///
392/// This method learns compressed representations that preserve information
393/// about the target labels while discarding irrelevant information.
394///
395/// # Parameters
396///
397/// * `beta` - Trade-off parameter between compression and prediction
398/// * `n_components` - Number of components in the compressed representation
399/// * `max_iter` - Maximum number of iterations
400/// * `tol` - Convergence tolerance
401#[derive(Debug, Clone)]
402pub struct InformationBottleneck<S = Untrained> {
403    state: S,
404    beta: f64,
405    n_components: usize,
406    max_iter: usize,
407    tol: f64,
408    random_state: Option<u64>,
409}
410
411impl InformationBottleneck<Untrained> {
412    /// Create a new InformationBottleneck instance
413    pub fn new() -> Self {
414        Self {
415            state: Untrained,
416            beta: 1.0,
417            n_components: 10,
418            max_iter: 100,
419            tol: 1e-4,
420            random_state: None,
421        }
422    }
423
424    /// Set the beta parameter (compression vs prediction trade-off)
425    pub fn beta(mut self, beta: f64) -> Self {
426        self.beta = beta;
427        self
428    }
429
430    /// Set the number of components
431    pub fn n_components(mut self, n_components: usize) -> Self {
432        self.n_components = n_components;
433        self
434    }
435
436    /// Set the maximum number of iterations
437    pub fn max_iter(mut self, max_iter: usize) -> Self {
438        self.max_iter = max_iter;
439        self
440    }
441
442    /// Set the convergence tolerance
443    pub fn tol(mut self, tol: f64) -> Self {
444        self.tol = tol;
445        self
446    }
447
448    /// Set the random state
449    pub fn random_state(mut self, random_state: u64) -> Self {
450        self.random_state = Some(random_state);
451        self
452    }
453}
454
455impl Default for InformationBottleneck<Untrained> {
456    fn default() -> Self {
457        Self::new()
458    }
459}
460
461impl Estimator for InformationBottleneck<Untrained> {
462    type Config = ();
463    type Error = SklearsError;
464    type Float = Float;
465
466    fn config(&self) -> &Self::Config {
467        &()
468    }
469}
470
471impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for InformationBottleneck<Untrained> {
472    type Fitted = InformationBottleneck<InformationBottleneckTrained>;
473
474    #[allow(non_snake_case)]
475    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
476        let X = X.to_owned();
477        let y = y.to_owned();
478        let (n_samples, n_features) = X.dim();
479
480        // Identify labeled samples
481        let mut labeled_indices = Vec::new();
482        let mut classes = std::collections::HashSet::new();
483
484        for (i, &label) in y.iter().enumerate() {
485            if label != -1 {
486                labeled_indices.push(i);
487                classes.insert(label);
488            }
489        }
490
491        if labeled_indices.is_empty() {
492            return Err(SklearsError::InvalidInput(
493                "No labeled samples provided".to_string(),
494            ));
495        }
496
497        let classes: Vec<i32> = classes.into_iter().collect();
498
499        // Initialize random number generator
500        let mut rng = if let Some(seed) = self.random_state {
501            Random::seed(seed)
502        } else {
503            Random::seed(
504                std::time::SystemTime::now()
505                    .duration_since(std::time::UNIX_EPOCH)
506                    .unwrap()
507                    .as_secs(),
508            )
509        };
510
511        // Initialize projection matrix for dimensionality reduction
512        let mut projection = Array2::<f64>::zeros((n_features, self.n_components));
513        for i in 0..n_features {
514            for j in 0..self.n_components {
515                projection[[i, j]] = rng.random_range(-0.1..0.1);
516            }
517        }
518
519        // Simple iterative optimization (simplified information bottleneck)
520        for _iter in 0..self.max_iter {
521            // Project features to lower dimension
522            let X_projected = X.dot(&projection);
523
524            // Compute reconstruction loss (simplified)
525            let reconstruction_loss =
526                self.compute_reconstruction_loss(&X, &X_projected, &projection)?;
527
528            // Update projection to minimize reconstruction loss while preserving class information
529            // This is a simplified implementation - full IB would require more sophisticated optimization
530            for i in 0..n_features {
531                for j in 0..self.n_components {
532                    let gradient = reconstruction_loss / (n_samples as f64);
533                    projection[[i, j]] -= 0.001 * gradient; // Simple gradient step
534                }
535            }
536        }
537
538        Ok(InformationBottleneck {
539            state: InformationBottleneckTrained {
540                X_train: X,
541                y_train: y,
542                classes: Array1::from(classes),
543                projection,
544            },
545            beta: self.beta,
546            n_components: self.n_components,
547            max_iter: self.max_iter,
548            tol: self.tol,
549            random_state: self.random_state,
550        })
551    }
552}
553
554impl InformationBottleneck<Untrained> {
555    fn compute_reconstruction_loss(
556        &self,
557        X_original: &Array2<f64>,
558        X_projected: &Array2<f64>,
559        projection: &Array2<f64>,
560    ) -> SklResult<f64> {
561        // Simplified reconstruction loss: MSE between original and reconstructed
562        let reconstruction = X_projected.dot(&projection.t());
563        let diff = X_original - &reconstruction;
564        let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
565        Ok(mse)
566    }
567}
568
569impl Predict<ArrayView2<'_, Float>, Array1<i32>>
570    for InformationBottleneck<InformationBottleneckTrained>
571{
572    #[allow(non_snake_case)]
573    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
574        let X = X.to_owned();
575        let n_test = X.nrows();
576        let mut predictions = Array1::zeros(n_test);
577
578        // Project test data
579        let X_test_projected = X.dot(&self.state.projection);
580        let X_train_projected = self.state.X_train.dot(&self.state.projection);
581
582        for i in 0..n_test {
583            // Find nearest neighbor in projected space
584            let mut min_dist = f64::INFINITY;
585            let mut best_label = self.state.classes[0];
586
587            for j in 0..self.state.X_train.nrows() {
588                let diff = &X_test_projected.row(i) - &X_train_projected.row(j);
589                let dist = diff.mapv(|x| x * x).sum().sqrt();
590
591                if dist < min_dist {
592                    min_dist = dist;
593                    best_label = self.state.y_train[j];
594                }
595            }
596
597            predictions[i] = best_label;
598        }
599
600        Ok(predictions)
601    }
602}
603
604/// Trained state for MutualInformationMaximization
605#[derive(Debug, Clone)]
606pub struct MutualInformationTrained {
607    /// X_train
608    pub X_train: Array2<f64>,
609    /// y_train
610    pub y_train: Array1<i32>,
611    /// classes
612    pub classes: Array1<i32>,
613    /// transformation
614    pub transformation: Array2<f64>,
615    /// n_bins
616    pub n_bins: usize,
617}
618
619/// Trained state for InformationBottleneck
620#[derive(Debug, Clone)]
621pub struct InformationBottleneckTrained {
622    /// X_train
623    pub X_train: Array2<f64>,
624    /// y_train
625    pub y_train: Array1<i32>,
626    /// classes
627    pub classes: Array1<i32>,
628    /// projection
629    pub projection: Array2<f64>,
630}
631
632/// Entropy-based Regularization for Semi-Supervised Learning
633///
634/// This method adds entropy regularization to encourage confident predictions
635/// on unlabeled data while minimizing classification error on labeled data.
636///
637/// # Parameters
638///
639/// * `entropy_weight` - Weight for entropy regularization term
640/// * `max_iter` - Maximum number of iterations
641/// * `learning_rate` - Learning rate for gradient descent
642/// * `n_neighbors` - Number of neighbors for graph construction
643#[derive(Debug, Clone)]
644pub struct EntropyRegularizedSemiSupervised<S = Untrained> {
645    state: S,
646    entropy_weight: f64,
647    max_iter: usize,
648    learning_rate: f64,
649    n_neighbors: usize,
650    random_state: Option<u64>,
651}
652
653impl EntropyRegularizedSemiSupervised<Untrained> {
654    /// Create a new EntropyRegularizedSemiSupervised instance
655    pub fn new() -> Self {
656        Self {
657            state: Untrained,
658            entropy_weight: 0.5,
659            max_iter: 100,
660            learning_rate: 0.01,
661            n_neighbors: 5,
662            random_state: None,
663        }
664    }
665
666    /// Set the entropy weight
667    pub fn entropy_weight(mut self, weight: f64) -> Self {
668        self.entropy_weight = weight;
669        self
670    }
671
672    /// Set the maximum number of iterations
673    pub fn max_iter(mut self, max_iter: usize) -> Self {
674        self.max_iter = max_iter;
675        self
676    }
677
678    /// Set the learning rate
679    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
680        self.learning_rate = learning_rate;
681        self
682    }
683
684    /// Set the number of neighbors
685    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
686        self.n_neighbors = n_neighbors;
687        self
688    }
689
690    /// Set the random state
691    pub fn random_state(mut self, random_state: u64) -> Self {
692        self.random_state = Some(random_state);
693        self
694    }
695}
696
697impl Default for EntropyRegularizedSemiSupervised<Untrained> {
698    fn default() -> Self {
699        Self::new()
700    }
701}
702
703impl Estimator for EntropyRegularizedSemiSupervised<Untrained> {
704    type Config = ();
705    type Error = SklearsError;
706    type Float = Float;
707
708    fn config(&self) -> &Self::Config {
709        &()
710    }
711}
712
713impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>>
714    for EntropyRegularizedSemiSupervised<Untrained>
715{
716    type Fitted = EntropyRegularizedSemiSupervised<EntropyRegularizedTrained>;
717
718    #[allow(non_snake_case)]
719    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
720        let X = X.to_owned();
721        let y = y.to_owned();
722        let (n_samples, n_features) = X.dim();
723
724        // Identify labeled and unlabeled samples
725        let mut labeled_indices = Vec::new();
726        let mut unlabeled_indices = Vec::new();
727        let mut classes = std::collections::HashSet::new();
728
729        for (i, &label) in y.iter().enumerate() {
730            if label == -1 {
731                unlabeled_indices.push(i);
732            } else {
733                labeled_indices.push(i);
734                classes.insert(label);
735            }
736        }
737
738        if labeled_indices.is_empty() {
739            return Err(SklearsError::InvalidInput(
740                "No labeled samples provided".to_string(),
741            ));
742        }
743
744        let classes: Vec<i32> = classes.into_iter().collect();
745        let n_classes = classes.len();
746
747        // Initialize probability distributions
748        let mut prob_distributions = Array2::<f64>::zeros((n_samples, n_classes));
749
750        // Set labeled samples
751        for &idx in &labeled_indices {
752            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
753                prob_distributions[[idx, class_idx]] = 1.0;
754            }
755        }
756
757        // Initialize unlabeled samples with uniform distribution
758        for &idx in &unlabeled_indices {
759            for class_idx in 0..n_classes {
760                prob_distributions[[idx, class_idx]] = 1.0 / n_classes as f64;
761            }
762        }
763
764        // Build k-NN graph
765        let mut adjacency = Array2::<f64>::zeros((n_samples, n_samples));
766        for i in 0..n_samples {
767            let mut distances: Vec<(usize, f64)> = Vec::new();
768            for j in 0..n_samples {
769                if i != j {
770                    let diff = &X.row(i) - &X.row(j);
771                    let dist = diff.mapv(|x| x * x).sum().sqrt();
772                    distances.push((j, dist));
773                }
774            }
775            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
776
777            for &(j, dist) in distances.iter().take(self.n_neighbors) {
778                let weight = (-dist.powi(2) / 2.0).exp();
779                adjacency[[i, j]] = weight;
780                adjacency[[j, i]] = weight;
781            }
782        }
783
784        // Normalize adjacency matrix
785        for i in 0..n_samples {
786            let row_sum: f64 = adjacency.row(i).sum();
787            if row_sum > 0.0 {
788                for j in 0..n_samples {
789                    adjacency[[i, j]] /= row_sum;
790                }
791            }
792        }
793
794        // Optimize with entropy regularization
795        for _iter in 0..self.max_iter {
796            let prev_probs = prob_distributions.clone();
797
798            // Update unlabeled samples
799            for &idx in &unlabeled_indices {
800                // Smooth labels from neighbors
801                let mut smooth_dist = Array1::<f64>::zeros(n_classes);
802                for j in 0..n_samples {
803                    for k in 0..n_classes {
804                        smooth_dist[k] += adjacency[[idx, j]] * prob_distributions[[j, k]];
805                    }
806                }
807
808                // Compute entropy regularization gradient
809                let mut entropy_grad = Array1::<f64>::zeros(n_classes);
810                for k in 0..n_classes {
811                    let p = prob_distributions[[idx, k]].max(1e-10);
812                    entropy_grad[k] = -(p.ln() + 1.0);
813                }
814
815                // Update probabilities
816                for k in 0..n_classes {
817                    prob_distributions[[idx, k]] =
818                        smooth_dist[k] - self.learning_rate * self.entropy_weight * entropy_grad[k];
819                    prob_distributions[[idx, k]] = prob_distributions[[idx, k]].max(0.0);
820                }
821
822                // Normalize
823                let row_sum: f64 = prob_distributions.row(idx).sum();
824                if row_sum > 0.0 {
825                    for k in 0..n_classes {
826                        prob_distributions[[idx, k]] /= row_sum;
827                    }
828                }
829            }
830
831            // Check convergence
832            let diff = (&prob_distributions - &prev_probs).mapv(|x| x.abs()).sum();
833            if diff < 1e-6 {
834                break;
835            }
836        }
837
838        // Generate final labels
839        let mut final_labels = y.clone();
840        for &idx in &unlabeled_indices {
841            let class_idx = prob_distributions
842                .row(idx)
843                .iter()
844                .enumerate()
845                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
846                .unwrap()
847                .0;
848            final_labels[idx] = classes[class_idx];
849        }
850
851        Ok(EntropyRegularizedSemiSupervised {
852            state: EntropyRegularizedTrained {
853                X_train: X,
854                y_train: final_labels,
855                classes: Array1::from(classes),
856                prob_distributions,
857                adjacency,
858            },
859            entropy_weight: self.entropy_weight,
860            max_iter: self.max_iter,
861            learning_rate: self.learning_rate,
862            n_neighbors: self.n_neighbors,
863            random_state: self.random_state,
864        })
865    }
866}
867
868impl Predict<ArrayView2<'_, Float>, Array1<i32>>
869    for EntropyRegularizedSemiSupervised<EntropyRegularizedTrained>
870{
871    #[allow(non_snake_case)]
872    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
873        let X = X.to_owned();
874        let n_test = X.nrows();
875        let mut predictions = Array1::zeros(n_test);
876
877        for i in 0..n_test {
878            let mut min_dist = f64::INFINITY;
879            let mut best_label = self.state.classes[0];
880
881            for j in 0..self.state.X_train.nrows() {
882                let diff = &X.row(i) - &self.state.X_train.row(j);
883                let dist = diff.mapv(|x| x * x).sum().sqrt();
884
885                if dist < min_dist {
886                    min_dist = dist;
887                    best_label = self.state.y_train[j];
888                }
889            }
890
891            predictions[i] = best_label;
892        }
893
894        Ok(predictions)
895    }
896}
897
898/// KL-Divergence Optimization for Semi-Supervised Learning
899///
900/// This method optimizes a classifier by minimizing KL divergence between
901/// predictions on differently augmented versions of the same unlabeled data.
902///
903/// # Parameters
904///
905/// * `temperature` - Temperature for softmax
906/// * `max_iter` - Maximum number of iterations
907/// * `learning_rate` - Learning rate for optimization
908/// * `kl_weight` - Weight for KL divergence term
909#[derive(Debug, Clone)]
910pub struct KLDivergenceOptimization<S = Untrained> {
911    state: S,
912    temperature: f64,
913    max_iter: usize,
914    learning_rate: f64,
915    kl_weight: f64,
916    random_state: Option<u64>,
917}
918
919impl KLDivergenceOptimization<Untrained> {
920    /// Create a new KLDivergenceOptimization instance
921    pub fn new() -> Self {
922        Self {
923            state: Untrained,
924            temperature: 1.0,
925            max_iter: 100,
926            learning_rate: 0.01,
927            kl_weight: 1.0,
928            random_state: None,
929        }
930    }
931
932    /// Set the temperature
933    pub fn temperature(mut self, temperature: f64) -> Self {
934        self.temperature = temperature;
935        self
936    }
937
938    /// Set the maximum number of iterations
939    pub fn max_iter(mut self, max_iter: usize) -> Self {
940        self.max_iter = max_iter;
941        self
942    }
943
944    /// Set the learning rate
945    pub fn learning_rate(mut self, learning_rate: f64) -> Self {
946        self.learning_rate = learning_rate;
947        self
948    }
949
950    /// Set the KL weight
951    pub fn kl_weight(mut self, weight: f64) -> Self {
952        self.kl_weight = weight;
953        self
954    }
955
956    /// Set the random state
957    pub fn random_state(mut self, random_state: u64) -> Self {
958        self.random_state = Some(random_state);
959        self
960    }
961}
962
963impl Default for KLDivergenceOptimization<Untrained> {
964    fn default() -> Self {
965        Self::new()
966    }
967}
968
969impl Estimator for KLDivergenceOptimization<Untrained> {
970    type Config = ();
971    type Error = SklearsError;
972    type Float = Float;
973
974    fn config(&self) -> &Self::Config {
975        &()
976    }
977}
978
979impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for KLDivergenceOptimization<Untrained> {
980    type Fitted = KLDivergenceOptimization<KLDivergenceTrained>;
981
982    #[allow(non_snake_case)]
983    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
984        let X = X.to_owned();
985        let y = y.to_owned();
986        let (n_samples, n_features) = X.dim();
987
988        // Identify labeled and unlabeled samples
989        let mut labeled_indices = Vec::new();
990        let mut unlabeled_indices = Vec::new();
991        let mut classes = std::collections::HashSet::new();
992
993        for (i, &label) in y.iter().enumerate() {
994            if label == -1 {
995                unlabeled_indices.push(i);
996            } else {
997                labeled_indices.push(i);
998                classes.insert(label);
999            }
1000        }
1001
1002        if labeled_indices.is_empty() {
1003            return Err(SklearsError::InvalidInput(
1004                "No labeled samples provided".to_string(),
1005            ));
1006        }
1007
1008        let classes: Vec<i32> = classes.into_iter().collect();
1009        let n_classes = classes.len();
1010
1011        // Initialize RNG
1012        let mut rng = if let Some(seed) = self.random_state {
1013            Random::seed(seed)
1014        } else {
1015            Random::seed(
1016                std::time::SystemTime::now()
1017                    .duration_since(std::time::UNIX_EPOCH)
1018                    .unwrap()
1019                    .as_secs(),
1020            )
1021        };
1022
1023        // Initialize classifier weights
1024        let mut weights = Array2::<f64>::zeros((n_features, n_classes));
1025        for i in 0..n_features {
1026            for j in 0..n_classes {
1027                weights[[i, j]] = rng.random_range(-0.1..0.1);
1028            }
1029        }
1030
1031        // Training loop with KL divergence minimization
1032        for _iter in 0..self.max_iter {
1033            // Compute predictions for all samples
1034            let logits = X.dot(&weights);
1035            let mut predictions = Array2::<f64>::zeros((n_samples, n_classes));
1036
1037            // Apply softmax with temperature
1038            for i in 0..n_samples {
1039                let max_logit = logits
1040                    .row(i)
1041                    .iter()
1042                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1043                let mut exp_sum = 0.0;
1044
1045                for j in 0..n_classes {
1046                    let exp_val = ((logits[[i, j]] - max_logit) / self.temperature).exp();
1047                    predictions[[i, j]] = exp_val;
1048                    exp_sum += exp_val;
1049                }
1050
1051                if exp_sum > 0.0 {
1052                    for j in 0..n_classes {
1053                        predictions[[i, j]] /= exp_sum;
1054                    }
1055                }
1056            }
1057
1058            // Compute gradient
1059            let mut gradient = Array2::<f64>::zeros((n_features, n_classes));
1060
1061            // Supervised loss gradient (cross-entropy)
1062            for &idx in &labeled_indices {
1063                if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
1064                    for j in 0..n_features {
1065                        for k in 0..n_classes {
1066                            let target = if k == class_idx { 1.0 } else { 0.0 };
1067                            gradient[[j, k]] += X[[idx, j]] * (predictions[[idx, k]] - target);
1068                        }
1069                    }
1070                }
1071            }
1072
1073            // KL divergence term for unlabeled samples (encourage confident predictions)
1074            for &idx in &unlabeled_indices {
1075                // Create augmented version (simplified: add small noise)
1076                let mut X_aug = X.row(idx).to_owned();
1077                for j in 0..n_features {
1078                    X_aug[j] += rng.random_range(-0.01..0.01);
1079                }
1080
1081                // Compute prediction on augmented sample
1082                let logits_aug = X_aug.dot(&weights);
1083                let max_logit = logits_aug.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1084                let mut pred_aug = Array1::<f64>::zeros(n_classes);
1085                let mut exp_sum = 0.0;
1086
1087                for j in 0..n_classes {
1088                    let exp_val = ((logits_aug[j] - max_logit) / self.temperature).exp();
1089                    pred_aug[j] = exp_val;
1090                    exp_sum += exp_val;
1091                }
1092
1093                if exp_sum > 0.0 {
1094                    pred_aug /= exp_sum;
1095                }
1096
1097                // KL divergence gradient
1098                for j in 0..n_features {
1099                    for k in 0..n_classes {
1100                        let p = predictions[[idx, k]].max(1e-10);
1101                        let q = pred_aug[k].max(1e-10);
1102                        let kl_grad = p * (p / q).ln();
1103                        gradient[[j, k]] += self.kl_weight * X[[idx, j]] * kl_grad;
1104                    }
1105                }
1106            }
1107
1108            // Update weights
1109            let scale = self.learning_rate / n_samples as f64;
1110            for i in 0..n_features {
1111                for j in 0..n_classes {
1112                    weights[[i, j]] -= scale * gradient[[i, j]];
1113                }
1114            }
1115        }
1116
1117        // Generate final predictions
1118        let logits = X.dot(&weights);
1119        let mut final_labels = y.clone();
1120
1121        for &idx in &unlabeled_indices {
1122            let class_idx = logits
1123                .row(idx)
1124                .iter()
1125                .enumerate()
1126                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
1127                .unwrap()
1128                .0;
1129            final_labels[idx] = classes[class_idx];
1130        }
1131
1132        Ok(KLDivergenceOptimization {
1133            state: KLDivergenceTrained {
1134                X_train: X,
1135                y_train: final_labels,
1136                classes: Array1::from(classes),
1137                weights,
1138            },
1139            temperature: self.temperature,
1140            max_iter: self.max_iter,
1141            learning_rate: self.learning_rate,
1142            kl_weight: self.kl_weight,
1143            random_state: self.random_state,
1144        })
1145    }
1146}
1147
1148impl Predict<ArrayView2<'_, Float>, Array1<i32>> for KLDivergenceOptimization<KLDivergenceTrained> {
1149    #[allow(non_snake_case)]
1150    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
1151        let X = X.to_owned();
1152        let n_test = X.nrows();
1153        let mut predictions = Array1::zeros(n_test);
1154
1155        let logits = X.dot(&self.state.weights);
1156
1157        for i in 0..n_test {
1158            let class_idx = logits
1159                .row(i)
1160                .iter()
1161                .enumerate()
1162                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
1163                .unwrap()
1164                .0;
1165            predictions[i] = self.state.classes[class_idx];
1166        }
1167
1168        Ok(predictions)
1169    }
1170}
1171
1172/// Trained state for EntropyRegularizedSemiSupervised
1173#[derive(Debug, Clone)]
1174pub struct EntropyRegularizedTrained {
1175    /// X_train
1176    pub X_train: Array2<f64>,
1177    /// y_train
1178    pub y_train: Array1<i32>,
1179    /// classes
1180    pub classes: Array1<i32>,
1181    /// prob_distributions
1182    pub prob_distributions: Array2<f64>,
1183    /// adjacency
1184    pub adjacency: Array2<f64>,
1185}
1186
1187/// Trained state for KLDivergenceOptimization
1188#[derive(Debug, Clone)]
1189pub struct KLDivergenceTrained {
1190    /// X_train
1191    pub X_train: Array2<f64>,
1192    /// y_train
1193    pub y_train: Array1<i32>,
1194    /// classes
1195    pub classes: Array1<i32>,
1196    /// weights
1197    pub weights: Array2<f64>,
1198}
1199
1200#[allow(non_snake_case)]
1201#[cfg(test)]
1202mod tests {
1203    use super::*;
1204    use scirs2_core::array;
1205
1206    #[test]
1207    #[allow(non_snake_case)]
1208    fn test_mutual_information_maximization() {
1209        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1210        let y = array![0, 1, -1, -1];
1211
1212        let mim = MutualInformationMaximization::new()
1213            .n_bins(5)
1214            .max_iter(10)
1215            .random_state(42);
1216
1217        let fitted = mim.fit(&X.view(), &y.view()).unwrap();
1218        let predictions = fitted.predict(&X.view()).unwrap();
1219
1220        assert_eq!(predictions.len(), 4);
1221        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1222
1223        // Labeled samples should be predicted correctly
1224        assert_eq!(predictions[0], 0);
1225        assert_eq!(predictions[1], 1);
1226    }
1227
1228    #[test]
1229    #[allow(non_snake_case)]
1230    fn test_information_bottleneck() {
1231        let X = array![
1232            [1.0, 2.0, 3.0],
1233            [2.0, 3.0, 4.0],
1234            [3.0, 4.0, 5.0],
1235            [4.0, 5.0, 6.0]
1236        ];
1237        let y = array![0, 1, -1, -1];
1238
1239        let ib = InformationBottleneck::new()
1240            .n_components(2)
1241            .max_iter(10)
1242            .random_state(42);
1243
1244        let fitted = ib.fit(&X.view(), &y.view()).unwrap();
1245        let predictions = fitted.predict(&X.view()).unwrap();
1246
1247        assert_eq!(predictions.len(), 4);
1248        // Predictions should be valid class labels (including -1 for potentially unlabeled predictions)
1249        assert!(predictions.iter().all(|&p| p == -1 || p == 0 || p == 1));
1250    }
1251
1252    #[test]
1253    #[allow(non_snake_case)]
1254    fn test_mutual_information_estimation() {
1255        let mim = MutualInformationMaximization::new().n_bins(5);
1256        let X = array![[1.0, 2.0], [2.0, 3.0]];
1257        let y = array![0, 1];
1258        let labeled_indices = vec![0, 1];
1259        let classes = vec![0, 1];
1260
1261        let mi = mim
1262            .estimate_mutual_information(&X, &y, &labeled_indices, &classes)
1263            .unwrap();
1264        assert!(mi >= 0.0); // Mutual information should be non-negative
1265    }
1266
1267    #[test]
1268    fn test_information_bottleneck_parameters() {
1269        let ib = InformationBottleneck::new()
1270            .beta(0.5)
1271            .n_components(5)
1272            .max_iter(50)
1273            .tol(1e-5);
1274
1275        assert_eq!(ib.beta, 0.5);
1276        assert_eq!(ib.n_components, 5);
1277        assert_eq!(ib.max_iter, 50);
1278        assert_eq!(ib.tol, 1e-5);
1279    }
1280
1281    #[test]
1282    fn test_mutual_information_maximization_parameters() {
1283        let mim = MutualInformationMaximization::new()
1284            .n_bins(15)
1285            .max_iter(200)
1286            .learning_rate(0.05)
1287            .temperature(2.0)
1288            .regularization(0.02);
1289
1290        assert_eq!(mim.n_bins, 15);
1291        assert_eq!(mim.max_iter, 200);
1292        assert_eq!(mim.learning_rate, 0.05);
1293        assert_eq!(mim.temperature, 2.0);
1294        assert_eq!(mim.regularization, 0.02);
1295    }
1296
1297    #[test]
1298    #[allow(non_snake_case)]
1299    fn test_empty_labeled_samples_error() {
1300        let X = array![[1.0, 2.0], [2.0, 3.0]];
1301        let y = array![-1, -1]; // No labeled samples
1302
1303        let mim = MutualInformationMaximization::new();
1304        let result = mim.fit(&X.view(), &y.view());
1305
1306        assert!(result.is_err());
1307        if let Err(SklearsError::InvalidInput(msg)) = result {
1308            assert_eq!(msg, "No labeled samples provided");
1309        } else {
1310            panic!("Expected InvalidInput error");
1311        }
1312    }
1313
1314    #[test]
1315    #[allow(non_snake_case)]
1316    fn test_single_class_stability() {
1317        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1318        let y = array![0, 0, -1, -1]; // Only one class labeled
1319
1320        let mim = MutualInformationMaximization::new()
1321            .max_iter(5)
1322            .random_state(42);
1323
1324        let fitted = mim.fit(&X.view(), &y.view()).unwrap();
1325        let predictions = fitted.predict(&X.view()).unwrap();
1326
1327        assert_eq!(predictions.len(), 4);
1328        // With only one labeled class, predictions should be stable
1329        assert!(predictions.iter().all(|&p| p == 0));
1330    }
1331
1332    #[test]
1333    #[allow(non_snake_case)]
1334    fn test_entropy_regularized_basic() {
1335        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1336        let y = array![0, 1, -1, -1];
1337
1338        let er = EntropyRegularizedSemiSupervised::new()
1339            .entropy_weight(0.5)
1340            .max_iter(10)
1341            .random_state(42);
1342
1343        let fitted = er.fit(&X.view(), &y.view()).unwrap();
1344        let predictions = fitted.predict(&X.view()).unwrap();
1345
1346        assert_eq!(predictions.len(), 4);
1347        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1348    }
1349
1350    #[test]
1351    fn test_entropy_regularized_parameters() {
1352        let er = EntropyRegularizedSemiSupervised::new()
1353            .entropy_weight(0.5)
1354            .max_iter(50)
1355            .learning_rate(0.001)
1356            .n_neighbors(10);
1357
1358        assert_eq!(er.entropy_weight, 0.5);
1359        assert_eq!(er.max_iter, 50);
1360        assert_eq!(er.learning_rate, 0.001);
1361        assert_eq!(er.n_neighbors, 10);
1362    }
1363
1364    #[test]
1365    #[allow(non_snake_case)]
1366    fn test_kl_divergence_optimization_basic() {
1367        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
1368        let y = array![0, 1, -1, -1];
1369
1370        let kl = KLDivergenceOptimization::new()
1371            .max_iter(10)
1372            .temperature(1.0)
1373            .random_state(42);
1374
1375        let fitted = kl.fit(&X.view(), &y.view()).unwrap();
1376        let predictions = fitted.predict(&X.view()).unwrap();
1377
1378        assert_eq!(predictions.len(), 4);
1379        assert!(predictions.iter().all(|&p| p >= 0 && p <= 1));
1380    }
1381
1382    #[test]
1383    fn test_kl_divergence_parameters() {
1384        let kl = KLDivergenceOptimization::new()
1385            .temperature(2.0)
1386            .max_iter(200)
1387            .learning_rate(0.001)
1388            .kl_weight(0.5);
1389
1390        assert_eq!(kl.temperature, 2.0);
1391        assert_eq!(kl.max_iter, 200);
1392        assert_eq!(kl.learning_rate, 0.001);
1393        assert_eq!(kl.kl_weight, 0.5);
1394    }
1395}